|
6 | 6 | "log/slog" |
7 | 7 | "net/http" |
8 | 8 | "net/netip" |
| 9 | + "strings" |
9 | 10 | "sync" |
10 | 11 |
|
11 | 12 | "github.com/dropmorepackets/haproxy-go/pkg/buffer" |
@@ -55,6 +56,17 @@ func releaseHostBuf(b *buffer.SliceBuffer) { |
55 | 56 | hostBufPool.Put(b) |
56 | 57 | } |
57 | 58 |
|
| 59 | +func getTrustedDomain(host []byte, td []string) []byte { |
| 60 | + h := string(host) |
| 61 | + for _, d := range td { |
| 62 | + if strings.HasSuffix(h, d) { |
| 63 | + return []byte(d) |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + return nil |
| 68 | +} |
| 69 | + |
58 | 70 | func (f *frontend) HandleSPOEValidate(ctx context.Context, w *encoding.ActionWriter, m *encoding.Message) { |
59 | 71 | k := encoding.AcquireKVEntry() |
60 | 72 | defer encoding.ReleaseKVEntry(k) |
@@ -93,9 +105,15 @@ func (f *frontend) HandleSPOEValidate(ctx context.Context, w *encoding.ActionWri |
93 | 105 | return |
94 | 106 | } |
95 | 107 |
|
| 108 | + td := getTrustedDomain(host, f.bh.TrustedDomains) |
| 109 | + if td != nil { |
| 110 | + host = td |
| 111 | + } |
| 112 | + |
96 | 113 | hostBuf := acquireHostBuf() |
97 | 114 | defer releaseHostBuf(hostBuf) |
98 | | - copy(hostBuf.WriteNBytes(len(k.ValueBytes())), k.ValueBytes()) |
| 115 | + |
| 116 | + copy(hostBuf.WriteNBytes(len(host)), host) |
99 | 117 | ri.Host = hostBuf.ReadBytes() |
100 | 118 |
|
101 | 119 | if err := readExpectedKVEntry(ctx, m, k, "cookie"); err != nil { |
@@ -141,13 +159,22 @@ func (f *frontend) HandleSPOEChallenge(ctx context.Context, w *encoding.ActionWr |
141 | 159 | if err := readExpectedKVEntry(ctx, m, k, "host"); err != nil { |
142 | 160 | return |
143 | 161 | } |
144 | | - if len(k.ValueBytes()) > hostBufferLength { |
| 162 | + host := k.ValueBytes() |
| 163 | + if len(host) > hostBufferLength { |
145 | 164 | slog.ErrorContext(ctx, "host length too big") |
146 | 165 | } |
147 | 166 |
|
| 167 | + td := getTrustedDomain(host, f.bh.TrustedDomains) |
| 168 | + if td != nil { |
| 169 | + host = td |
| 170 | + } |
| 171 | + |
| 172 | + _ = w.SetString(encoding.VarScopeTransaction, "domain", string(host)) |
| 173 | + |
148 | 174 | hostBuf := acquireHostBuf() |
149 | 175 | defer releaseHostBuf(hostBuf) |
150 | | - copy(hostBuf.WriteNBytes(len(k.ValueBytes())), k.ValueBytes()) |
| 176 | + |
| 177 | + copy(hostBuf.WriteNBytes(len(host)), host) |
151 | 178 | ri.Host = hostBuf.ReadBytes() |
152 | 179 |
|
153 | 180 | req := berghain.AcquireValidatorRequest() |
|
0 commit comments