Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func (d *Dialer) connectionInfoCache(
d.dialerID, useIAMAuthNDial,
)
}
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
c = newMonitoredCache(cache, cn, d.failoverPeriod, d.resolver, d.logger)
d.cache[k] = c

return c, nil
Expand Down
154 changes: 139 additions & 15 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)

_, err = d.EngineVersion(context.Background(), tc.icn)
if err == nil {
Expand Down Expand Up @@ -561,7 +561,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)

err = d.Warmup(context.Background(), tc.icn, tc.opts...)
if err == nil {
Expand Down Expand Up @@ -769,7 +769,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)

_, err = d.Dial(context.Background(), tc.icn, tc.opts...)
if err == nil {
Expand Down Expand Up @@ -819,7 +819,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
},
},
}
d.cache[createKey(cn)] = newMonitoredCache(nil, spy, cn, 0, nil, nil)
d.cache[createKey(cn)] = newMonitoredCache(spy, cn, 0, nil, nil)

_, err = d.Dial(context.Background(), icn)
if !errors.Is(err, sentinel) {
Expand Down Expand Up @@ -1063,16 +1063,22 @@ type changingResolver struct {
stage atomic.Int32
}

func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) {
// For TestDialerFailoverOnInstanceChange
if name == "update.example.com" {
if r.stage.Load() == 0 {
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
func (r *changingResolver) Resolve(ctx context.Context, name string) (instance.ConnName, error) {
select {
// for TestDialerClosesOldConnectionsOpenAfterDnsChange
case <-ctx.Done():
return instance.ConnName{}, fmt.Errorf("mock dns timeout error")
default:
// For TestDialerFailoverOnInstanceChange
if name == "update.example.com" {
if r.stage.Load() == 0 {
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
}
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com")
}
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com")
// TestDialerFailsDnsSrvRecordMissing
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
}
// TestDialerFailsDnsSrvRecordMissing
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
}

func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
Expand Down Expand Up @@ -1107,7 +1113,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
},
})

// Start the proxy for instance 1
// Execute the proxy for instance 1
stop1 := mock.StartServerProxy(t, inst)
t.Cleanup(func() {
stop1()
Expand All @@ -1127,7 +1133,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
t.Fatal("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed.")
}

// Start the proxy for instance 2
// Execute the proxy for instance 2
stop2 := mock.StartServerProxy(t, inst2)
t.Cleanup(func() {
stop2()
Expand All @@ -1140,6 +1146,124 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {

}

func TestDialerClosesOldConnectionsOpenAfterDnsChange(t *testing.T) {
// At first, the resolver will resolve
// update.example.com to "my-instance"
// Then, the resolver will resolve the same domain name to
// "my-instance2".
// This shows that on every call to Dial(), the dialer will resolve the
// SRV record and connect to the correct instance.
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNS("update.example.com"),
)
inst2 := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance2",
mock.WithDNS("update.example.com"),
)
r := &changingResolver{}

d := setupDialer(t, setupConfig{
skipServer: true,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
mock.InstanceGetSuccess(inst2, 1),
mock.CreateEphemeralSuccess(inst2, 1),
},
dialerOptions: []Option{
WithFailoverPeriod(10 * time.Millisecond),
WithResolver(r),
WithTokenSource(mock.EmptyTokenSource{}),
WithContextDebugLogger(&testLog{t: t}),
},
})

srv := mock.NewFailoverTestServer(t)
t.Cleanup(func() {
srv.Close()
})

// Execute the mock server on 3307 for instance 1
srv.Start(&inst)
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()

// Dial using a context with a timeout, similar to how the auth proxy
// uses the dialer
dialCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
conn, err := d.Dial(dialCtx, "update.example.com")
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
c1 := mock.NewDbClient(t, conn, "c1")
go c1.Execute(ctx)
defer c1.Close()

// Wait 15 seconds to give the fake client time to connect and
// read from the socket
time.Sleep(15 * time.Second)

// Stop the instance1 mock, then start the instance2 mock server.
// The dialer will need to refresh before it can connect to instance2.
t.Logf("Switching to instance2 server")
srv.Stop()
srv.Start(&inst2)

// Update the DNS resolver. This should signal to the dialer that
// it should disconnect and refresh certificates.
t.Logf("Updating DNS record")
r.stage.Store(1)
time.Sleep(1 * time.Second)

// Dial the domain name again. This should now connect to instance2.
dialCtx, cancelFn = context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
conn, err = d.Dial(dialCtx, "update.example.com")
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
c2 := mock.NewDbClient(t, conn, "c2")
go c2.Execute(ctx)
defer c2.Close()

// Wait for the dial context to expire. This ensures that the
// DNS record loop in the dialer is not dependent on the dial context.
time.Sleep(15 * time.Second)

// Check that the client connections are in the correct state:
// c1 should be closed because it connected before the domain name changed.
// c2 should be open because it connected after the domain name changed.

// Assert that c1 is closed because the dns record changed
if !c1.Closed() {
t.Errorf("want c1 closed, was open")
}
// Assert that c1 received some messages, only from inst1
if len(c1.Recv()) == 0 {
t.Errorf("c1 wants >0 messages received, got 0")
}
for _, m := range c1.Recv() {
if m != "my-instance" {
t.Errorf("c1 wants messages from my-instance, got %q", m)
}
}
// Assert that c2 is open. No domain name changes have occurred.
if c2.Closed() {
t.Errorf("want c2 open, was closed")
}
// Assert that c2 received some messages, only from inst2
if len(c2.Recv()) == 0 {
t.Errorf("c2 wants >0 messages received, got 0")
}
for _, m := range c2.Recv() {
if m != "my-instance2" {
t.Errorf("c2 wants messages from my-instance2, got %q", m)
}
}
}

func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) {

tcs := []struct {
Expand Down Expand Up @@ -1389,7 +1513,7 @@ func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
mock.RotateCA(inst)
}

// Start the server with new certificates
// Execute the server with new certificates
cancel2 := mock.StartServerProxy(t, inst)
defer cancel2()

Expand Down
Loading
Loading