diff --git a/cmd/src/gateway.go b/cmd/src/gateway.go index c84340df48..0540d694e3 100644 --- a/cmd/src/gateway.go +++ b/cmd/src/gateway.go @@ -17,6 +17,7 @@ Usage: The commands are: benchmark runs benchmarks against Cody Gateway + benchmark-stream runs benchmarks against Cody Gateway code completion streaming endpoints Use "src gateway [command] -h" for more information about a command. diff --git a/cmd/src/gateway_benchmark.go b/cmd/src/gateway_benchmark.go index 0c1c78570d..2988dfb9db 100644 --- a/cmd/src/gateway_benchmark.go +++ b/cmd/src/gateway_benchmark.go @@ -26,9 +26,14 @@ type Stats struct { Total time.Duration } +type requestResult struct { + duration time.Duration + traceID string // X-Trace header value +} + func init() { usage := ` -'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints. +'src gateway benchmark' runs performance benchmarks against Cody Gateway and Sourcegraph test endpoints. Usage: @@ -39,17 +44,20 @@ Examples: $ src gateway benchmark --sgp $ src gateway benchmark --requests 50 --sgp $ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp - $ src gateway benchmark --requests 50 --csv results.csv --sgp + $ src gateway benchmark --requests 50 --csv results.csv --request-csv requests.csv --sgp + $ src gateway benchmark --gateway https://cody-gateway.sourcegraph.com --sourcegraph https://sourcegraph.com --sgp --use-special-header ` flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError) var ( - requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint") - csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)") - gatewayEndpoint = flagSet.String("gateway", "https://cody-gateway.sourcegraph.com", "Cody Gateway endpoint") - sgEndpoint = flagSet.String("sourcegraph", "https://sourcegraph.com", "Sourcegraph endpoint") - sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance") + requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint") + csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)") + requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)") + gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint") + sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint") + sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance") + useSpecialHeader = flagSet.Bool("use-special-header", false, "Use special header to test the gateway") ) handler := func(args []string) error { @@ -61,15 +69,23 @@ Examples: return cmderrors.Usage("additional arguments not allowed") } + if *useSpecialHeader { + fmt.Println("Using special header 'cody-core-gc-test'") + } + var ( httpClient = &http.Client{} endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s ) if *gatewayEndpoint != "" { fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint) + headers := http.Header{ + "X-Sourcegraph-Should-Trace": []string{"true"}, + } endpoints["ws(s): gateway"] = &webSocketClient{ - conn: nil, - URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1), + conn: nil, + URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1), + reqHeaders: headers, } endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http") } else { @@ -80,12 +96,18 @@ Examples: return cmderrors.Usage("must specify --sgp ") } fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint) + headers := http.Header{ + "Authorization": []string{"token " + *sgpToken}, + "X-Sourcegraph-Should-Trace": []string{"true"}, + } + if *useSpecialHeader { + headers.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&") + } + endpoints["ws(s): sourcegraph"] = &webSocketClient{ - conn: nil, - URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1), - headers: http.Header{ - "Authorization": []string{"token " + *sgpToken}, - }, + conn: nil, + URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1), + reqHeaders: headers, } endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http") endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket") @@ -95,21 +117,25 @@ Examples: fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount) - var results []endpointResult + var eResults []endpointResult + rResults := map[string][]requestResult{} for name, clientOrURL := range endpoints { durations := make([]time.Duration, 0, *requestCount) + rResults[name] = make([]requestResult, 0, *requestCount) fmt.Printf("\nTesting %s...", name) for i := 0; i < *requestCount; i++ { if ws, ok := clientOrURL.(*webSocketClient); ok { - duration := benchmarkEndpointWebSocket(ws) - if duration > 0 { - durations = append(durations, duration) + result := benchmarkEndpointWebSocket(ws) + if result.duration > 0 { + durations = append(durations, result.duration) + rResults[name] = append(rResults[name], result) } } else if url, ok := clientOrURL.(string); ok { - duration := benchmarkEndpointHTTP(httpClient, url, *sgpToken) - if duration > 0 { - durations = append(durations, duration) + result := benchmarkEndpointHTTP(httpClient, url, *sgpToken, *useSpecialHeader) + if result.duration > 0 { + durations = append(durations, result.duration) + rResults[name] = append(rResults[name], result) } } } @@ -117,7 +143,7 @@ Examples: stats := calculateStats(durations) - results = append(results, endpointResult{ + eResults = append(eResults, endpointResult{ name: name, avg: stats.Avg, median: stats.Median, @@ -130,14 +156,20 @@ Examples: }) } - printResults(results, requestCount) + printResults(eResults, requestCount) if *csvOutput != "" { - if err := writeResultsToCSV(*csvOutput, results, requestCount); err != nil { + if err := writeResultsToCSV(*csvOutput, eResults, requestCount); err != nil { return fmt.Errorf("failed to export CSV: %v", err) } fmt.Printf("\nResults exported to %s\n", *csvOutput) } + if *requestLevelCsvOutput != "" { + if err := writeRequestResultsToCSV(*requestLevelCsvOutput, rResults); err != nil { + return fmt.Errorf("failed to export request-level CSV: %v", err) + } + fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput) + } return nil } @@ -158,9 +190,10 @@ Examples: } type webSocketClient struct { - conn *websocket.Conn - URL string - headers http.Header + conn *websocket.Conn + URL string + reqHeaders http.Header + respHeaders http.Header } func (c *webSocketClient) reconnect() error { @@ -169,11 +202,13 @@ func (c *webSocketClient) reconnect() error { } fmt.Println("Connecting to WebSocket..", c.URL) var err error - c.conn, _, err = websocket.DefaultDialer.Dial(c.URL, c.headers) + var resp *http.Response + c.conn, resp, err = websocket.DefaultDialer.Dial(c.URL, c.reqHeaders) if err != nil { c.conn = nil // retry again later return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err) } + c.respHeaders = resp.Header fmt.Println("Connected!") return nil } @@ -190,19 +225,23 @@ type endpointResult struct { successful int } -func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Duration { +func benchmarkEndpointHTTP(client *http.Client, url, accessToken string, useSpecialHeader bool) requestResult { start := time.Now() req, err := http.NewRequest("POST", url, strings.NewReader("ping")) if err != nil { fmt.Printf("Error creating request: %v\n", err) - return 0 + return requestResult{} } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "token "+accessToken) + req.Header.Set("X-Sourcegraph-Should-Trace", "true") + if useSpecialHeader { + req.Header.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&") + } resp, err := client.Do(req) if err != nil { fmt.Printf("Error calling %s: %v\n", url, err) - return 0 + return requestResult{} } defer func() { err := resp.Body.Close() @@ -212,27 +251,30 @@ func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Du }() if resp.StatusCode != http.StatusOK { fmt.Printf("non-200 response: %v\n", resp.Status) - return 0 + return requestResult{} } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Printf("Error reading response body: %v\n", err) - return 0 + return requestResult{} } if string(body) != "pong" { fmt.Printf("Expected 'pong' response, got: %q\n", string(body)) - return 0 + return requestResult{} } - return time.Since(start) + return requestResult{ + duration: time.Since(start), + traceID: resp.Header.Get("X-Trace"), + } } -func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration { +func benchmarkEndpointWebSocket(client *webSocketClient) requestResult { // Perform initial websocket connection, if needed. if client.conn == nil { if err := client.reconnect(); err != nil { fmt.Printf("Error reconnecting: %v\n", err) - return 0 + return requestResult{} } } @@ -244,7 +286,7 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration { if err := client.reconnect(); err != nil { fmt.Printf("Error reconnecting: %v\n", err) } - return 0 + return requestResult{} } _, message, err := client.conn.ReadMessage() @@ -253,16 +295,19 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration { if err := client.reconnect(); err != nil { fmt.Printf("Error reconnecting: %v\n", err) } - return 0 + return requestResult{} } if string(message) != "pong" { fmt.Printf("Expected 'pong' response, got: %q\n", string(message)) if err := client.reconnect(); err != nil { fmt.Printf("Error reconnecting: %v\n", err) } - return 0 + return requestResult{} + } + return requestResult{ + duration: time.Since(start), + traceID: client.respHeaders.Get("Content-Type"), } - return time.Since(start) } func calculateStats(durations []time.Duration) Stats { @@ -438,3 +483,40 @@ func writeResultsToCSV(filename string, results []endpointResult, requestCount * return nil } + +func writeRequestResultsToCSV(filename string, results map[string][]requestResult) error { + file, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create CSV file: %v", err) + } + defer func() { + err := file.Close() + if err != nil { + return + } + }() + + writer := csv.NewWriter(file) + defer writer.Flush() + + // Write header + header := []string{"Endpoint", "Duration (ms)", "Trace ID"} + if err := writer.Write(header); err != nil { + return fmt.Errorf("failed to write CSV header: %v", err) + } + + for endpoint, requestResults := range results { + for _, result := range requestResults { + row := []string{ + endpoint, + fmt.Sprintf("%.2f", float64(result.duration.Microseconds())/1000), + result.traceID, + } + if err := writer.Write(row); err != nil { + return fmt.Errorf("failed to write CSV row: %v", err) + } + } + } + + return nil +} diff --git a/cmd/src/gateway_benchmark_stream.go b/cmd/src/gateway_benchmark_stream.go new file mode 100644 index 0000000000..66154a5c8b --- /dev/null +++ b/cmd/src/gateway_benchmark_stream.go @@ -0,0 +1,283 @@ +package main + +import ( + "flag" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/sourcegraph/src-cli/internal/cmderrors" +) + +type httpEndpoint struct { + url string + authHeader string + body string +} + +func init() { + usage := ` +'src gateway benchmark-stream' runs performance benchmarks against Cody Gateway and Sourcegraph +code completion streaming endpoints. + +Usage: + + src gateway benchmark-stream [flags] + +Examples: + + $ src gateway benchmark-stream --requests 50 --csv results.csv --sgd --sgp + $ src gateway benchmark-stream --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgd --sgp + $ src gateway benchmark-stream --requests 250 --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgd --sgp --max-tokens 50 --provider fireworks --stream +` + + flagSet := flag.NewFlagSet("benchmark-stream", flag.ExitOnError) + + var ( + requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint") + csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)") + requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)") + gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint") + sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint") + sgdToken = flagSet.String("sgd", "", "Sourcegraph Dotcom user key for Cody Gateway") + sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance") + maxTokens = flagSet.Int("max-tokens", 256, "Maximum number of tokens to generate") + provider = flagSet.String("provider", "anthropic", "Provider to use for completion. Supported values: 'anthropic', 'fireworks'") + stream = flagSet.Bool("stream", false, "Whether to stream completions. Default: false") + ) + + handler := func(args []string) error { + // Parse the flags. + if err := flagSet.Parse(args); err != nil { + return err + } + if len(flagSet.Args()) != 0 { + return cmderrors.Usage("additional arguments not allowed") + } + if *gatewayEndpoint != "" && *sgdToken == "" { + return cmderrors.Usage("must specify --sgp ") + } + if *sgEndpoint != "" && *sgpToken == "" { + return cmderrors.Usage("must specify --sgp ") + } + + var httpClient = &http.Client{} + var cgResult, sgResult endpointResult + var cgRequestResults, sgRequestResults []requestResult + + // Do the benchmarking. + fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount) + if *gatewayEndpoint != "" { + fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint) + endpoint := buildGatewayHttpEndpoint(*gatewayEndpoint, *sgdToken, *maxTokens, *provider, *stream) + cgResult, cgRequestResults = benchmarkCodeCompletions("gateway", httpClient, endpoint, *requestCount) + fmt.Println() + } else { + fmt.Println("warning: not benchmarking Cody Gateway (-gateway endpoint not provided)") + } + if *sgEndpoint != "" { + fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint) + endpoint := buildSourcegraphHttpEndpoint(*sgEndpoint, *sgpToken, *maxTokens, *provider, *stream) + sgResult, sgRequestResults = benchmarkCodeCompletions("sourcegraph", httpClient, endpoint, *requestCount) + fmt.Println() + } else { + fmt.Println("warning: not benchmarking Sourcegraph instance (-sourcegraph endpoint not provided)") + } + + // Output the results. + endpointResults := []endpointResult{cgResult, sgResult} + printResults(endpointResults, requestCount) + if *csvOutput != "" { + if err := writeResultsToCSV(*csvOutput, endpointResults, requestCount); err != nil { + return fmt.Errorf("failed to export CSV: %v", err) + } + fmt.Printf("\nAggregate results exported to %s\n", *csvOutput) + } + if *requestLevelCsvOutput != "" { + if err := writeRequestResultsToCSV(*requestLevelCsvOutput, map[string][]requestResult{"gateway": cgRequestResults, "sourcegraph": sgRequestResults}); err != nil { + return fmt.Errorf("failed to export CSV: %v", err) + } + fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput) + } + + return nil + } + gatewayCommands = append(gatewayCommands, &command{ + flagSet: flagSet, + aliases: []string{}, + handler: handler, + usageFunc: func() { + _, err := fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src gateway %s':\n", flagSet.Name()) + if err != nil { + return + } + flagSet.PrintDefaults() + fmt.Println(usage) + }, + }) +} + +func buildGatewayHttpEndpoint(gatewayEndpoint string, sgdToken string, maxTokens int, provider string, stream bool) httpEndpoint { + s := "true" + if !stream { + s = "false" + } + if provider == "anthropic" { + return httpEndpoint{ + url: fmt.Sprint(gatewayEndpoint, "/v1/completions/anthropic-messages"), + authHeader: fmt.Sprintf("Bearer %s", sgdToken), + body: fmt.Sprintf(`{ + "model": "claude-3-haiku-20240307", + "messages": [ + {"role": "user", "content": "def bubble_sort(arr):"}, + {"role": "assistant", "content": "Here is a bubble sort:"} + ], + "max_tokens": %d, + "temperature": 0.0, + "stream": %s +}`, maxTokens, s), + } + } else if provider == "fireworks" { + return httpEndpoint{ + url: fmt.Sprint(gatewayEndpoint, "/v1/completions/fireworks"), + authHeader: fmt.Sprintf("Bearer %s", sgdToken), + body: fmt.Sprintf(`{ + "model": "starcoder", + "prompt": "#hello.ts<|fim▁begin|>const sayHello = () => <|fim▁hole|><|fim▁end|>", + "max_tokens": %d, + "stop": [ + "\n\n", + "\n\r\n", + "<|fim▁begin|>", + "<|fim▁hole|>", + "<|fim▁end|>, <|eos_token|>" + ], + "temperature": 0.2, + "topK": 0, + "topP": 0, + "stream": %s +}`, maxTokens, s), + } + } + + return httpEndpoint{} +} + +func buildSourcegraphHttpEndpoint(sgEndpoint string, sgpToken string, maxTokens int, provider string, stream bool) httpEndpoint { + s := "true" + if !stream { + s = "false" + } + if provider == "anthropic" { + return httpEndpoint{ + url: fmt.Sprint(sgEndpoint, "/.api/completions/stream"), + authHeader: fmt.Sprintf("token %s", sgpToken), + body: fmt.Sprintf(`{ + "model": "anthropic::2023-06-01::claude-3-haiku", + "messages": [ + {"speaker": "human", "text": "def bubble_sort(arr):"}, + {"speaker": "assistant", "text": "Here is a bubble sort:"} + ], + "maxTokensToSample": %d, + "temperature": 0.0, + "stream": %s +}`, maxTokens, s), + } + } else if provider == "fireworks" { + return httpEndpoint{ + url: fmt.Sprint(sgEndpoint, "/.api/completions/code"), + authHeader: fmt.Sprintf("token %s", sgpToken), + body: fmt.Sprintf(`{ + "model": "fireworks::v1::starcoder", + "messages": [ + {"speaker": "human", "text": "#hello.ts<|fim▁begin|>const sayHello = () => <|fim▁hole|><|fim▁end|>"} + ], + "maxTokensToSample": %d, + "stopSequences": [ + "\n\n", + "\n\r\n", + "<|fim▁begin|>", + "<|fim▁hole|>", + "<|fim▁end|>, <|eos_token|>" + ], + "temperature": 0.2, + "topK": 0, + "topP": 0, + "stream": %s +}`, maxTokens, s), + } + } + + return httpEndpoint{} +} + +func benchmarkCodeCompletions(benchmarkName string, client *http.Client, endpoint httpEndpoint, requestCount int) (endpointResult, []requestResult) { + results := make([]requestResult, 0, requestCount) + durations := make([]time.Duration, 0, requestCount) + + for i := 0; i < requestCount; i++ { + result := benchmarkCodeCompletion(client, endpoint) + if result.duration > 0 { + results = append(results, result) + durations = append(durations, result.duration) + } + } + stats := calculateStats(durations) + + return toEndpointResult(benchmarkName, stats, len(durations)), results +} + +func benchmarkCodeCompletion(client *http.Client, endpoint httpEndpoint) requestResult { + start := time.Now() + req, err := http.NewRequest("POST", endpoint.url, strings.NewReader(endpoint.body)) + if err != nil { + fmt.Printf("Error creating request: %v\n", err) + return requestResult{0, ""} + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", endpoint.authHeader) + req.Header.Set("X-Sourcegraph-Should-Trace", "true") + req.Header.Set("X-Sourcegraph-Feature", "code_completions") + resp, err := client.Do(req) + if err != nil { + fmt.Printf("Error calling %s: %v\n", endpoint.url, err) + return requestResult{0, ""} + } + defer func() { + err := resp.Body.Close() + if err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + fmt.Printf("non-200 response: %v - %s\n", resp.Status, body) + return requestResult{0, ""} + } + _, err = io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading response body: %v\n", err) + return requestResult{0, ""} + } + + return requestResult{ + duration: time.Since(start), + traceID: resp.Header.Get("X-Trace"), + } +} + +func toEndpointResult(name string, stats Stats, requestCount int) endpointResult { + return endpointResult{ + name: name, + avg: stats.Avg, + median: stats.Median, + p5: stats.P5, + p75: stats.P75, + p80: stats.P80, + p95: stats.P95, + successful: requestCount, + total: stats.Total, + } +}