From 58830be94329ce1a541f1f79c9e4e70aef430b19 Mon Sep 17 00:00:00 2001 From: Emelia Smith Date: Mon, 15 Jan 2024 11:36:30 +0100 Subject: [PATCH] Streaming: Rework websocket server initialisation & authentication code (#28631) --- streaming/index.js | 128 +++++++++++++++++++++++++++++++++------------ 1 file changed, 95 insertions(+), 33 deletions(-) diff --git a/streaming/index.js b/streaming/index.js index 42d0afc7c5b..c8124fcc0f1 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -182,14 +182,74 @@ const CHANNEL_NAMES = [ ]; const startServer = async () => { + const pgPool = new pg.Pool(pgConfigFromEnv(process.env)); + const server = http.createServer(); + const wss = new WebSocket.Server({ noServer: true }); + + // Set the X-Request-Id header on WebSockets: + wss.on("headers", function onHeaders(headers, req) { + headers.push(`X-Request-Id: ${req.id}`); + }); + const app = express(); app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal'); - const pgPool = new pg.Pool(pgConfigFromEnv(process.env)); - const server = http.createServer(app); app.use(cors()); + // Handle eventsource & other http requests: + server.on('request', app); + + // Handle upgrade requests: + server.on('upgrade', async function handleUpgrade(request, socket, head) { + /** @param {Error} err */ + const onSocketError = (err) => { + log.error(`Error with websocket upgrade: ${err}`); + }; + + socket.on('error', onSocketError); + + // Authenticate: + try { + await accountFromRequest(request); + } catch (err) { + log.error(`Error authenticating request: ${err}`); + + // Unfortunately for using the on('upgrade') setup, we need to manually + // write a HTTP Response to the Socket to close the connection upgrade + // attempt, so the following code is to handle all of that. + const statusCode = err.status ?? 401; + + /** @type {Record} */ + const headers = { + 'Connection': 'close', + 'Content-Type': 'text/plain', + 'Content-Length': 0, + 'X-Request-Id': request.id, + // TODO: Send the error message via header so it can be debugged in + // developer tools + }; + + // Ensure the socket is closed once we've finished writing to it: + socket.once('finish', () => { + socket.destroy(); + }); + + // Write the HTTP response manually: + socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`); + + return; + } + + wss.handleUpgrade(request, socket, head, function done(ws) { + // Remove the error handler: + socket.removeListener('error', onSocketError); + + // Start the connection: + wss.emit('connection', ws, request); + }); + }); + /** * @type {Object.): void>>} */ @@ -360,10 +420,19 @@ const startServer = async () => { const isInScope = (req, necessaryScopes) => req.scopes.some(scope => necessaryScopes.includes(scope)); + /** + * @typedef ResolvedAccount + * @property {string} accessTokenId + * @property {string[]} scopes + * @property {string} accountId + * @property {string[]} chosenLanguages + * @property {string} deviceId + */ + /** * @param {string} token * @param {any} req - * @returns {Promise.} + * @returns {Promise} */ const accountFromToken = (token, req) => new Promise((resolve, reject) => { pgPool.connect((err, client, done) => { @@ -394,14 +463,20 @@ const startServer = async () => { req.chosenLanguages = result.rows[0].chosen_languages; req.deviceId = result.rows[0].device_id; - resolve(); + resolve({ + accessTokenId: result.rows[0].id, + scopes: result.rows[0].scopes.split(' '), + accountId: result.rows[0].account_id, + chosenLanguages: result.rows[0].chosen_languages, + deviceId: result.rows[0].device_id + }); }); }); }); /** * @param {any} req - * @returns {Promise.} + * @returns {Promise} */ const accountFromRequest = (req) => new Promise((resolve, reject) => { const authorization = req.headers.authorization; @@ -494,25 +569,6 @@ const startServer = async () => { reject(err); }); - /** - * @param {any} info - * @param {function(boolean, number, string): void} callback - */ - const wsVerifyClient = (info, callback) => { - // When verifying the websockets connection, we no longer pre-emptively - // check OAuth scopes and drop the connection if they're missing. We only - // drop the connection if access without token is not allowed by environment - // variables. OAuth scope checks are moved to the point of subscription - // to a specific stream. - - accountFromRequest(info.req).then(() => { - callback(true, undefined, undefined); - }).catch(err => { - log.error(info.req.requestId, err.toString()); - callback(false, 401, 'Unauthorized'); - }); - }; - /** * @typedef SystemMessageHandlers * @property {function(): void} onKill @@ -944,8 +1000,8 @@ const startServer = async () => { }; /** - * @param {any} req - * @param {any} ws + * @param {http.IncomingMessage} req + * @param {WebSocket} ws * @param {string[]} streamName * @returns {function(string, string): void} */ @@ -955,7 +1011,9 @@ const startServer = async () => { return; } - ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => { + const message = JSON.stringify({ stream: streamName, event, payload }); + + ws.send(message, (/** @type {Error} */ err) => { if (err) { log.error(req.requestId, `Failed to send to websocket: ${err}`); } @@ -992,8 +1050,6 @@ const startServer = async () => { }); }); - const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient }); - /** * @typedef StreamParams * @property {string} [tag] @@ -1173,8 +1229,8 @@ const startServer = async () => { /** * @typedef WebSocketSession - * @property {any} socket - * @property {any} request + * @property {WebSocket} websocket + * @property {http.IncomingMessage} request * @property {Object.} subscriptions */ @@ -1297,7 +1353,11 @@ const startServer = async () => { } }; - wss.on('connection', (ws, req) => { + /** + * @param {WebSocket & { isAlive: boolean }} ws + * @param {http.IncomingMessage} req + */ + function onConnection(ws, req) { // Note: url.parse could throw, which would terminate the connection, so we // increment the connected clients metric straight away when we establish // the connection, without waiting: @@ -1375,7 +1435,9 @@ const startServer = async () => { if (location && location.query.stream) { subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); } - }); + } + + wss.on('connection', onConnection); setInterval(() => { wss.clients.forEach(ws => {