@@ -11,6 +11,8 @@ import (
1111 "strings"
1212 "time"
1313
14+ "github.com/gorilla/websocket"
15+
1416 "github.com/sourcegraph/src-cli/internal/cmderrors"
1517)
1618
@@ -42,8 +44,9 @@ Examples:
4244 flagSet := flag .NewFlagSet ("benchmark" , flag .ExitOnError )
4345
4446 var (
45- requestCount = flagSet .Int ("requests" , 1000 , "Number of requests to make per endpoint" )
46- csvOutput = flagSet .String ("csv" , "" , "Export results to CSV file (provide filename)" )
47+ requestCount = flagSet .Int ("requests" , 1000 , "Number of requests to make per endpoint" )
48+ csvOutput = flagSet .String ("csv" , "" , "Export results to CSV file (provide filename)" )
49+ gatewayEndpoint = flagSet .String ("gateway" , "https://cody-gateway.sourcegraph.com" , "Cody Gateway endpoint" )
4750 )
4851
4952 handler := func (args []string ) error {
@@ -55,26 +58,49 @@ Examples:
5558 return cmderrors .Usage ("additional arguments not allowed" )
5659 }
5760
58- // Create HTTP client with TLS skip verify
59- client := & http.Client {Transport : & http.Transport {}}
60-
61- endpoints := map [string ]string {
62- "HTTP" : fmt .Sprintf ("%s/gateway" , cfg .Endpoint ),
63- "HTTP then WebSocket" : fmt .Sprintf ("%s/gateway/http-then-websocket" , cfg .Endpoint ),
61+ var (
62+ gatewayWebsocket , sourcegraphWebsocket * websocket.Conn
63+ err error
64+ httpClient = & http.Client {}
65+ endpoints = map [string ]any {} // Values: URL `string`s or `*websocket.Conn`s
66+ )
67+ if * gatewayEndpoint != "" {
68+ wsURL := strings .Replace (fmt .Sprint (* gatewayEndpoint , "/v2/websocket" ), "http" , "ws" , 1 )
69+ gatewayWebsocket , _ , err = websocket .DefaultDialer .Dial (wsURL , nil )
70+ if err != nil {
71+ return fmt .Errorf ("WebSocket dial(%s): %v" , wsURL , err )
72+ }
73+ endpoints ["ws(s): gateway" ] = gatewayWebsocket
74+ endpoints ["http(s): gateway" ] = fmt .Sprint (* gatewayEndpoint , "/v2/http" )
75+ }
76+ if cfg .Endpoint != "" {
77+ wsURL := strings .Replace (fmt .Sprint (cfg .Endpoint , "/.api/gateway/websocket" ), "http" , "ws" , 1 )
78+ sourcegraphWebsocket , _ , err = websocket .DefaultDialer .Dial (wsURL , nil )
79+ if err != nil {
80+ return fmt .Errorf ("WebSocket dial(%s): %v" , wsURL , err )
81+ }
82+ endpoints ["ws(s): sourcegraph" ] = sourcegraphWebsocket
83+ endpoints ["http(s): sourcegraph" ] = fmt .Sprint (* gatewayEndpoint , "/.api/gateway/http" )
6484 }
6585
6686 fmt .Printf ("Starting benchmark with %d requests per endpoint...\n " , * requestCount )
6787
6888 var results []endpointResult
69-
70- for name , url := range endpoints {
89+ for name , clientOrURL := range endpoints {
7190 durations := make ([]time.Duration , 0 , * requestCount )
7291 fmt .Printf ("\n Testing %s..." , name )
7392
7493 for i := 0 ; i < * requestCount ; i ++ {
75- duration := benchmarkEndpoint (client , url )
76- if duration > 0 {
77- durations = append (durations , duration )
94+ if ws , ok := clientOrURL .(* websocket.Conn ); ok {
95+ duration := benchmarkEndpointWebSocket (ws )
96+ if duration > 0 {
97+ durations = append (durations , duration )
98+ }
99+ } else if url , ok := clientOrURL .(string ); ok {
100+ duration := benchmarkEndpointHTTP (httpClient , url )
101+ if duration > 0 {
102+ durations = append (durations , duration )
103+ }
78104 }
79105 }
80106 fmt .Println ()
@@ -133,15 +159,15 @@ type endpointResult struct {
133159 successful int
134160}
135161
136- func benchmarkEndpoint (client * http.Client , url string ) time.Duration {
162+ func benchmarkEndpointHTTP (client * http.Client , url string ) time.Duration {
137163 start := time .Now ()
138164 resp , err := client .Get (url )
139165 if err != nil {
140166 fmt .Printf ("Error calling %s: %v\n " , url , err )
141167 return 0
142168 }
143- defer func (Body io.ReadCloser ) {
144- err := Body .Close ()
169+ defer func (body io.ReadCloser ) {
170+ err := body .Close ()
145171 if err != nil {
146172 fmt .Printf ("Error closing response body: %v\n " , err )
147173 }
@@ -152,10 +178,42 @@ func benchmarkEndpoint(client *http.Client, url string) time.Duration {
152178 fmt .Printf ("Error reading response body: %v\n " , err )
153179 return 0
154180 }
181+ if resp .StatusCode != http .StatusOK {
182+ fmt .Printf ("non-200 response: %v\n " , resp .Status )
183+ return 0
184+ }
185+ body , err := io .ReadAll (resp .Body )
186+ if err != nil {
187+ fmt .Printf ("Error reading response body: %v\n " , err )
188+ return 0
189+ }
190+ if string (body ) != "pong" {
191+ fmt .Printf ("Expected 'pong' response, got: %q\n " , string (body ))
192+ return 0
193+ }
155194
156195 return time .Since (start )
157196}
158197
198+ func benchmarkEndpointWebSocket (conn * websocket.Conn ) time.Duration {
199+ start := time .Now ()
200+ err := conn .WriteMessage (websocket .TextMessage , []byte ("ping" ))
201+ if err != nil {
202+ fmt .Printf ("WebSocket write error: %v\n " , err )
203+ return 0
204+ }
205+ _ , message , err := conn .ReadMessage ()
206+ if err != nil {
207+ fmt .Printf ("WebSocket read error: %v\n " , err )
208+ return 0
209+ }
210+ if string (message ) != "pong" {
211+ fmt .Printf ("Expected 'pong' response, got: %q\n " , string (message ))
212+ return 0
213+ }
214+ return time .Since (start )
215+ }
216+
159217func calculateStats (durations []time.Duration ) Stats {
160218 if len (durations ) == 0 {
161219 return Stats {0 , 0 , 0 , 0 , 0 , 0 , 0 }
0 commit comments