WebSockets via Rocket's Upgrade connection

This PR implements a (not yet fully released) new feature of Rocket which allows WebSockets/Upgrade connections.
No more need for multiple ports to be opened for Vaultwarden.
No explicit need for a reverse proxy to get WebSockets to work (Although I still suggest to use a reverse proxy).

- Using a git revision for Rocket, since `rocket_ws` is not yet released.
- Updated other crates as well.
- Added a connection guard to clear the WS connection from the Users list.

Fixes #685
Fixes #2917
Fixes #1424
This commit is contained in:
BlackDex
2023-03-30 17:18:59 +02:00
parent 3bd4e42fb0
commit 3d11f4cd16
3 changed files with 343 additions and 191 deletions

View File

@ -1,16 +1,15 @@
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
net::{IpAddr, SocketAddr},
sync::Arc,
time::Duration,
};
use chrono::NaiveDateTime;
use futures::{SinkExt, StreamExt};
use rmpv::Value;
use rocket::Route;
use rocket::{
futures::{SinkExt, StreamExt},
Route,
};
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc::Sender,
@ -21,34 +20,127 @@ use tokio_tungstenite::{
};
use crate::{
api::EmptyResult,
db::models::{Cipher, Folder, Send, User},
auth::ClientIp,
db::models::{Cipher, Folder, Send as DbSend, User},
Error, CONFIG,
};
use once_cell::sync::Lazy;
static WS_USERS: Lazy<Arc<WebSocketUsers>> = Lazy::new(|| {
Arc::new(WebSocketUsers {
map: Arc::new(dashmap::DashMap::new()),
})
});
pub fn routes() -> Vec<Route> {
routes![websockets_err]
routes![websockets_hub]
}
#[get("/hub")]
fn websockets_err() -> EmptyResult {
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
#[derive(FromForm, Debug)]
struct WsAccessToken {
access_token: Option<String>,
}
if CONFIG.websocket_enabled()
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
{
err!(
"
###########################################################
'/notifications/hub' should be proxied to the websocket server or notifications won't work.
Go to the Wiki for more info, or disable WebSockets setting WEBSOCKET_ENABLED=false.
###########################################################################################\n"
)
} else {
Err(Error::empty())
struct WSEntryMapGuard {
users: Arc<WebSocketUsers>,
user_uuid: String,
entry_uuid: uuid::Uuid,
addr: IpAddr,
}
impl WSEntryMapGuard {
fn new(users: Arc<WebSocketUsers>, user_uuid: String, entry_uuid: uuid::Uuid, addr: IpAddr) -> Self {
Self {
users,
user_uuid,
entry_uuid,
addr,
}
}
}
impl Drop for WSEntryMapGuard {
fn drop(&mut self) {
info!("Closing WS connection from {}", self.addr);
if let Some(mut entry) = self.users.map.get_mut(&self.user_uuid) {
entry.retain(|(uuid, _)| uuid != &self.entry_uuid);
}
}
}
#[get("/hub?<data..>")]
async fn websockets_hub<'r>(
ws: rocket_ws::WebSocket,
data: WsAccessToken,
ip: ClientIp,
) -> Result<rocket_ws::Channel<'r>, Error> {
let addr = ip.ip;
info!("Accepting Rocket WS connection from {addr}");
let Some(token) = data.access_token else { err_code!("Invalid claim", 401) };
let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) };
let (mut rx, guard) = {
let users = Arc::clone(&WS_USERS);
// Add a channel to send messages to this client to the map
let entry_uuid = uuid::Uuid::new_v4();
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
users.map.entry(claims.sub.clone()).or_default().push((entry_uuid, tx));
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
(rx, WSEntryMapGuard::new(users, claims.sub, entry_uuid, addr))
};
Ok(ws.channel(move |mut stream| {
Box::pin(async move {
// Make sure the guard is moved into the channel future so it's not dropped earlier
let _guard = guard;
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
res = stream.next() => {
match res {
Some(Ok(message)) => {
match message {
// Respond to any pings
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
Message::Pong(_) => {/* Ignored */},
// We should receive an initial message with the protocol and version, and we will reply to it
Message::Text(ref message) => {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
continue;
}
}
// Just echo anything else the client sends
_ => stream.send(message).await?,
}
}
_ => break,
}
}
res = rx.recv() => {
match res {
Some(res) => stream.send(res).await?,
None => break,
}
}
_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
}
}
Ok(())
})
}))
}
//
// Websockets server
//
@ -127,8 +219,8 @@ impl WebSocketUsers {
async fn send_update(&self, user_uuid: &str, data: &[u8]) {
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
for (_, sender) in user.iter() {
if sender.send(Message::binary(data)).await.is_err() {
// TODO: Delete from map here too?
if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}");
}
}
}
@ -196,7 +288,7 @@ impl WebSocketUsers {
}
}
pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
pub async fn send_send_update(&self, ut: UpdateType, send: &DbSend, user_uuids: &[String]) {
let user_uuid = convert_option(send.user_uuid.clone());
let data = create_update(
@ -280,15 +372,12 @@ pub enum UpdateType {
None = 100,
}
pub type Notify<'a> = &'a rocket::State<WebSocketUsers>;
pub fn start_notification_server() -> WebSocketUsers {
let users = WebSocketUsers {
map: Arc::new(dashmap::DashMap::new()),
};
pub type Notify<'a> = &'a rocket::State<Arc<WebSocketUsers>>;
pub fn start_notification_server() -> Arc<WebSocketUsers> {
let users = Arc::clone(&WS_USERS);
if CONFIG.websocket_enabled() {
let users2 = users.clone();
let users2 = Arc::<WebSocketUsers>::clone(&users);
tokio::spawn(async move {
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
@ -300,7 +389,7 @@ pub fn start_notification_server() -> WebSocketUsers {
loop {
tokio::select! {
Ok((stream, addr)) = listener.accept() => {
tokio::spawn(handle_connection(stream, users2.clone(), addr));
tokio::spawn(handle_connection(stream, Arc::<WebSocketUsers>::clone(&users2), addr));
}
_ = &mut shutdown_rx => {
@ -316,7 +405,7 @@ pub fn start_notification_server() -> WebSocketUsers {
users
}
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: SocketAddr) -> Result<(), Error> {
async fn handle_connection(stream: TcpStream, users: Arc<WebSocketUsers>, addr: SocketAddr) -> Result<(), Error> {
let mut user_uuid: Option<String> = None;
info!("Accepting WS connection from {addr}");
@ -336,41 +425,39 @@ async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: Socke
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
// Add a channel to send messages to this client to the map
let entry_uuid = uuid::Uuid::new_v4();
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
let (mut rx, guard) = {
// Add a channel to send messages to this client to the map
let entry_uuid = uuid::Uuid::new_v4();
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
(rx, WSEntryMapGuard::new(users, user_uuid, entry_uuid, addr.ip()))
};
let _guard = guard;
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
res = stream.next() => {
match res {
Some(Ok(message)) => {
// Respond to any pings
if let Message::Ping(ping) = message {
if stream.send(Message::Pong(ping)).await.is_err() {
break;
match message {
// Respond to any pings
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
Message::Pong(_) => {/* Ignored */},
// We should receive an initial message with the protocol and version, and we will reply to it
Message::Text(ref message) => {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
continue;
}
}
continue;
} else if let Message::Pong(_) = message {
/* Ignored */
continue;
}
// We should receive an initial message with the protocol and version, and we will reply to it
if let Message::Text(ref message) = message {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
continue;
}
}
// Just echo anything else the client sends
if stream.send(message).await.is_err() {
break;
// Just echo anything else the client sends
_ => stream.send(message).await?,
}
}
_ => break,
@ -379,27 +466,15 @@ async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: Socke
res = rx.recv() => {
match res {
Some(res) => {
if stream.send(res).await.is_err() {
break;
}
},
Some(res) => stream.send(res).await?,
None => break,
}
}
_= interval.tick() => {
if stream.send(Message::Ping(create_ping())).await.is_err() {
break;
}
}
_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
}
}
info!("Closing WS connection from {addr}");
// Delete from map
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid);
Ok(())
}