Skip to content
38 changes: 36 additions & 2 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh

import (
"errors"
"time"

"github.com/databricks/cli/cmd/root"
Expand All @@ -18,10 +19,18 @@ func newConnectCommand() *cobra.Command {
This command establishes an SSH connection to Databricks compute, setting up
the SSH server and handling the connection proxy.

For dedicated clusters:
databricks ssh connect --cluster=<cluster-id>

For serverless compute:
databricks ssh connect --name=<connection-name> [--accelerator=<accelerator>]

` + disclaimer,
}

var clusterID string
var connectionName string
var accelerator string
var proxyMode bool
var serverMetadata string
var shutdownDelay time.Duration
Expand All @@ -30,9 +39,11 @@ the SSH server and handling the connection proxy.
var releasesDir string
var autoStartCluster bool
var userKnownHostsFile string
var liteswap string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")
Expand All @@ -50,6 +61,9 @@ the SSH server and handling the connection proxy.
cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client")
cmd.Flags().MarkHidden("user-known-hosts-file")

cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)")
cmd.Flags().MarkHidden("liteswap")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
if proxyMode {
Expand All @@ -64,20 +78,40 @@ the SSH server and handling the connection proxy.
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name")
}

if accelerator != "" && connectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided

opts := client.ClientOptions{
Profile: wsClient.Config.Profile,
ClusterID: clusterID,
ConnectionName: connectionName,
Accelerator: accelerator,
ProxyMode: proxyMode,
ServerMetadata: serverMetadata,
ShutdownDelay: shutdownDelay,
MaxClients: maxClients,
HandoverTimeout: handoverTimeout,
ReleasesDir: releasesDir,
ServerTimeout: serverTimeout,
TaskStartupTimeout: taskStartupTimeout,
AutoStartCluster: autoStartCluster,
ClientPublicKeyName: clientPublicKeyName,
ClientPrivateKeyName: clientPrivateKeyName,
UserKnownHostsFile: userKnownHostsFile,
Liteswap: liteswap,
AdditionalArgs: args,
}
return client.Run(ctx, wsClient, opts)
Expand Down
1 change: 1 addition & 0 deletions experimental/ssh/cmd/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
defaultHandoverTimeout = 30 * time.Minute

serverTimeout = 24 * time.Hour
taskStartupTimeout = 10 * time.Minute
serverPortRange = 100
serverConfigDir = ".ssh-tunnel"
serverPrivateKeyName = "server-private-key"
Expand Down
4 changes: 4 additions & 0 deletions experimental/ssh/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes.
var maxClients int
var shutdownDelay time.Duration
var clusterID string
var sessionID string
var version string
var secretScopeName string
var authorizedKeySecretName string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)")
cmd.MarkFlagRequired("session-id")
cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys")
cmd.MarkFlagRequired("secret-scope-name")
cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key")
Expand All @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes.
wsc := cmdctx.WorkspaceClient(ctx)
opts := server.ServerOptions{
ClusterID: clusterID,
SessionID: sessionID,
MaxClients: maxClients,
ShutdownDelay: shutdownDelay,
Version: version,
Expand Down
Loading