Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions pkg/lib/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"path/filepath"
Expand All @@ -16,6 +17,11 @@ import (
"sigs.k8s.io/controller-runtime/pkg/metrics/filters"
)

// certPoolGetter is an interface for getting a certificate pool
type certPoolGetter interface {
GetCertPool() *x509.CertPool
}

// Option applies a configuration option to the given config.
type Option func(s *serverConfig)

Expand Down Expand Up @@ -94,6 +100,10 @@ func (sc *serverConfig) getAddress(tlsEnabled bool) string {
return ":8080"
}

func (sc *serverConfig) clientCAEnabled() bool {
return sc.clientCAPath != nil && *sc.clientCAPath != ""
}

func (sc serverConfig) getListenAndServeFunc() (func() error, error) {
tlsEnabled, err := sc.tlsEnabled()
if err != nil {
Expand Down Expand Up @@ -168,15 +178,23 @@ func (sc serverConfig) getListenAndServeFunc() (func() error, error) {
return nil, fmt.Errorf("error creating cert file watcher: %v", err)
}
csw.Run(context.Background())
certPoolStore, err := filemonitor.NewCertPoolStore(*sc.clientCAPath)
if err != nil {
return nil, fmt.Errorf("certificate monitoring for client-ca failed: %v", err)
}
cpsw, err := filemonitor.NewWatch(sc.logger, []string{filepath.Dir(*sc.clientCAPath)}, certPoolStore.HandleCABundleUpdate)
if err != nil {
return nil, fmt.Errorf("error creating cert file watcher: %v", err)

// Only setup client CA monitoring if clientCAPath is provided
var certPoolStore certPoolGetter
if sc.clientCAEnabled() {
cps, err := filemonitor.NewCertPoolStore(*sc.clientCAPath)
if err != nil {
return nil, fmt.Errorf("certificate monitoring for client-ca failed: %v", err)
}
cpsw, err := filemonitor.NewWatch(sc.logger, []string{filepath.Dir(*sc.clientCAPath)}, cps.HandleCABundleUpdate)
if err != nil {
return nil, fmt.Errorf("error creating cert file watcher: %v", err)
}
cpsw.Run(context.Background())
certPoolStore = cps
} else {
sc.logger.Info("No client CA provided, client certificate verification disabled")
}
cpsw.Run(context.Background())

s.TLSConfig = &tls.Config{
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expand All @@ -187,11 +205,15 @@ func (sc serverConfig) getListenAndServeFunc() (func() error, error) {
if cert := certStore.GetCertificate(); cert != nil {
certs = append(certs, *cert)
}
return &tls.Config{
tlsCfg := &tls.Config{
Certificates: certs,
ClientCAs: certPoolStore.GetCertPool(),
ClientAuth: tls.VerifyClientCertIfGiven,
}, nil
}
// Only configure client CA verification if certPoolStore is available
if certPoolStore != nil {
tlsCfg.ClientCAs = certPoolStore.GetCertPool()
tlsCfg.ClientAuth = tls.VerifyClientCertIfGiven
}
return tlsCfg, nil
},
NextProtos: []string{"http/1.1"}, // Disable HTTP/2 for security
}
Expand Down
109 changes: 109 additions & 0 deletions pkg/lib/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,115 @@ func TestGetListenAndServeFunc_WithoutKubeConfig(t *testing.T) {
assert.NoError(t, err, "GetListenAndServeFunc should succeed without kubeConfig")
}

// TestGetListenAndServeFunc_WithEmptyClientCA tests that the server
// starts successfully when TLS is enabled but client-ca is empty
func TestGetListenAndServeFunc_WithEmptyClientCA(t *testing.T) {
// Generate test certificates dynamically
caCert, caKey, err := generateCA()
require.NoError(t, err)

serverCert, serverKey, err := generateServerCert(caCert, caKey, "localhost")
require.NoError(t, err)

tmpDir, err := os.MkdirTemp("", "server-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

tlsCertPath := filepath.Join(tmpDir, "tls.crt")
tlsKeyPath := filepath.Join(tmpDir, "tls.key")
emptyClientCAPath := "" // Empty client CA path

err = os.WriteFile(tlsCertPath, serverCert, 0644)
require.NoError(t, err)
err = os.WriteFile(tlsKeyPath, serverKey, 0600)
require.NoError(t, err)

logger := logrus.New()
logger.SetOutput(io.Discard)

// Test with TLS enabled but empty client CA - should succeed
_, err = GetListenAndServeFunc(
WithLogger(logger),
WithTLS(&tlsCertPath, &tlsKeyPath, &emptyClientCAPath),
WithDebug(false),
)

assert.NoError(t, err, "GetListenAndServeFunc should succeed with empty client-ca")
}

// TestGetListenAndServeFunc_WithNilClientCA tests that the server
// starts successfully when TLS is enabled but client-ca pointer is nil
func TestGetListenAndServeFunc_WithNilClientCA(t *testing.T) {
// Generate test certificates dynamically
caCert, caKey, err := generateCA()
require.NoError(t, err)

serverCert, serverKey, err := generateServerCert(caCert, caKey, "localhost")
require.NoError(t, err)

tmpDir, err := os.MkdirTemp("", "server-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

tlsCertPath := filepath.Join(tmpDir, "tls.crt")
tlsKeyPath := filepath.Join(tmpDir, "tls.key")

err = os.WriteFile(tlsCertPath, serverCert, 0644)
require.NoError(t, err)
err = os.WriteFile(tlsKeyPath, serverKey, 0600)
require.NoError(t, err)

logger := logrus.New()
logger.SetOutput(io.Discard)

// Test with TLS enabled but nil client CA pointer - should succeed
_, err = GetListenAndServeFunc(
WithLogger(logger),
WithTLS(&tlsCertPath, &tlsKeyPath, nil),
WithDebug(false),
)

assert.NoError(t, err, "GetListenAndServeFunc should succeed with nil client-ca pointer")
}

// TestClientCAEnabled tests the clientCAEnabled helper function
func TestClientCAEnabled(t *testing.T) {
tests := []struct {
name string
clientCAPath *string
expected bool
}{
{
name: "nil pointer",
clientCAPath: nil,
expected: false,
},
{
name: "empty string",
clientCAPath: strPtr(""),
expected: false,
},
{
name: "valid path",
clientCAPath: strPtr("/path/to/ca.crt"),
expected: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sc := &serverConfig{
clientCAPath: tt.clientCAPath,
}
assert.Equal(t, tt.expected, sc.clientCAEnabled(), "clientCAEnabled result should match expected")
})
}
}

func strPtr(s string) *string {
return &s
}

// TestHTTPClientHasTLSConfig verifies that rest.HTTPClientFor creates a client
// with proper TLS configuration including CA certificates
func TestHTTPClientHasTLSConfig(t *testing.T) {
Expand Down