mirror of
https://github.com/mastodon/mastodon.git
synced 2024-08-20 21:08:15 -07:00
Streaming: Rework websocket server initialisation & authentication code (#28631)
This commit is contained in:
parent
e72676e83a
commit
58830be943
1 changed files with 95 additions and 33 deletions
|
@ -182,14 +182,74 @@ const CHANNEL_NAMES = [
|
||||||
];
|
];
|
||||||
|
|
||||||
const startServer = async () => {
|
const startServer = async () => {
|
||||||
|
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
|
||||||
|
const server = http.createServer();
|
||||||
|
const wss = new WebSocket.Server({ noServer: true });
|
||||||
|
|
||||||
|
// Set the X-Request-Id header on WebSockets:
|
||||||
|
wss.on("headers", function onHeaders(headers, req) {
|
||||||
|
headers.push(`X-Request-Id: ${req.id}`);
|
||||||
|
});
|
||||||
|
|
||||||
const app = express();
|
const app = express();
|
||||||
|
|
||||||
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
|
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
|
||||||
|
|
||||||
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
|
|
||||||
const server = http.createServer(app);
|
|
||||||
app.use(cors());
|
app.use(cors());
|
||||||
|
|
||||||
|
// Handle eventsource & other http requests:
|
||||||
|
server.on('request', app);
|
||||||
|
|
||||||
|
// Handle upgrade requests:
|
||||||
|
server.on('upgrade', async function handleUpgrade(request, socket, head) {
|
||||||
|
/** @param {Error} err */
|
||||||
|
const onSocketError = (err) => {
|
||||||
|
log.error(`Error with websocket upgrade: ${err}`);
|
||||||
|
};
|
||||||
|
|
||||||
|
socket.on('error', onSocketError);
|
||||||
|
|
||||||
|
// Authenticate:
|
||||||
|
try {
|
||||||
|
await accountFromRequest(request);
|
||||||
|
} catch (err) {
|
||||||
|
log.error(`Error authenticating request: ${err}`);
|
||||||
|
|
||||||
|
// Unfortunately for using the on('upgrade') setup, we need to manually
|
||||||
|
// write a HTTP Response to the Socket to close the connection upgrade
|
||||||
|
// attempt, so the following code is to handle all of that.
|
||||||
|
const statusCode = err.status ?? 401;
|
||||||
|
|
||||||
|
/** @type {Record<string, string | number>} */
|
||||||
|
const headers = {
|
||||||
|
'Connection': 'close',
|
||||||
|
'Content-Type': 'text/plain',
|
||||||
|
'Content-Length': 0,
|
||||||
|
'X-Request-Id': request.id,
|
||||||
|
// TODO: Send the error message via header so it can be debugged in
|
||||||
|
// developer tools
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ensure the socket is closed once we've finished writing to it:
|
||||||
|
socket.once('finish', () => {
|
||||||
|
socket.destroy();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Write the HTTP response manually:
|
||||||
|
socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
wss.handleUpgrade(request, socket, head, function done(ws) {
|
||||||
|
// Remove the error handler:
|
||||||
|
socket.removeListener('error', onSocketError);
|
||||||
|
|
||||||
|
// Start the connection:
|
||||||
|
wss.emit('connection', ws, request);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {Object.<string, Array.<function(Object<string, any>): void>>}
|
* @type {Object.<string, Array.<function(Object<string, any>): void>>}
|
||||||
*/
|
*/
|
||||||
|
@ -360,10 +420,19 @@ const startServer = async () => {
|
||||||
const isInScope = (req, necessaryScopes) =>
|
const isInScope = (req, necessaryScopes) =>
|
||||||
req.scopes.some(scope => necessaryScopes.includes(scope));
|
req.scopes.some(scope => necessaryScopes.includes(scope));
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef ResolvedAccount
|
||||||
|
* @property {string} accessTokenId
|
||||||
|
* @property {string[]} scopes
|
||||||
|
* @property {string} accountId
|
||||||
|
* @property {string[]} chosenLanguages
|
||||||
|
* @property {string} deviceId
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {string} token
|
* @param {string} token
|
||||||
* @param {any} req
|
* @param {any} req
|
||||||
* @returns {Promise.<void>}
|
* @returns {Promise<ResolvedAccount>}
|
||||||
*/
|
*/
|
||||||
const accountFromToken = (token, req) => new Promise((resolve, reject) => {
|
const accountFromToken = (token, req) => new Promise((resolve, reject) => {
|
||||||
pgPool.connect((err, client, done) => {
|
pgPool.connect((err, client, done) => {
|
||||||
|
@ -394,14 +463,20 @@ const startServer = async () => {
|
||||||
req.chosenLanguages = result.rows[0].chosen_languages;
|
req.chosenLanguages = result.rows[0].chosen_languages;
|
||||||
req.deviceId = result.rows[0].device_id;
|
req.deviceId = result.rows[0].device_id;
|
||||||
|
|
||||||
resolve();
|
resolve({
|
||||||
|
accessTokenId: result.rows[0].id,
|
||||||
|
scopes: result.rows[0].scopes.split(' '),
|
||||||
|
accountId: result.rows[0].account_id,
|
||||||
|
chosenLanguages: result.rows[0].chosen_languages,
|
||||||
|
deviceId: result.rows[0].device_id
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {any} req
|
* @param {any} req
|
||||||
* @returns {Promise.<void>}
|
* @returns {Promise<ResolvedAccount>}
|
||||||
*/
|
*/
|
||||||
const accountFromRequest = (req) => new Promise((resolve, reject) => {
|
const accountFromRequest = (req) => new Promise((resolve, reject) => {
|
||||||
const authorization = req.headers.authorization;
|
const authorization = req.headers.authorization;
|
||||||
|
@ -494,25 +569,6 @@ const startServer = async () => {
|
||||||
reject(err);
|
reject(err);
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {any} info
|
|
||||||
* @param {function(boolean, number, string): void} callback
|
|
||||||
*/
|
|
||||||
const wsVerifyClient = (info, callback) => {
|
|
||||||
// When verifying the websockets connection, we no longer pre-emptively
|
|
||||||
// check OAuth scopes and drop the connection if they're missing. We only
|
|
||||||
// drop the connection if access without token is not allowed by environment
|
|
||||||
// variables. OAuth scope checks are moved to the point of subscription
|
|
||||||
// to a specific stream.
|
|
||||||
|
|
||||||
accountFromRequest(info.req).then(() => {
|
|
||||||
callback(true, undefined, undefined);
|
|
||||||
}).catch(err => {
|
|
||||||
log.error(info.req.requestId, err.toString());
|
|
||||||
callback(false, 401, 'Unauthorized');
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef SystemMessageHandlers
|
* @typedef SystemMessageHandlers
|
||||||
* @property {function(): void} onKill
|
* @property {function(): void} onKill
|
||||||
|
@ -944,8 +1000,8 @@ const startServer = async () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {any} req
|
* @param {http.IncomingMessage} req
|
||||||
* @param {any} ws
|
* @param {WebSocket} ws
|
||||||
* @param {string[]} streamName
|
* @param {string[]} streamName
|
||||||
* @returns {function(string, string): void}
|
* @returns {function(string, string): void}
|
||||||
*/
|
*/
|
||||||
|
@ -955,7 +1011,9 @@ const startServer = async () => {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => {
|
const message = JSON.stringify({ stream: streamName, event, payload });
|
||||||
|
|
||||||
|
ws.send(message, (/** @type {Error} */ err) => {
|
||||||
if (err) {
|
if (err) {
|
||||||
log.error(req.requestId, `Failed to send to websocket: ${err}`);
|
log.error(req.requestId, `Failed to send to websocket: ${err}`);
|
||||||
}
|
}
|
||||||
|
@ -992,8 +1050,6 @@ const startServer = async () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef StreamParams
|
* @typedef StreamParams
|
||||||
* @property {string} [tag]
|
* @property {string} [tag]
|
||||||
|
@ -1173,8 +1229,8 @@ const startServer = async () => {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef WebSocketSession
|
* @typedef WebSocketSession
|
||||||
* @property {any} socket
|
* @property {WebSocket} websocket
|
||||||
* @property {any} request
|
* @property {http.IncomingMessage} request
|
||||||
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
|
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -1297,7 +1353,11 @@ const startServer = async () => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
wss.on('connection', (ws, req) => {
|
/**
|
||||||
|
* @param {WebSocket & { isAlive: boolean }} ws
|
||||||
|
* @param {http.IncomingMessage} req
|
||||||
|
*/
|
||||||
|
function onConnection(ws, req) {
|
||||||
// Note: url.parse could throw, which would terminate the connection, so we
|
// Note: url.parse could throw, which would terminate the connection, so we
|
||||||
// increment the connected clients metric straight away when we establish
|
// increment the connected clients metric straight away when we establish
|
||||||
// the connection, without waiting:
|
// the connection, without waiting:
|
||||||
|
@ -1375,7 +1435,9 @@ const startServer = async () => {
|
||||||
if (location && location.query.stream) {
|
if (location && location.query.stream) {
|
||||||
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
|
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
|
wss.on('connection', onConnection);
|
||||||
|
|
||||||
setInterval(() => {
|
setInterval(() => {
|
||||||
wss.clients.forEach(ws => {
|
wss.clients.forEach(ws => {
|
||||||
|
|
Loading…
Reference in a new issue