Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,39 +87,39 @@ type ClientOpts struct {
ProxyPath string
}

// NewClient creates a new API client.
func NewClient(opts ClientOpts) Client {
if opts.Out == nil {
panic("unexpected nil out option")
}

flags := opts.Flags
if flags == nil {
flags = defaultFlags()
}

httpClient := http.DefaultClient

func buildTransport(opts ClientOpts, flags *Flags) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
customTransport := false

if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify {
customTransport = true
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}

if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}

if applyProxy(transport, opts.ProxyURL, opts.ProxyPath) {
customTransport = true
if opts.ProxyURL != nil || opts.ProxyPath != "" {
transport = NewProxyTransport(transport, opts.ProxyURL, opts.ProxyPath)
}

if customTransport {
httpClient = &http.Client{
Transport: transport,
}
return transport
}

// NewClient creates a new API client.
func NewClient(opts ClientOpts) Client {
if opts.Out == nil {
panic("unexpected nil out option")
}

flags := opts.Flags
if flags == nil {
flags = defaultFlags()
}

transport := buildTransport(opts, flags)

httpClient := &http.Client{
Transport: transport,
}

return &client{
Expand Down
39 changes: 38 additions & 1 deletion internal/api/api_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
package api

// TODO: implement a super basic GraphQL server that can return canned results.
import (
"net/url"
"testing"
)

func TestBuildTransport(t *testing.T) {
boolPtr := func(b bool) *bool { return &b }

t.Run("insecure skip verify", func(t *testing.T) {
transport := buildTransport(ClientOpts{}, &Flags{insecureSkipVerify: boolPtr(true)})
if !transport.TLSClientConfig.InsecureSkipVerify {
t.Error("expected InsecureSkipVerify to be true")
}
})

t.Run("unix socket proxy clears Proxy", func(t *testing.T) {
transport := buildTransport(ClientOpts{ProxyPath: "/tmp/test.sock"}, defaultFlags())
if transport.Proxy != nil {
t.Error("expected Proxy to be nil")
}
})

// http.DefaultTransport.Dial / DialTLS is already set and we can't compare two funcs
// so our best effort here is to just check Proxy is nil / not nill based on the ProxyURL
t.Run("http proxy clears Proxy", func(t *testing.T) {
transport := buildTransport(ClientOpts{ProxyURL: &url.URL{Scheme: "http", Host: "proxy:8080"}}, defaultFlags())
if transport.Proxy != nil {
t.Error("expected Proxy to be nil")
}
})

t.Run("socks5 proxy sets Proxy", func(t *testing.T) {
transport := buildTransport(ClientOpts{ProxyURL: &url.URL{Scheme: "socks5", Host: "proxy:1080"}}, defaultFlags())
if transport.Proxy == nil {
t.Error("expected Proxy to be set")
}
})
}
Comment thread
burmudar marked this conversation as resolved.
Outdated
14 changes: 4 additions & 10 deletions internal/api/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import (
"net/url"
)

func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) (applied bool) {
if proxyURL == nil && proxyPath == "" {
return false
}
func NewProxyTransport(base *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport {
Comment thread
burmudar marked this conversation as resolved.
Outdated
// Clone so that we don't change the original transport
transport := base.Clone()
Comment thread
burmudar marked this conversation as resolved.
Outdated

handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) {
// Extract the hostname (without the port) for TLS SNI
Expand All @@ -34,8 +33,6 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
return tlsConn, nil
}

proxyApplied := false

if proxyPath != "" {
dial := func(ctx context.Context, _, _ string) (net.Conn, error) {
d := net.Dialer{}
Expand All @@ -52,13 +49,11 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
transport.DialTLSContext = dialTLS
// clear out any system proxy settings
transport.Proxy = nil
proxyApplied = true
} else if proxyURL != nil {
switch proxyURL.Scheme {
case "socks5", "socks5h":
// SOCKS proxies work out of the box - no need to manually dial
transport.Proxy = http.ProxyURL(proxyURL)
proxyApplied = true
case "http", "https":
dial := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Dial the proxy
Expand Down Expand Up @@ -130,9 +125,8 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
transport.DialTLSContext = dialTLS
// clear out any system proxy settings
transport.Proxy = nil
proxyApplied = true
}
}

return proxyApplied
return transport
}
Loading