Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking

This commit is contained in:
Daniel García
2021-11-07 18:53:39 +01:00
parent 08f0de7b46
commit 2d5f172e77
30 changed files with 1314 additions and 1028 deletions

View File

@ -5,10 +5,10 @@ use std::io::Cursor;
use rocket::{
fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
http::{ContentType, Header, HeaderMap, Method, Status},
request::FromParam,
response::{self, Responder},
Data, Request, Response, Rocket,
Data, Orbit, Request, Response, Rocket,
};
use std::thread::sleep;
@ -18,6 +18,7 @@ use crate::CONFIG;
pub struct AppHeaders();
#[rocket::async_trait]
impl Fairing for AppHeaders {
fn info(&self) -> Info {
Info {
@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
}
}
fn on_response(&self, _req: &Request, res: &mut Response) {
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
res.set_raw_header("Referrer-Policy", "same-origin");
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
@ -72,6 +73,7 @@ impl Cors {
}
}
#[rocket::async_trait]
impl Fairing for Cors {
fn info(&self) -> Info {
Info {
@ -80,7 +82,7 @@ impl Fairing for Cors {
}
}
fn on_response(&self, request: &Request, response: &mut Response) {
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
let req_headers = request.headers();
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
@ -97,7 +99,7 @@ impl Fairing for Cors {
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
response.set_status(Status::Ok);
response.set_header(ContentType::Plain);
response.set_sized_body(Cursor::new(""));
response.set_sized_body(Some(0), Cursor::new(""));
}
}
}
@ -134,25 +136,21 @@ impl<R> Cached<R> {
}
}
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
fn respond_to(self, req: &Request) -> response::Result<'r> {
impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
let mut res = self.response.respond_to(request)?;
let cache_control_header = if self.is_immutable {
format!("public, immutable, max-age={}", self.ttl)
} else {
format!("public, max-age={}", self.ttl)
};
res.set_raw_header("Cache-Control", cache_control_header);
let time_now = chrono::Local::now();
match self.response.respond_to(req) {
Ok(mut res) => {
res.set_raw_header("Cache-Control", cache_control_header);
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
Ok(res)
}
e @ Err(_) => e,
}
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
Ok(res)
}
}
@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
type Error = ();
#[inline(always)]
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(SafeString(s))
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(SafeString(param.to_string()))
} else {
Err(())
}
@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
pub struct BetterLogging(pub bool);
#[rocket::async_trait]
impl Fairing for BetterLogging {
fn info(&self) -> Info {
Info {
name: "Better Logging",
kind: Kind::Launch | Kind::Request | Kind::Response,
kind: Kind::Liftoff | Kind::Request | Kind::Response,
}
}
fn on_launch(&self, rocket: &Rocket) {
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
if self.0 {
info!(target: "routes", "Routes loaded:");
let mut routes: Vec<_> = rocket.routes().collect();
@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
info!(target: "start", "Rocket has launched from {}", addr);
}
fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
let method = request.method();
if !self.0 && method == Method::Options {
return;
}
let uri = request.uri();
let uri_path = uri.path();
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
let uri_path_str = uri_path.url_decode_lossy();
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
match uri.query() {
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
None => info!(target: "request", "{} {}", method, uri_path),
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
None => info!(target: "request", "{} {}", method, uri_path_str),
};
}
}
fn on_response(&self, request: &Request, response: &mut Response) {
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
if !self.0 && request.method() == Method::Options {
return;
}
let uri_path = request.uri().path();
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
let uri_path_str = uri_path.url_decode_lossy();
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
let status = response.status();
if let Some(route) = request.route() {
info!(target: "response", "{} => {} {}", route, status.code, status.reason)
if let Some(ref route) = request.route() {
info!(target: "response", "{} => {}", route, status)
} else {
info!(target: "response", "{} {}", status.code, status.reason)
info!(target: "response", "{}", status)
}
}
}
@ -614,10 +613,7 @@ where
}
}
use reqwest::{
blocking::{Client, ClientBuilder},
header,
};
use reqwest::{header, Client, ClientBuilder};
pub fn get_reqwest_client() -> Client {
get_reqwest_client_builder().build().expect("Failed to build client")