Implement custom DNS resolver (#3988)

This commit is contained in:
Daniel García
2024-04-27 20:25:34 +02:00
committed by GitHub
parent 2ad33ec97f
commit 27dc67fadd
6 changed files with 345 additions and 255 deletions

View File

@ -4,6 +4,7 @@
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
use num_traits::ToPrimitive;
use once_cell::sync::Lazy;
use rocket::{
fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, Status},
@ -701,14 +702,9 @@ where
use reqwest::{header, Client, ClientBuilder};
pub fn get_reqwest_client() -> Client {
match get_reqwest_client_builder().build() {
Ok(client) => client,
Err(e) => {
error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
get_reqwest_client_builder().hickory_dns(false).build().expect("Failed to build client")
}
}
pub fn get_reqwest_client() -> &'static Client {
static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
&INSTANCE
}
pub fn get_reqwest_client_builder() -> ClientBuilder {
@ -767,3 +763,248 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags
feature_states
}
mod dns_resolver {
use std::{
fmt,
net::{IpAddr, SocketAddr},
sync::Arc,
};
use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver};
use once_cell::sync::Lazy;
use reqwest::dns::{Name, Resolve, Resolving};
use crate::{util::is_global, CONFIG};
#[derive(Debug, Clone)]
pub enum CustomResolverError {
Blacklist {
domain: String,
},
NonGlobalIp {
domain: String,
ip: IpAddr,
},
}
impl CustomResolverError {
pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
let mut source = e.source();
while let Some(err) = source {
source = err.source();
if let Some(err) = err.downcast_ref::<CustomResolverError>() {
return Some(err);
}
}
None
}
}
impl fmt::Display for CustomResolverError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Blacklist {
domain,
} => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"),
Self::NonGlobalIp {
domain,
ip,
} => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
}
}
}
impl std::error::Error for CustomResolverError {}
#[derive(Debug, Clone)]
pub enum CustomDnsResolver {
Default(),
Hickory(Arc<TokioAsyncResolver>),
}
type BoxError = Box<dyn std::error::Error + Send + Sync>;
impl CustomDnsResolver {
pub fn instance() -> Arc<Self> {
static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new);
Arc::clone(&*INSTANCE)
}
fn new() -> Arc<Self> {
match read_system_conf() {
Ok((config, opts)) => {
let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone());
Arc::new(Self::Hickory(Arc::new(resolver)))
}
Err(e) => {
warn!("Error creating Hickory resolver, falling back to default: {e:?}");
Arc::new(Self::Default())
}
}
}
// Note that we get an iterator of addresses, but we only grab the first one for convenience
async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
pre_resolve(name)?;
let result = match self {
Self::Default() => tokio::net::lookup_host(name).await?.next(),
Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
};
if let Some(addr) = &result {
post_resolve(name, addr.ip())?;
}
Ok(result)
}
}
fn pre_resolve(name: &str) -> Result<(), CustomResolverError> {
if crate::api::is_domain_blacklisted(name) {
return Err(CustomResolverError::Blacklist {
domain: name.to_string(),
});
}
Ok(())
}
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> {
if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) {
Err(CustomResolverError::NonGlobalIp {
domain: name.to_string(),
ip,
})
} else {
Ok(())
}
}
impl Resolve for CustomDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let this = self.clone();
Box::pin(async move {
let name = name.as_str();
let result = this.resolve_domain(name).await?;
Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
})
}
}
}
pub use dns_resolver::{CustomDnsResolver, CustomResolverError};
/// TODO: This is extracted from IpAddr::is_global, which is unstable:
/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
#[allow(clippy::nonminimal_bool)]
#[cfg(any(not(feature = "unstable"), test))]
pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(ip) => {
!(ip.octets()[0] == 0 // "This network"
|| ip.is_private()
|| (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared()
|| ip.is_loopback()
|| ip.is_link_local()
// addresses reserved for future protocols (`192.0.0.0/24`)
||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0)
|| ip.is_documentation()
|| (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking()
|| (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved()
|| ip.is_broadcast())
}
std::net::IpAddr::V6(ip) => {
!(ip.is_unspecified()
|| ip.is_loopback()
// IPv4-mapped Address (`::ffff:0:0/96`)
|| matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
// IPv4-IPv6 Translat. (`64:ff9b:1::/48`)
|| matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _])
// Discard-Only Address Block (`100::/64`)
|| matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _])
// IETF Protocol Assignments (`2001::/23`)
|| (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200)
&& !(
// Port Control Protocol Anycast (`2001:1::1`)
u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001
// Traversal Using Relays around NAT Anycast (`2001:1::2`)
|| u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002
// AMT (`2001:3::/32`)
|| matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _])
// AS112-v6 (`2001:4:112::/48`)
|| matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _])
// ORCHIDv2 (`2001:20::/28`)
|| matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b))
))
|| ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation()
|| ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local()
|| ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local()
}
}
}
#[cfg(not(feature = "unstable"))]
pub use is_global_hardcoded as is_global;
#[cfg(feature = "unstable")]
#[inline(always)]
pub fn is_global(ip: std::net::IpAddr) -> bool {
ip.is_global()
}
/// These are some tests to check that the implementations match
/// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17
/// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct
/// Note that the is_global implementation is subject to change as new IP RFCs are created
///
/// To run while showing progress output:
/// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored
#[cfg(test)]
#[cfg(feature = "unstable")]
mod tests {
use super::*;
use std::net::IpAddr;
#[test]
#[ignore]
fn test_ipv4_global() {
for a in 0..u8::MAX {
println!("Iter: {}/255", a);
for b in 0..u8::MAX {
for c in 0..u8::MAX {
for d in 0..u8::MAX {
let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d));
assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip)
}
}
}
}
}
#[test]
#[ignore]
fn test_ipv6_global() {
use rand::Rng;
std::thread::scope(|s| {
for t in 0..16 {
let handle = s.spawn(move || {
let mut v = [0u8; 16];
let mut rng = rand::thread_rng();
for i in 0..20 {
println!("Thread {t} Iter: {i}/50");
for _ in 0..500_000_000 {
rng.fill(&mut v);
let ip = IpAddr::V6(std::net::Ipv6Addr::from(v));
assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}");
}
}
});
}
});
}
}