Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/src/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Usage:
The commands are:

benchmark runs benchmarks against Cody Gateway
benchmark-stream runs benchmarks against Cody Gateway code completion streaming endpoints

Use "src gateway [command] -h" for more information about a command.

Expand Down
164 changes: 123 additions & 41 deletions cmd/src/gateway_benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ type Stats struct {
Total time.Duration
}

type requestResult struct {
duration time.Duration
traceID string // X-Trace header value
}

func init() {
usage := `
'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints.
'src gateway benchmark' runs performance benchmarks against Cody Gateway and Sourcegraph test endpoints.

Usage:

Expand All @@ -39,17 +44,20 @@ Examples:
$ src gateway benchmark --sgp <token>
$ src gateway benchmark --requests 50 --sgp <token>
$ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp <token>
$ src gateway benchmark --requests 50 --csv results.csv --sgp <token>
$ src gateway benchmark --requests 50 --csv results.csv --request-csv requests.csv --sgp <token>
$ src gateway benchmark --gateway https://cody-gateway.sourcegraph.com --sourcegraph https://sourcegraph.com --sgp <token> --use-special-header
`

flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError)

var (
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
gatewayEndpoint = flagSet.String("gateway", "https://cody-gateway.sourcegraph.com", "Cody Gateway endpoint")
sgEndpoint = flagSet.String("sourcegraph", "https://sourcegraph.com", "Sourcegraph endpoint")
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)")
gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint")
sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint")
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
useSpecialHeader = flagSet.Bool("use-special-header", false, "Use special header to test the gateway")
)

handler := func(args []string) error {
Expand All @@ -61,15 +69,23 @@ Examples:
return cmderrors.Usage("additional arguments not allowed")
}

if *useSpecialHeader {
fmt.Println("Using special header 'cody-core-gc-test'")
}

var (
httpClient = &http.Client{}
endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s
)
if *gatewayEndpoint != "" {
fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint)
headers := http.Header{
"X-Sourcegraph-Should-Trace": []string{"true"},
}
endpoints["ws(s): gateway"] = &webSocketClient{
conn: nil,
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
conn: nil,
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
reqHeaders: headers,
}
endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http")
} else {
Expand All @@ -80,12 +96,18 @@ Examples:
return cmderrors.Usage("must specify --sgp <Sourcegraph personal access token>")
}
fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint)
headers := http.Header{
"Authorization": []string{"token " + *sgpToken},
"X-Sourcegraph-Should-Trace": []string{"true"},
}
if *useSpecialHeader {
headers.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
}

endpoints["ws(s): sourcegraph"] = &webSocketClient{
conn: nil,
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
headers: http.Header{
"Authorization": []string{"token " + *sgpToken},
},
conn: nil,
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
reqHeaders: headers,
}
endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http")
endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket")
Expand All @@ -95,29 +117,33 @@ Examples:

fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount)

