mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-06-12 05:07:40 +02:00
WebSockets via Rocket's Upgrade connection
This PR implements a (not yet fully released) new feature of Rocket which allows WebSockets/Upgrade connections. No more need for multiple ports to be opened for Vaultwarden. No explicit need for a reverse proxy to get WebSockets to work (Although I still suggest to use a reverse proxy). - Using a git revision for Rocket, since `rocket_ws` is not yet released. - Updated other crates as well. - Added a connection guard to clear the WS connection from the Users list. Fixes #685 Fixes #2917 Fixes #1424
This commit is contained in:
@ -1,16 +1,15 @@
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use rmpv::Value;
|
||||
use rocket::Route;
|
||||
use rocket::{
|
||||
futures::{SinkExt, StreamExt},
|
||||
Route,
|
||||
};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::mpsc::Sender,
|
||||
@ -21,34 +20,127 @@ use tokio_tungstenite::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
api::EmptyResult,
|
||||
db::models::{Cipher, Folder, Send, User},
|
||||
auth::ClientIp,
|
||||
db::models::{Cipher, Folder, Send as DbSend, User},
|
||||
Error, CONFIG,
|
||||
};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
static WS_USERS: Lazy<Arc<WebSocketUsers>> = Lazy::new(|| {
|
||||
Arc::new(WebSocketUsers {
|
||||
map: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
});
|
||||
|
||||
pub fn routes() -> Vec<Route> {
|
||||
routes![websockets_err]
|
||||
routes![websockets_hub]
|
||||
}
|
||||
|
||||
#[get("/hub")]
|
||||
fn websockets_err() -> EmptyResult {
|
||||
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
|
||||
#[derive(FromForm, Debug)]
|
||||
struct WsAccessToken {
|
||||
access_token: Option<String>,
|
||||
}
|
||||
|
||||
if CONFIG.websocket_enabled()
|
||||
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
|
||||
{
|
||||
err!(
|
||||
"
|
||||
###########################################################
|
||||
'/notifications/hub' should be proxied to the websocket server or notifications won't work.
|
||||
Go to the Wiki for more info, or disable WebSockets setting WEBSOCKET_ENABLED=false.
|
||||
###########################################################################################\n"
|
||||
)
|
||||
} else {
|
||||
Err(Error::empty())
|
||||
struct WSEntryMapGuard {
|
||||
users: Arc<WebSocketUsers>,
|
||||
user_uuid: String,
|
||||
entry_uuid: uuid::Uuid,
|
||||
addr: IpAddr,
|
||||
}
|
||||
|
||||
impl WSEntryMapGuard {
|
||||
fn new(users: Arc<WebSocketUsers>, user_uuid: String, entry_uuid: uuid::Uuid, addr: IpAddr) -> Self {
|
||||
Self {
|
||||
users,
|
||||
user_uuid,
|
||||
entry_uuid,
|
||||
addr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WSEntryMapGuard {
|
||||
fn drop(&mut self) {
|
||||
info!("Closing WS connection from {}", self.addr);
|
||||
if let Some(mut entry) = self.users.map.get_mut(&self.user_uuid) {
|
||||
entry.retain(|(uuid, _)| uuid != &self.entry_uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/hub?<data..>")]
|
||||
async fn websockets_hub<'r>(
|
||||
ws: rocket_ws::WebSocket,
|
||||
data: WsAccessToken,
|
||||
ip: ClientIp,
|
||||
) -> Result<rocket_ws::Channel<'r>, Error> {
|
||||
let addr = ip.ip;
|
||||
info!("Accepting Rocket WS connection from {addr}");
|
||||
|
||||
let Some(token) = data.access_token else { err_code!("Invalid claim", 401) };
|
||||
let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) };
|
||||
|
||||
let (mut rx, guard) = {
|
||||
let users = Arc::clone(&WS_USERS);
|
||||
|
||||
// Add a channel to send messages to this client to the map
|
||||
let entry_uuid = uuid::Uuid::new_v4();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
|
||||
users.map.entry(claims.sub.clone()).or_default().push((entry_uuid, tx));
|
||||
|
||||
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
|
||||
(rx, WSEntryMapGuard::new(users, claims.sub, entry_uuid, addr))
|
||||
};
|
||||
|
||||
Ok(ws.channel(move |mut stream| {
|
||||
Box::pin(async move {
|
||||
// Make sure the guard is moved into the channel future so it's not dropped earlier
|
||||
let _guard = guard;
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(15));
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = stream.next() => {
|
||||
match res {
|
||||
Some(Ok(message)) => {
|
||||
match message {
|
||||
// Respond to any pings
|
||||
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
|
||||
Message::Pong(_) => {/* Ignored */},
|
||||
|
||||
// We should receive an initial message with the protocol and version, and we will reply to it
|
||||
Message::Text(ref message) => {
|
||||
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
|
||||
|
||||
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
||||
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Just echo anything else the client sends
|
||||
_ => stream.send(message).await?,
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
res = rx.recv() => {
|
||||
match res {
|
||||
Some(res) => stream.send(res).await?,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
//
|
||||
// Websockets server
|
||||
//
|
||||
@ -127,8 +219,8 @@ impl WebSocketUsers {
|
||||
async fn send_update(&self, user_uuid: &str, data: &[u8]) {
|
||||
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
|
||||
for (_, sender) in user.iter() {
|
||||
if sender.send(Message::binary(data)).await.is_err() {
|
||||
// TODO: Delete from map here too?
|
||||
if let Err(e) = sender.send(Message::binary(data)).await {
|
||||
error!("Error sending WS update {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -196,7 +288,7 @@ impl WebSocketUsers {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
|
||||
pub async fn send_send_update(&self, ut: UpdateType, send: &DbSend, user_uuids: &[String]) {
|
||||
let user_uuid = convert_option(send.user_uuid.clone());
|
||||
|
||||
let data = create_update(
|
||||
@ -280,15 +372,12 @@ pub enum UpdateType {
|
||||
None = 100,
|
||||
}
|
||||
|
||||
pub type Notify<'a> = &'a rocket::State<WebSocketUsers>;
|
||||
|
||||
pub fn start_notification_server() -> WebSocketUsers {
|
||||
let users = WebSocketUsers {
|
||||
map: Arc::new(dashmap::DashMap::new()),
|
||||
};
|
||||
pub type Notify<'a> = &'a rocket::State<Arc<WebSocketUsers>>;
|
||||
|
||||
pub fn start_notification_server() -> Arc<WebSocketUsers> {
|
||||
let users = Arc::clone(&WS_USERS);
|
||||
if CONFIG.websocket_enabled() {
|
||||
let users2 = users.clone();
|
||||
let users2 = Arc::<WebSocketUsers>::clone(&users);
|
||||
tokio::spawn(async move {
|
||||
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
|
||||
info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
|
||||
@ -300,7 +389,7 @@ pub fn start_notification_server() -> WebSocketUsers {
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok((stream, addr)) = listener.accept() => {
|
||||
tokio::spawn(handle_connection(stream, users2.clone(), addr));
|
||||
tokio::spawn(handle_connection(stream, Arc::<WebSocketUsers>::clone(&users2), addr));
|
||||
}
|
||||
|
||||
_ = &mut shutdown_rx => {
|
||||
@ -316,7 +405,7 @@ pub fn start_notification_server() -> WebSocketUsers {
|
||||
users
|
||||
}
|
||||
|
||||
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: SocketAddr) -> Result<(), Error> {
|
||||
async fn handle_connection(stream: TcpStream, users: Arc<WebSocketUsers>, addr: SocketAddr) -> Result<(), Error> {
|
||||
let mut user_uuid: Option<String> = None;
|
||||
|
||||
info!("Accepting WS connection from {addr}");
|
||||
@ -336,41 +425,39 @@ async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: Socke
|
||||
|
||||
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
|
||||
|
||||
// Add a channel to send messages to this client to the map
|
||||
let entry_uuid = uuid::Uuid::new_v4();
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
|
||||
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
|
||||
let (mut rx, guard) = {
|
||||
// Add a channel to send messages to this client to the map
|
||||
let entry_uuid = uuid::Uuid::new_v4();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
|
||||
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
|
||||
|
||||
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
|
||||
(rx, WSEntryMapGuard::new(users, user_uuid, entry_uuid, addr.ip()))
|
||||
};
|
||||
|
||||
let _guard = guard;
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(15));
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = stream.next() => {
|
||||
match res {
|
||||
Some(Ok(message)) => {
|
||||
// Respond to any pings
|
||||
if let Message::Ping(ping) = message {
|
||||
if stream.send(Message::Pong(ping)).await.is_err() {
|
||||
break;
|
||||
match message {
|
||||
// Respond to any pings
|
||||
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
|
||||
Message::Pong(_) => {/* Ignored */},
|
||||
|
||||
// We should receive an initial message with the protocol and version, and we will reply to it
|
||||
Message::Text(ref message) => {
|
||||
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
|
||||
|
||||
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
||||
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else if let Message::Pong(_) = message {
|
||||
/* Ignored */
|
||||
continue;
|
||||
}
|
||||
|
||||
// We should receive an initial message with the protocol and version, and we will reply to it
|
||||
if let Message::Text(ref message) = message {
|
||||
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
|
||||
|
||||
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
||||
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Just echo anything else the client sends
|
||||
if stream.send(message).await.is_err() {
|
||||
break;
|
||||
// Just echo anything else the client sends
|
||||
_ => stream.send(message).await?,
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
@ -379,27 +466,15 @@ async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: Socke
|
||||
|
||||
res = rx.recv() => {
|
||||
match res {
|
||||
Some(res) => {
|
||||
if stream.send(res).await.is_err() {
|
||||
break;
|
||||
}
|
||||
},
|
||||
Some(res) => stream.send(res).await?,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
_= interval.tick() => {
|
||||
if stream.send(Message::Ping(create_ping())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
|
||||
}
|
||||
}
|
||||
|
||||
info!("Closing WS connection from {addr}");
|
||||
|
||||
// Delete from map
|
||||
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user