-
Notifications
You must be signed in to change notification settings - Fork 3
Implement load balancer #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1d751f4
d73d2fd
8724add
1ff491b
7980e37
16ffedb
9956b0e
d33405b
819c3a4
b94190c
2d60d8c
cef3a95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| package gateway | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "log" | ||
| "net" | ||
| "net/http" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
| ) | ||
|
|
||
| type DNSResolver interface { | ||
| LookupIP(string) ([]net.IP, error) | ||
| } | ||
|
|
||
| type DefaultDNSResolver struct{} | ||
|
|
||
| func (d DefaultDNSResolver) LookupIP(hostname string) ([]net.IP, error) { | ||
| return net.LookupIP(hostname) | ||
| } | ||
|
|
||
| type roundRobinLoadBalancer struct { | ||
| hostname string | ||
| ips []string | ||
| currentIndex int | ||
| transport http.RoundTripper | ||
| resolveIPs func(hostname string) ([]net.IP, error) | ||
| sync.RWMutex | ||
| } | ||
|
|
||
| func newRoundRobinLoadBalancer(hostname string, resolver func(hostname string) ([]net.IP, error)) *roundRobinLoadBalancer { | ||
| lb := &roundRobinLoadBalancer{ | ||
| hostname: hostname, | ||
| transport: http.DefaultTransport, | ||
| resolveIPs: resolver, | ||
| } | ||
|
|
||
| // Resolve IPs initially | ||
| ips, err := lb.resolveIPs(hostname) | ||
| if err != nil { | ||
| log.Printf("Failed to resolve IPs for hostname %s: %v", lb.hostname, err) | ||
| } else { | ||
| lb.ips = ipsToStrings(ips) | ||
| } | ||
|
|
||
| return lb | ||
| } | ||
|
|
||
| func (lb *roundRobinLoadBalancer) roundTrip(req *http.Request) (*http.Response, error) { | ||
| lb.Lock() | ||
| defer lb.Unlock() | ||
|
|
||
| if len(lb.ips) == 0 { | ||
| // TODO: replace format error with a log statement | ||
| return nil, fmt.Errorf("no IP addresses available") | ||
| } | ||
|
|
||
| ip := lb.getNextIP() | ||
| req.URL.Host = strings.Replace(req.URL.Host, lb.hostname, ip, 1) | ||
| lb.currentIndex++ | ||
|
|
||
| return lb.transport.RoundTrip(req) | ||
| } | ||
|
|
||
| func (lb *roundRobinLoadBalancer) getNextIP() string { | ||
| return lb.ips[lb.currentIndex%len(lb.ips)] | ||
| } | ||
|
|
||
| func (lb *roundRobinLoadBalancer) safeGetNextIP() string { | ||
| lb.RLock() | ||
| defer lb.RUnlock() | ||
|
|
||
| return lb.getNextIP() | ||
| } | ||
|
|
||
| // Refresh IPs periodically | ||
| func (lb *roundRobinLoadBalancer) refreshIPs(refreshInterval time.Duration) { | ||
| for { | ||
| ips, err := lb.resolveIPs(lb.hostname) | ||
| if err != nil { | ||
| // TODO: replace std library log package with logrus | ||
| log.Printf("Failed to resolve IPs for hostname %s: %v", lb.hostname, err) | ||
| } else { | ||
| lb.Lock() | ||
| lb.ips = ipsToStrings(ips) | ||
| lb.currentIndex = 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure about this, as resetting to zero means we will use the first resolved IP more than others. But I think tests will give the final veredict: Include some tests that prove every ip address gets similar number of requests over many refresh dns. |
||
| lb.Unlock() | ||
| } | ||
| time.Sleep(refreshInterval) | ||
| } | ||
| } | ||
|
|
||
| func ipsToStrings(ips []net.IP) []string { | ||
| strs := make([]string, len(ips)) | ||
| for i, ip := range ips { | ||
| strs[i] = ip.String() | ||
| } | ||
| return strs | ||
| } | ||
|
|
||
| // CustomTransport wraps http.Transport and embeds the round-robin load balancer. | ||
| type CustomTransport struct { | ||
| http.Transport | ||
| lb *roundRobinLoadBalancer | ||
| } | ||
|
|
||
| // RoundTrip sends the HTTP request using round-robin load balancing. | ||
| func (ct *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
| return ct.lb.roundTrip(req) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| package gateway | ||
|
|
||
| import ( | ||
| "io" | ||
| "net" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "strings" | ||
| "testing" | ||
| "time" | ||
| ) | ||
|
|
||
| type mockDNSResolver struct { | ||
| IPs []net.IP | ||
| Err error | ||
| } | ||
|
|
||
| func (m mockDNSResolver) LookupIP(hostname string) ([]net.IP, error) { | ||
| return m.IPs, m.Err | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately this is not how it works. The DNS returns IPs in different order every time. for example: First try Second try I think if you randomize the order here, it can get pretty similar, though |
||
| } | ||
|
|
||
| type customRoundTripper struct{} | ||
|
|
||
| func (rt customRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||
| resp := &http.Response{ | ||
| StatusCode: 200, | ||
| Body: io.NopCloser(strings.NewReader("Hello, client")), | ||
| Header: make(http.Header), | ||
| } | ||
| return resp, nil | ||
| } | ||
|
|
||
| func TestDistribution(t *testing.T) { | ||
| hostname := "example.com" | ||
| testCases := []struct { | ||
| name string | ||
| IPs []net.IP | ||
| numReqs int | ||
| refreshInterval time.Duration | ||
| tolerance float64 | ||
| }{ | ||
| { | ||
| name: "4 IPs, 1000 requests, 1 second refresh interval, 10% tolerance", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 10% tolerance is probably too much. We will improve it in a follow up PR |
||
| IPs: []net.IP{ | ||
| net.ParseIP("192.0.0.1"), | ||
| net.ParseIP("192.0.0.2"), | ||
| net.ParseIP("192.0.0.3"), | ||
| net.ParseIP("192.0.0.4"), | ||
| }, | ||
| numReqs: 1000, | ||
| refreshInterval: 1 * time.Second, | ||
| tolerance: 0.1, | ||
| }, | ||
| { | ||
| name: "4 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance", | ||
| IPs: []net.IP{ | ||
| net.ParseIP("192.0.0.1"), | ||
| net.ParseIP("192.0.0.2"), | ||
| net.ParseIP("192.0.0.3"), | ||
| net.ParseIP("192.0.0.4"), | ||
| }, | ||
| numReqs: 1000, | ||
| refreshInterval: 2 * time.Second, | ||
| tolerance: 0.1, | ||
| }, | ||
| { | ||
| name: "4 IPs, 1000 requests, 3 seconds refresh interval, 10% tolerance", | ||
| IPs: []net.IP{ | ||
| net.ParseIP("192.0.0.1"), | ||
| net.ParseIP("192.0.0.2"), | ||
| net.ParseIP("192.0.0.3"), | ||
| net.ParseIP("192.0.0.4"), | ||
| }, | ||
| numReqs: 1000, | ||
| refreshInterval: 3 * time.Second, | ||
| tolerance: 0.1, | ||
| }, | ||
| { | ||
| name: "3 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance", | ||
| IPs: []net.IP{ | ||
| net.ParseIP("192.0.0.1"), | ||
| net.ParseIP("192.0.0.2"), | ||
| net.ParseIP("192.0.0.3"), | ||
| }, | ||
| numReqs: 1000, | ||
| refreshInterval: 2 * time.Second, | ||
| tolerance: 0.1, | ||
| }, | ||
| { | ||
| name: "10 IPs, 1000 requests, 2 seconds refresh interval, 10% tolerance", | ||
| IPs: []net.IP{ | ||
| net.ParseIP("192.0.0.1"), | ||
| net.ParseIP("192.0.0.2"), | ||
| net.ParseIP("192.0.0.3"), | ||
| net.ParseIP("192.0.0.4"), | ||
| net.ParseIP("192.0.0.5"), | ||
| net.ParseIP("192.0.0.6"), | ||
| net.ParseIP("192.0.0.7"), | ||
| net.ParseIP("192.0.0.8"), | ||
| net.ParseIP("192.0.0.9"), | ||
| net.ParseIP("192.0.0.10"), | ||
| }, | ||
| numReqs: 1000, | ||
| refreshInterval: 2 * time.Second, | ||
| tolerance: 0.1, | ||
| }, | ||
| } | ||
|
|
||
| for _, tc := range testCases { | ||
| t.Run(tc.name, func(t *testing.T) { | ||
| mockResolver := mockDNSResolver{ | ||
| IPs: tc.IPs, | ||
| Err: nil, | ||
| } | ||
|
|
||
| lb := newRoundRobinLoadBalancer(hostname, mockResolver.LookupIP) | ||
| lb.transport = &customRoundTripper{} | ||
|
|
||
| go lb.refreshIPs(tc.refreshInterval) | ||
|
|
||
| requestCounts := make(map[string]int) | ||
| for i := 0; i < tc.numReqs; i++ { | ||
| req := httptest.NewRequest("GET", "http://"+lb.safeGetNextIP(), nil) | ||
| resp, err := lb.roundTrip(req) | ||
| if err == nil { | ||
| addr := req.URL.Host | ||
| ip := strings.Replace(addr, hostname, "", 1) | ||
| requestCounts[ip]++ | ||
| resp.Body.Close() | ||
| } else { | ||
| t.Fatal(err) | ||
| } | ||
| time.Sleep(10 * time.Millisecond) | ||
| } | ||
|
|
||
| expectedCount := tc.numReqs / len(tc.IPs) | ||
| minCount := int(float64(expectedCount) * (1 - tc.tolerance)) | ||
| maxCount := int(float64(expectedCount) * (1 + tc.tolerance)) | ||
|
|
||
| for ip, count := range requestCounts { | ||
| if count < minCount || count > maxCount { | ||
| t.Errorf("IP %s received %d requests, which is outside the acceptable range (%d-%d)", ip, count, minCount, maxCount) | ||
| } | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.