var results []endpointResult
var eResults []endpointResult
rResults := map[string][]requestResult{}
for name, clientOrURL := range endpoints {
durations := make([]time.Duration, 0, *requestCount)
rResults[name] = make([]requestResult, 0, *requestCount)
fmt.Printf("\nTesting %s...", name)

for i := 0; i < *requestCount; i++ {
if ws, ok := clientOrURL.(*webSocketClient); ok {
duration := benchmarkEndpointWebSocket(ws)
if duration > 0 {
durations = append(durations, duration)
result := benchmarkEndpointWebSocket(ws)
if result.duration > 0 {
durations = append(durations, result.duration)
rResults[name] = append(rResults[name], result)
}
} else if url, ok := clientOrURL.(string); ok {
duration := benchmarkEndpointHTTP(httpClient, url, *sgpToken)
if duration > 0 {
durations = append(durations, duration)
result := benchmarkEndpointHTTP(httpClient, url, *sgpToken, *useSpecialHeader)
if result.duration > 0 {
durations = append(durations, result.duration)
rResults[name] = append(rResults[name], result)
}
}
}
fmt.Println()

stats := calculateStats(durations)

results = append(results, endpointResult{
eResults = append(eResults, endpointResult{
name: name,
avg: stats.Avg,
median: stats.Median,
Expand All @@ -130,14 +156,20 @@ Examples:
})
}

printResults(results, requestCount)
printResults(eResults, requestCount)

if *csvOutput != "" {
if err := writeResultsToCSV(*csvOutput, results, requestCount); err != nil {
if err := writeResultsToCSV(*csvOutput, eResults, requestCount); err != nil {
return fmt.Errorf("failed to export CSV: %v", err)
}
fmt.Printf("\nResults exported to %s\n", *csvOutput)
}
if *requestLevelCsvOutput != "" {
if err := writeRequestResultsToCSV(*requestLevelCsvOutput, rResults); err != nil {
return fmt.Errorf("failed to export request-level CSV: %v", err)
}
fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput)
}

return nil
}
Expand All @@ -158,9 +190,10 @@ Examples:
}

type webSocketClient struct {
conn *websocket.Conn
URL string
headers http.Header
conn *websocket.Conn
URL string
reqHeaders http.Header
respHeaders http.Header
}

func (c *webSocketClient) reconnect() error {
Expand All @@ -169,11 +202,13 @@ func (c *webSocketClient) reconnect() error {
}
fmt.Println("Connecting to WebSocket..", c.URL)
var err error
c.conn, _, err = websocket.DefaultDialer.Dial(c.URL, c.headers)
var resp *http.Response
c.conn, resp, err = websocket.DefaultDialer.Dial(c.URL, c.reqHeaders)
if err != nil {
c.conn = nil // retry again later
return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err)
}
c.respHeaders = resp.Header
fmt.Println("Connected!")
return nil
}
Expand All @@ -190,19 +225,23 @@ type endpointResult struct {
successful int
}

func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Duration {
func benchmarkEndpointHTTP(client *http.Client, url, accessToken string, useSpecialHeader bool) requestResult {
start := time.Now()
req, err := http.NewRequest("POST", url, strings.NewReader("ping"))
if err != nil {
fmt.Printf("Error creating request: %v\n", err)
return 0
return requestResult{}
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "token "+accessToken)
req.Header.Set("X-Sourcegraph-Should-Trace", "true")
if useSpecialHeader {
req.Header.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error calling %s: %v\n", url, err)
return 0
return requestResult{}
}
defer func() {
err := resp.Body.Close()
Expand All @@ -212,27 +251,30 @@ func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Du
}()
if resp.StatusCode != http.StatusOK {
fmt.Printf("non-200 response: %v\n", resp.Status)
return 0
return requestResult{}
}
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading response body: %v\n", err)
return 0
return requestResult{}
}
if string(body) != "pong" {
fmt.Printf("Expected 'pong' response, got: %q\n", string(body))
return 0
return requestResult{}
}

return time.Since(start)
return requestResult{
duration: time.Since(start),
traceID: resp.Header.Get("X-Trace"),
}
}

func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
func benchmarkEndpointWebSocket(client *webSocketClient) requestResult {
// Perform initial websocket connection, if needed.
if client.conn == nil {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
return 0
return requestResult{}
}
}

Expand All @@ -244,7 +286,7 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
_, message, err := client.conn.ReadMessage()

Expand All @@ -253,16 +295,19 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
if string(message) != "pong" {
fmt.Printf("Expected 'pong' response, got: %q\n", string(message))
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
return requestResult{
duration: time.Since(start),
traceID: client.respHeaders.Get("Content-Type"),
}
return time.Since(start)
}

func calculateStats(durations []time.Duration) Stats {
Expand Down Expand Up @@ -438,3 +483,40 @@ func writeResultsToCSV(filename string, results []endpointResult, requestCount *

return nil
}

func writeRequestResultsToCSV(filename string, results map[string][]requestResult) error {
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create CSV file: %v", err)
}
defer func() {
err := file.Close()
if err != nil {
return
}
}()

writer := csv.NewWriter(file)
defer writer.Flush()

// Write header
header := []string{"Endpoint", "Duration (ms)", "Trace ID"}
if err := writer.Write(header); err != nil {
return fmt.Errorf("failed to write CSV header: %v", err)
}

for endpoint, requestResults := range results {
for _, result := range requestResults {
row := []string{
endpoint,
fmt.Sprintf("%.2f", float64(result.duration.Microseconds())/1000),
result.traceID,
}
if err := writer.Write(row); err != nil {
return fmt.Errorf("failed to write CSV row: %v", err)
}
}
}

return nil
}
Loading
Loading