Skip to content
1 change: 1 addition & 0 deletions gateway/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Config struct {
type Upstream struct {
URL string `yaml:"url"`
Paths []string `yaml:"paths"`
DNSRefreshInterval time.Duration `yaml:"dns_refresh_interval"`
HTTPClientTimeout time.Duration `yaml:"http_client_timeout"`
HTTPClientDialerTimeout time.Duration `yaml:"http_client_dialer_timeout"`
HTTPClientTLSHandshakeTimeout time.Duration `yaml:"http_client_tls_handshake_timeout"`
Expand Down
58 changes: 2 additions & 56 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"net/http"
"time"

"github.com/cortexproject/auth-gateway/server"
)
Expand Down Expand Up @@ -39,66 +38,13 @@ var defaultQueryFrontendAPIs = []string{
"/api/prom/api/v1/status/buildinfo",
}

// TODO: create a helper function for error handling and parsing the duration
func New(config *Config, srv *server.Server) (*Gateway, error) {
httpClientTimeout, err := time.ParseDuration(config.Distributor.HTTPClientTimeout.String())
distributor, err := NewProxy(config.Distributor.URL, config.Distributor, DISTRIBUTOR)
if err != nil {
return nil, err
}

httpClientDialerTimeout, err := time.ParseDuration(config.Distributor.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

httpClientTLSHandshakeTimeout, err := time.ParseDuration(config.Distributor.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

httpClientResponseHeaderTimeout, err := time.ParseDuration(config.Distributor.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

distributorTimeouts := Upstream{
HTTPClientTimeout: httpClientTimeout,
HTTPClientDialerTimeout: httpClientDialerTimeout,
HTTPClientTLSHandshakeTimeout: httpClientTLSHandshakeTimeout,
HTTPClientResponseHeaderTimeout: httpClientResponseHeaderTimeout,
}
distributor, err := NewProxy(config.Distributor.URL, distributorTimeouts, DISTRIBUTOR)
if err != nil {
return nil, err
}

httpClientTimeout, err = time.ParseDuration(config.QueryFrontend.HTTPClientTimeout.String())
if err != nil {
return nil, err
}

httpClientDialerTimeout, err = time.ParseDuration(config.QueryFrontend.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

httpClientTLSHandshakeTimeout, err = time.ParseDuration(config.QueryFrontend.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

httpClientResponseHeaderTimeout, err = time.ParseDuration(config.QueryFrontend.HTTPClientDialerTimeout.String())
if err != nil {
return nil, err
}

frontendTimeouts := Upstream{
HTTPClientTimeout: httpClientTimeout,
HTTPClientDialerTimeout: httpClientDialerTimeout,
HTTPClientTLSHandshakeTimeout: httpClientTLSHandshakeTimeout,
HTTPClientResponseHeaderTimeout: httpClientResponseHeaderTimeout,
}
frontend, err := NewProxy(config.QueryFrontend.URL, frontendTimeouts, FRONTEND)
frontend, err := NewProxy(config.QueryFrontend.URL, config.QueryFrontend, FRONTEND)
if err != nil {
return nil, err
}
Expand Down
6 changes: 6 additions & 0 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: 10 * time.Second,
HTTPClientTLSHandshakeTimeout: 5 * time.Second,
HTTPClientResponseHeaderTimeout: 5 * time.Second,
DNSRefreshInterval: 3 * time.Second,
}

testCases := []struct {
Expand Down Expand Up @@ -78,6 +79,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: timeouts.HTTPClientDialerTimeout * time.Second,
HTTPClientTLSHandshakeTimeout: timeouts.HTTPClientTLSHandshakeTimeout * time.Second,
HTTPClientResponseHeaderTimeout: timeouts.HTTPClientResponseHeaderTimeout * time.Second,
DNSRefreshInterval: timeouts.DNSRefreshInterval,
},
QueryFrontend: Upstream{
URL: frontendServer.URL,
Expand All @@ -86,6 +88,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: timeouts.HTTPClientDialerTimeout * time.Second,
HTTPClientTLSHandshakeTimeout: timeouts.HTTPClientTLSHandshakeTimeout * time.Second,
HTTPClientResponseHeaderTimeout: timeouts.HTTPClientResponseHeaderTimeout * time.Second,
DNSRefreshInterval: timeouts.DNSRefreshInterval,
},
},
authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")),
Expand Down Expand Up @@ -133,6 +136,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: timeouts.HTTPClientDialerTimeout * time.Second,
HTTPClientTLSHandshakeTimeout: timeouts.HTTPClientTLSHandshakeTimeout * time.Second,
HTTPClientResponseHeaderTimeout: timeouts.HTTPClientResponseHeaderTimeout * time.Second,
DNSRefreshInterval: timeouts.DNSRefreshInterval,
},
QueryFrontend: Upstream{
URL: frontendServer.URL,
Expand All @@ -143,6 +147,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: timeouts.HTTPClientDialerTimeout * time.Second,
HTTPClientTLSHandshakeTimeout: timeouts.HTTPClientTLSHandshakeTimeout * time.Second,
HTTPClientResponseHeaderTimeout: timeouts.HTTPClientResponseHeaderTimeout * time.Second,
DNSRefreshInterval: timeouts.DNSRefreshInterval,
},
},
paths: []string{
Expand All @@ -169,6 +174,7 @@ func TestStartGateway(t *testing.T) {
HTTPClientDialerTimeout: timeouts.HTTPClientDialerTimeout,
HTTPClientTLSHandshakeTimeout: timeouts.HTTPClientTLSHandshakeTimeout,
HTTPClientResponseHeaderTimeout: timeouts.HTTPClientResponseHeaderTimeout,
DNSRefreshInterval: timeouts.DNSRefreshInterval,
},
},
paths: []string{
Expand Down
111 changes: 111 additions & 0 deletions gateway/load_balancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package gateway

import (
"fmt"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)

type DNSResolver interface {
LookupIP(string) ([]net.IP, error)
}

type DefaultDNSResolver struct{}

func (d DefaultDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
return net.LookupIP(hostname)
}

type roundRobinLoadBalancer struct {
hostname string
ips []string
currentIndex int
transport http.RoundTripper
resolveIPs func(hostname string) ([]net.IP, error)
sync.RWMutex
}

func newRoundRobinLoadBalancer(hostname string, resolver func(hostname string) ([]net.IP, error)) *roundRobinLoadBalancer {
lb := &roundRobinLoadBalancer{
hostname: hostname,
transport: http.DefaultTransport,
resolveIPs: resolver,
}

// Resolve IPs initially
ips, err := lb.resolveIPs(hostname)
if err != nil {
log.Printf("Failed to resolve IPs for hostname %s: %v", lb.hostname, err)
} else {
lb.ips = ipsToStrings(ips)
}

return lb
}

func (lb *roundRobinLoadBalancer) roundTrip(req *http.Request) (*http.Response, error) {
lb.Lock()
defer lb.Unlock()

if len(lb.ips) == 0 {
// TODO: replace format error with a log statement
return nil, fmt.Errorf("no IP addresses available")
}

ip := lb.getNextIP()
req.URL.Host = strings.Replace(req.URL.Host, lb.hostname, ip, 1)
lb.currentIndex++

return lb.transport.RoundTrip(req)
}

func (lb *roundRobinLoadBalancer) getNextIP() string {
return lb.ips[lb.currentIndex%len(lb.ips)]
}

func (lb *roundRobinLoadBalancer) safeGetNextIP() string {
lb.RLock()
defer lb.RUnlock()

return lb.getNextIP()
}

// Refresh IPs periodically
func (lb *roundRobinLoadBalancer) refreshIPs(refreshInterval time.Duration) {
for {
ips, err := lb.resolveIPs(lb.hostname)
if err != nil {
// TODO: replace std library log package with logrus
log.Printf("Failed to resolve IPs for hostname %s: %v", lb.hostname, err)
} else {
lb.Lock()
lb.ips = ipsToStrings(ips)
lb.currentIndex = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this, as resetting to zero means we will use the first resolved IP more than others.

But I think tests will give the final veredict: Include some tests that prove every ip address gets similar number of requests over many refresh dns.

lb.Unlock()
}
time.Sleep(refreshInterval)
}
}

func ipsToStrings(ips []net.IP) []string {
strs := make([]string, len(ips))
for i, ip := range ips {
strs[i] = ip.String()
}
return strs
}

// CustomTransport wraps http.Transport and embeds the round-robin load balancer.
type CustomTransport struct {
http.Transport
lb *roundRobinLoadBalancer
}

// RoundTrip sends the HTTP request using round-robin load balancing.
func (ct *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return ct.lb.roundTrip(req)
}
148 changes: 148 additions & 0 deletions gateway/load_balancer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package gateway

import (
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)

type mockDNSResolver struct {
IPs []net.IP
Err error
}

func (m mockDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
return m.IPs, m.Err
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this is not how it works. The DNS returns IPs in different order every time.

for example:

First try

$ nslookup cortexmetrics.io
Server:		10.128.4.247
Address:	10.128.4.247#53

Non-authoritative answer:
Name:	cortexmetrics.io
Address: 172.67.146.72
Name:	cortexmetrics.io
Address: 104.21.73.166

Second try

$ nslookup cortexmetrics.io
Server:		10.128.4.247
Address:	10.128.4.247#53

Non-authoritative answer:
Name:	cortexmetrics.io
Address: 104.21.73.166
Name:	cortexmetrics.io
Address: 172.67.146.72

I think if you randomize the order here, it can get pretty similar, though

}

type customRoundTripper struct{}

func (rt customRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("Hello, client")),
Header: make(http.Header),
}
return resp, nil
}

func TestDistribution(t *testing.T) {
hostname := "example.com"
testCases := []struct {
name string
IPs []net.IP
numReqs int
refreshInterval time.Duration
tolerance float64
}{
{
name: "4 IPs, 1000 requests, 1 second refresh interval, 10% tolerance",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10% tolerance is probably too much. We will improve it in a follow up PR

IPs: []net.IP{
net.ParseIP("192.0.0.1"),
net.ParseIP("192.0.0.2"),
net.ParseIP("192.0.0.3"),
net.ParseIP("192.0.0.4"),
},
numReqs: 1000,
refreshInterval: 1 * time.Second,
tolerance: 0.1,
},
{
name: "4 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance",
IPs: []net.IP{
net.ParseIP("192.0.0.1"),
net.ParseIP("192.0.0.2"),
net.ParseIP("192.0.0.3"),
net.ParseIP("192.0.0.4"),
},
numReqs: 1000,
refreshInterval: 2 * time.Second,
tolerance: 0.1,
},
{
name: "4 IPs, 1000 requests, 3 seconds refresh interval, 10% tolerance",
IPs: []net.IP{
net.ParseIP("192.0.0.1"),
net.ParseIP("192.0.0.2"),
net.ParseIP("192.0.0.3"),
net.ParseIP("192.0.0.4"),
},
numReqs: 1000,
refreshInterval: 3 * time.Second,
tolerance: 0.1,
},
{
name: "3 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance",
IPs: []net.IP{
net.ParseIP("192.0.0.1"),
net.ParseIP("192.0.0.2"),
net.ParseIP("192.0.0.3"),
},
numReqs: 1000,
refreshInterval: 2 * time.Second,
tolerance: 0.1,
},
{
name: "10 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance",
IPs: []net.IP{
net.ParseIP("192.0.0.1"),
net.ParseIP("192.0.0.2"),
net.ParseIP("192.0.0.3"),
net.ParseIP("192.0.0.4"),
net.ParseIP("192.0.0.5"),
net.ParseIP("192.0.0.6"),
net.ParseIP("192.0.0.7"),
net.ParseIP("192.0.0.8"),
net.ParseIP("192.0.0.9"),
net.ParseIP("192.0.0.10"),
},
numReqs: 1000,
refreshInterval: 2 * time.Second,
tolerance: 0.1,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockResolver := mockDNSResolver{
IPs: tc.IPs,
Err: nil,
}

lb := newRoundRobinLoadBalancer(hostname, mockResolver.LookupIP)
lb.transport = &customRoundTripper{}

go lb.refreshIPs(tc.refreshInterval)

requestCounts := make(map[string]int)
for i := 0; i < tc.numReqs; i++ {
req := httptest.NewRequest("GET", "http://"+lb.safeGetNextIP(), nil)
resp, err := lb.roundTrip(req)
if err == nil {
addr := req.URL.Host
ip := strings.Replace(addr, hostname, "", 1)
requestCounts[ip]++
resp.Body.Close()
} else {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond)
}

expectedCount := tc.numReqs / len(tc.IPs)
minCount := int(float64(expectedCount) * (1 - tc.tolerance))
maxCount := int(float64(expectedCount) * (1 + tc.tolerance))

for ip, count := range requestCounts {
if count < minCount || count > maxCount {
t.Errorf("IP %s received %d requests, which is outside the acceptable range (%d-%d)", ip, count, minCount, maxCount)
}
}
})
}

}
Loading