mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-06-12 05:07:40 +02:00
Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
70
src/util.rs
70
src/util.rs
@ -5,10 +5,10 @@ use std::io::Cursor;
|
||||
|
||||
use rocket::{
|
||||
fairing::{Fairing, Info, Kind},
|
||||
http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
|
||||
http::{ContentType, Header, HeaderMap, Method, Status},
|
||||
request::FromParam,
|
||||
response::{self, Responder},
|
||||
Data, Request, Response, Rocket,
|
||||
Data, Orbit, Request, Response, Rocket,
|
||||
};
|
||||
|
||||
use std::thread::sleep;
|
||||
@ -18,6 +18,7 @@ use crate::CONFIG;
|
||||
|
||||
pub struct AppHeaders();
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for AppHeaders {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, _req: &Request, res: &mut Response) {
|
||||
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
|
||||
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
||||
res.set_raw_header("Referrer-Policy", "same-origin");
|
||||
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
||||
@ -72,6 +73,7 @@ impl Cors {
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for Cors {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
@ -80,7 +82,7 @@ impl Fairing for Cors {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
||||
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||
let req_headers = request.headers();
|
||||
|
||||
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
||||
@ -97,7 +99,7 @@ impl Fairing for Cors {
|
||||
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
||||
response.set_status(Status::Ok);
|
||||
response.set_header(ContentType::Plain);
|
||||
response.set_sized_body(Cursor::new(""));
|
||||
response.set_sized_body(Some(0), Cursor::new(""));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -134,25 +136,21 @@ impl<R> Cached<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
|
||||
fn respond_to(self, req: &Request) -> response::Result<'r> {
|
||||
impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
|
||||
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
|
||||
let mut res = self.response.respond_to(request)?;
|
||||
|
||||
let cache_control_header = if self.is_immutable {
|
||||
format!("public, immutable, max-age={}", self.ttl)
|
||||
} else {
|
||||
format!("public, max-age={}", self.ttl)
|
||||
};
|
||||
res.set_raw_header("Cache-Control", cache_control_header);
|
||||
|
||||
let time_now = chrono::Local::now();
|
||||
|
||||
match self.response.respond_to(req) {
|
||||
Ok(mut res) => {
|
||||
res.set_raw_header("Cache-Control", cache_control_header);
|
||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||
Ok(res)
|
||||
}
|
||||
e @ Err(_) => e,
|
||||
}
|
||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
|
||||
type Error = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
|
||||
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
|
||||
|
||||
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||
Ok(SafeString(s))
|
||||
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
|
||||
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||
Ok(SafeString(param.to_string()))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
|
||||
|
||||
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
||||
pub struct BetterLogging(pub bool);
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for BetterLogging {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
name: "Better Logging",
|
||||
kind: Kind::Launch | Kind::Request | Kind::Response,
|
||||
kind: Kind::Liftoff | Kind::Request | Kind::Response,
|
||||
}
|
||||
}
|
||||
|
||||
fn on_launch(&self, rocket: &Rocket) {
|
||||
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||
if self.0 {
|
||||
info!(target: "routes", "Routes loaded:");
|
||||
let mut routes: Vec<_> = rocket.routes().collect();
|
||||
@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
|
||||
info!(target: "start", "Rocket has launched from {}", addr);
|
||||
}
|
||||
|
||||
fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
|
||||
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
|
||||
let method = request.method();
|
||||
if !self.0 && method == Method::Options {
|
||||
return;
|
||||
}
|
||||
let uri = request.uri();
|
||||
let uri_path = uri.path();
|
||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
||||
let uri_path_str = uri_path.url_decode_lossy();
|
||||
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||
match uri.query() {
|
||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
|
||||
None => info!(target: "request", "{} {}", method, uri_path),
|
||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
|
||||
None => info!(target: "request", "{} {}", method, uri_path_str),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
||||
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||
if !self.0 && request.method() == Method::Options {
|
||||
return;
|
||||
}
|
||||
let uri_path = request.uri().path();
|
||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
||||
let uri_path_str = uri_path.url_decode_lossy();
|
||||
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||
let status = response.status();
|
||||
if let Some(route) = request.route() {
|
||||
info!(target: "response", "{} => {} {}", route, status.code, status.reason)
|
||||
if let Some(ref route) = request.route() {
|
||||
info!(target: "response", "{} => {}", route, status)
|
||||
} else {
|
||||
info!(target: "response", "{} {}", status.code, status.reason)
|
||||
info!(target: "response", "{}", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -614,10 +613,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
use reqwest::{
|
||||
blocking::{Client, ClientBuilder},
|
||||
header,
|
||||
};
|
||||
use reqwest::{header, Client, ClientBuilder};
|
||||
|
||||
pub fn get_reqwest_client() -> Client {
|
||||
get_reqwest_client_builder().build().expect("Failed to build client")
|
||||
|
Reference in New Issue
Block a user