@@ -5,6 +5,7 @@ import type { ServerOptions as HttpsServerOptions } from 'node:https'
55import { createServer as createHttpsServer } from 'node:https'
66import type { Socket } from 'node:net'
77import type { Duplex } from 'node:stream'
8+ import crypto from 'node:crypto'
89import colors from 'picocolors'
910import type { WebSocket as WebSocketRaw } from 'ws'
1011import { WebSocketServer as WebSocketServerRaw_ } from 'ws'
@@ -89,6 +90,29 @@ function noop() {
8990 // noop
9091}
9192
93+ // we only allow websockets to be connected if it has a valid token
94+ // this is to prevent untrusted origins to connect to the server
95+ // for example, Cross-site WebSocket hijacking
96+ //
97+ // we should check the token before calling wss.handleUpgrade
98+ // otherwise untrusted ws clients will be included in wss.clients
99+ //
100+ // using the query params means the token might be logged out in server or middleware logs
101+ // but we assume that is not an issue since the token is regenerated for each process
102+ function hasValidToken ( config : ResolvedConfig , url : URL ) {
103+ const token = url . searchParams . get ( 'token' )
104+ if ( ! token ) return false
105+
106+ try {
107+ const isValidToken = crypto . timingSafeEqual (
108+ Buffer . from ( token ) ,
109+ Buffer . from ( config . webSocketToken ) ,
110+ )
111+ return isValidToken
112+ } catch { } // an error is thrown when the length is incorrect
113+ return false
114+ }
115+
92116export function createWebSocketServer (
93117 server : HttpServer | null ,
94118 config : ResolvedConfig ,
@@ -110,7 +134,6 @@ export function createWebSocketServer(
110134 }
111135 }
112136
113- let wss : WebSocketServerRaw_
114137 let wsHttpServer : Server | undefined = undefined
115138
116139 const hmr = isObject ( config . server . hmr ) && config . server . hmr
@@ -129,21 +152,50 @@ export function createWebSocketServer(
129152 const port = hmrPort || 24678
130153 const host = ( hmr && hmr . host ) || undefined
131154
155+ const shouldHandle = ( req : IncomingMessage ) => {
156+ if ( config . legacy ?. skipWebSocketTokenCheck ) {
157+ return true
158+ }
159+
160+ // If the Origin header is set, this request might be coming from a browser.
161+ // Browsers always sets the Origin header for WebSocket connections.
162+ if ( req . headers . origin ) {
163+ const parsedUrl = new URL ( `http://example.com${ req . url ! } ` )
164+ return hasValidToken ( config , parsedUrl )
165+ }
166+
167+ // We allow non-browser requests to connect without a token
168+ // for backward compat and convenience
169+ // This is fine because if you can sent a request without the SOP limitation,
170+ // you can also send a normal HTTP request to the server.
171+ return true
172+ }
173+ const handleUpgrade = (
174+ req : IncomingMessage ,
175+ socket : Duplex ,
176+ head : Buffer ,
177+ _isPing : boolean ,
178+ ) => {
179+ wss . handleUpgrade ( req , socket as Socket , head , ( ws ) => {
180+ wss . emit ( 'connection' , ws , req )
181+ } )
182+ }
183+ const wss : WebSocketServerRaw_ = new WebSocketServerRaw ( { noServer : true } )
184+ wss . shouldHandle = shouldHandle
185+
132186 if ( wsServer ) {
133187 let hmrBase = config . base
134188 const hmrPath = hmr ? hmr . path : undefined
135189 if ( hmrPath ) {
136190 hmrBase = path . posix . join ( hmrBase , hmrPath )
137191 }
138- wss = new WebSocketServerRaw ( { noServer : true } )
139192 hmrServerWsListener = ( req , socket , head ) => {
193+ const parsedUrl = new URL ( `http://example.com${ req . url ! } ` )
140194 if (
141195 req . headers [ 'sec-websocket-protocol' ] === HMR_HEADER &&
142- req . url === hmrBase
196+ parsedUrl . pathname === hmrBase
143197 ) {
144- wss . handleUpgrade ( req , socket as Socket , head , ( ws ) => {
145- wss . emit ( 'connection' , ws , req )
146- } )
198+ handleUpgrade ( req , socket as Socket , head , false )
147199 }
148200 }
149201 wsServer . on ( 'upgrade' , hmrServerWsListener )
@@ -167,9 +219,22 @@ export function createWebSocketServer(
167219 } else {
168220 wsHttpServer = createHttpServer ( route )
169221 }
170- // vite dev server in middleware mode
171- // need to call ws listen manually
172- wss = new WebSocketServerRaw ( { server : wsHttpServer } )
222+ wsHttpServer . on ( 'upgrade' , ( req , socket , head ) => {
223+ handleUpgrade ( req , socket as Socket , head , false )
224+ } )
225+ wsHttpServer . on ( 'error' , ( e : Error & { code : string } ) => {
226+ if ( e . code === 'EADDRINUSE' ) {
227+ config . logger . error (
228+ colors . red ( `WebSocket server error: Port is already in use` ) ,
229+ { error : e } ,
230+ )
231+ } else {
232+ config . logger . error (
233+ colors . red ( `WebSocket server error:\n${ e . stack || e . message } ` ) ,
234+ { error : e } ,
235+ )
236+ }
237+ } )
173238 }
174239
175240 wss . on ( 'connection' , ( socket ) => {
0 commit comments