Skip to content

Commit f1ae7ea

Browse files
committed
feat: add support for credentials file (#1151)
1 parent 0ac17a5 commit f1ae7ea

File tree

11 files changed

+220
-71
lines changed

11 files changed

+220
-71
lines changed

.envrc.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ export SQLSERVER_CONNECTION_NAME="project:region:instance"
1515
export SQLSERVER_USER="sqlserver-user"
1616
export SQLSERVER_PASS="sqlserver-password"
1717
export SQLSERVER_DB="sqlserver-db-name"
18+
19+
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
# Compiled binary
99
/cmd/cloud_sql_proxy/cloud_sql_proxy
1010
/cloud_sql_proxy
11+
# v2 binary
12+
/cloudsql-proxy
13+
14+
/key.json

cmd/root.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ connecting to Cloud SQL instances. It listens on a local port and forwards conne
9292
to your instance's IP address, providing a secure connection without having to manage
9393
any client SSL certificates.`,
9494
Args: func(cmd *cobra.Command, args []string) error {
95-
err := parseConfig(c.conf, args)
95+
err := parseConfig(cmd, c.conf, args)
9696
if err != nil {
9797
return err
9898
}
@@ -108,6 +108,8 @@ any client SSL certificates.`,
108108
// Global-only flags
109109
cmd.PersistentFlags().StringVarP(&c.conf.Token, "token", "t", "",
110110
"Bearer token used for authorization.")
111+
cmd.PersistentFlags().StringVarP(&c.conf.CredentialsFile, "credentials-file", "c", "",
112+
"Path to a service account key to use for authentication.")
111113

112114
// Global and per instance flags
113115
cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1",
@@ -119,7 +121,7 @@ any client SSL certificates.`,
119121
return c
120122
}
121123

122-
func parseConfig(conf *proxy.Config, args []string) error {
124+
func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
123125
// If no instance connection names were provided, error.
124126
if len(args) == 0 {
125127
return newBadCommandError("missing instance_connection_name (e.g., project:region:instance)")
@@ -129,6 +131,20 @@ func parseConfig(conf *proxy.Config, args []string) error {
129131
return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr))
130132
}
131133

134+
// If both token and credentials file were set, error.
135+
if conf.Token != "" && conf.CredentialsFile != "" {
136+
return newBadCommandError("Cannot specify --token and --credentials-file flags at the same time")
137+
}
138+
139+
switch {
140+
case conf.Token != "":
141+
cmd.Printf("Authorizing with the -token flag\n")
142+
case conf.CredentialsFile != "":
143+
cmd.Printf("Authorizing with the credentials file at %q\n", conf.CredentialsFile)
144+
default:
145+
cmd.Printf("Authorizing with Application Default Credentials")
146+
}
147+
132148
var ics []proxy.InstanceConnConfig
133149
for _, a := range args {
134150
// Assume no query params initially
@@ -211,8 +227,8 @@ func runSignalWrapper(cmd *Command) error {
211227
// Otherwise, initialize a new one.
212228
d := cmd.conf.Dialer
213229
if d == nil {
214-
var err error
215230
opts := append(cmd.conf.DialerOpts(), cloudsqlconn.WithUserAgent(userAgent))
231+
var err error
216232
d, err = cloudsqlconn.NewDialer(ctx, opts...)
217233
if err != nil {
218234
shutdownCh <- fmt.Errorf("error initializing dialer: %v", err)

cmd/root_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ func TestNewCommandArguments(t *testing.T) {
119119
Token: "MYCOOLTOKEN",
120120
}),
121121
},
122+
{
123+
desc: "using the credentiale file flag",
124+
args: []string{"--credentials-file", "/path/to/file", "proj:region:inst"},
125+
want: withDefaults(&proxy.Config{
126+
CredentialsFile: "/path/to/file",
127+
}),
128+
},
129+
{
130+
desc: "using the (short) credentiale file flag",
131+
args: []string{"-c", "/path/to/file", "proj:region:inst"},
132+
want: withDefaults(&proxy.Config{
133+
CredentialsFile: "/path/to/file",
134+
}),
135+
},
122136
}
123137

124138
for _, tc := range tcs {
@@ -186,6 +200,12 @@ func TestNewCommandWithErrors(t *testing.T) {
186200
desc: "when the port query param is not a number",
187201
args: []string{"proj:region:inst?port=hi"},
188202
},
203+
{
204+
desc: "when both token and credentials file is set",
205+
args: []string{
206+
"--token", "my-token",
207+
"--credentials-file", "/path/to/file", "proj:region:inst"},
208+
},
189209
}
190210

191211
for _, tc := range tcs {

internal/proxy/proxy.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ type Config struct {
4545
// Token is the Bearer token used for authorization.
4646
Token string
4747

48+
// CredentialsFile is the path to a service account key.
49+
CredentialsFile string
50+
4851
// Addr is the address on which to bind all instances.
4952
Addr string
5053

@@ -61,18 +64,17 @@ type Config struct {
6164
Dialer cloudsql.Dialer
6265
}
6366

64-
// NewConfig initializes a Config struct using the default database engine
65-
// ports.
66-
func NewConfig() *Config {
67-
return &Config{}
68-
}
69-
7067
func (c *Config) DialerOpts() []cloudsqlconn.Option {
7168
var opts []cloudsqlconn.Option
72-
if c.Token != "" {
69+
switch {
70+
case c.Token != "":
7371
opts = append(opts, cloudsqlconn.WithTokenSource(
7472
oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}),
7573
))
74+
case c.CredentialsFile != "":
75+
opts = append(opts, cloudsqlconn.WithCredentialsFile(
76+
c.CredentialsFile,
77+
))
7678
}
7779
return opts
7880
}

internal/proxy/proxy_test.go

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -167,30 +167,3 @@ func TestClientInitialization(t *testing.T) {
167167
})
168168
}
169169
}
170-
171-
func TestConfigDialerOpts(t *testing.T) {
172-
tcs := []struct {
173-
desc string
174-
config proxy.Config
175-
wantLen int
176-
}{
177-
{
178-
desc: "when there are no options",
179-
config: proxy.Config{},
180-
wantLen: 0,
181-
},
182-
{
183-
desc: "when a token is present",
184-
config: proxy.Config{Token: "my-token"},
185-
wantLen: 1,
186-
},
187-
}
188-
189-
for _, tc := range tcs {
190-
t.Run(tc.desc, func(t *testing.T) {
191-
if got := tc.config.DialerOpts(); tc.wantLen != len(got) {
192-
t.Errorf("want len = %v, got = %v", tc.wantLen, len(got))
193-
}
194-
})
195-
}
196-
}

tests/alldb_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"net/http"
2222
"testing"
23+
"time"
2324
)
2425

2526
// requireAllVars skips the given test if at least one environment variable is undefined.
@@ -43,13 +44,17 @@ func TestMultiInstanceDial(t *testing.T) {
4344
t.Skip("skipping Health Check integration tests")
4445
}
4546
requireAllVars(t)
46-
ctx := context.Background()
47-
48-
var args []string
49-
args = append(args, fmt.Sprintf("-instances=%s=tcp:%d,%s=tcp:%d,%s=tcp:%d", *mysqlConnName, mysqlPort, *postgresConnName, postgresPort, *sqlserverConnName, sqlserverPort))
50-
args = append(args, "-use_http_health_check")
47+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
48+
defer cancel()
5149

5250
// Start the proxy.
51+
args := []string{
52+
// This test doesn't care what the instance port is, so use "0" which
53+
// means, let the runtime pick a random port.
54+
fmt.Sprintf("-instances=%s=tcp:0,%s=tcp:0,%s=tcp:0",
55+
*mysqlConnName, *postgresConnName, *sqlserverConnName),
56+
"-use_http_health_check",
57+
}
5358
p, err := StartProxy(ctx, args...)
5459
if err != nil {
5560
t.Fatalf("unable to start proxy: %v", err)

testsV2/connection_test.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,45 @@ package tests
1717
import (
1818
"context"
1919
"database/sql"
20+
"os"
2021
"testing"
22+
"time"
23+
24+
"golang.org/x/oauth2"
25+
"golang.org/x/oauth2/google"
26+
"google.golang.org/api/sqladmin/v1"
2127
)
2228

23-
// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
24-
func proxyConnTest(t *testing.T, connName, driver, dsn string, port int, dir string) {
25-
ctx := context.Background()
29+
const connTestTimeout = time.Minute
2630

27-
args := []string{connName}
31+
// removeAuthEnvVar retrieves an OAuth2 token and a path to a service account key
32+
// and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function
33+
// that restores the original setup.
34+
func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) {
35+
ts, err := google.DefaultTokenSource(context.Background(), sqladmin.SqlserviceAdminScope)
36+
if err != nil {
37+
t.Errorf("failed to resolve token source: %v", err)
38+
}
39+
tok, err := ts.Token()
40+
if err != nil {
41+
t.Errorf("failed to get token: %v", err)
42+
}
43+
path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS")
44+
if !ok {
45+
t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment")
46+
}
47+
if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil {
48+
t.Fatalf("failed to unset GOOGLE_APPLICATION_CREDENTIALS")
49+
}
50+
return tok, path, func() {
51+
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", path)
52+
}
53+
}
2854

55+
// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
56+
func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
57+
ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout)
58+
defer cancel()
2959
// Start the proxy
3060
p, err := StartProxy(ctx, args...)
3161
if err != nil {

testsV2/mysql_test.go

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,76 @@ var (
2727
mysqlConnName = flag.String("mysql_conn_name", os.Getenv("MYSQL_CONNECTION_NAME"), "Cloud SQL MYSQL instance connection name, in the form of 'project:region:instance'.")
2828
mysqlUser = flag.String("mysql_user", os.Getenv("MYSQL_USER"), "Name of database user.")
2929
mysqlPass = flag.String("mysql_pass", os.Getenv("MYSQL_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).")
30-
mysqlDb = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.")
31-
32-
mysqlPort = 3306
30+
mysqlDB = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.")
3331
)
3432

35-
func requireMysqlVars(t *testing.T) {
33+
func requireMySQLVars(t *testing.T) {
3634
switch "" {
3735
case *mysqlConnName:
3836
t.Fatal("'mysql_conn_name' not set")
3937
case *mysqlUser:
4038
t.Fatal("'mysql_user' not set")
4139
case *mysqlPass:
4240
t.Fatal("'mysql_pass' not set")
43-
case *mysqlDb:
41+
case *mysqlDB:
4442
t.Fatal("'mysql_db' not set")
4543
}
4644
}
4745

48-
func TestMysqlTcp(t *testing.T) {
46+
func TestMySQLTCP(t *testing.T) {
47+
if testing.Short() {
48+
t.Skip("skipping MySQL integration tests")
49+
}
50+
requireMySQLVars(t)
51+
cfg := mysql.Config{
52+
User: *mysqlUser,
53+
Passwd: *mysqlPass,
54+
DBName: *mysqlDB,
55+
AllowNativePasswords: true,
56+
Addr: "127.0.0.1:3306",
57+
Net: "tcp",
58+
}
59+
proxyConnTest(t, []string{*mysqlConnName}, "mysql", cfg.FormatDSN())
60+
}
61+
62+
func TestMySQLAuthWithToken(t *testing.T) {
4963
if testing.Short() {
5064
t.Skip("skipping MySQL integration tests")
5165
}
52-
requireMysqlVars(t)
66+
requireMySQLVars(t)
67+
tok, _, cleanup := removeAuthEnvVar(t)
68+
defer cleanup()
69+
70+
cfg := mysql.Config{
71+
User: *mysqlUser,
72+
Passwd: *mysqlPass,
73+
DBName: *mysqlDB,
74+
AllowNativePasswords: true,
75+
Addr: "127.0.0.1:3306",
76+
Net: "tcp",
77+
}
78+
proxyConnTest(t,
79+
[]string{"--token", tok.AccessToken, *mysqlConnName},
80+
"mysql", cfg.FormatDSN())
81+
}
82+
83+
func TestMySQLAuthWithCredentialsFile(t *testing.T) {
84+
if testing.Short() {
85+
t.Skip("skipping MySQL integration tests")
86+
}
87+
requireMySQLVars(t)
88+
_, path, cleanup := removeAuthEnvVar(t)
89+
defer cleanup()
90+
5391
cfg := mysql.Config{
5492
User: *mysqlUser,
5593
Passwd: *mysqlPass,
56-
DBName: *mysqlDb,
94+
DBName: *mysqlDB,
5795
AllowNativePasswords: true,
5896
Addr: "127.0.0.1:3306",
5997
Net: "tcp",
6098
}
61-
proxyConnTest(t, *mysqlConnName, "mysql", cfg.FormatDSN(), mysqlPort, "")
99+
proxyConnTest(t,
100+
[]string{"--credentials-file", path, *mysqlConnName},
101+
"mysql", cfg.FormatDSN())
62102
}

0 commit comments

Comments
 (0)