Implement fd I/O on Rust side

This commit is contained in:
topjohnwu 2025-01-29 21:02:22 +08:00 committed by John Wu
parent a48a9c858a
commit b575c95710
8 changed files with 42 additions and 142 deletions

View File

@ -7,14 +7,15 @@ use crate::package::ManagerInfo;
use base::libc::{O_CLOEXEC, O_RDONLY}; use base::libc::{O_CLOEXEC, O_RDONLY};
use base::{ use base::{
cstr, info, libc, open_fd, warn, BufReadExt, Directory, FsPath, FsPathBuf, LoggedResult, cstr, info, libc, open_fd, warn, BufReadExt, Directory, FsPath, FsPathBuf, LoggedResult,
ReadExt, Utf8CStr, Utf8CStrBufArr, WriteExt, ReadExt, ResultExt, Utf8CStr, Utf8CStrBufArr, WriteExt,
}; };
use bit_set::BitSet; use bit_set::BitSet;
use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable}; use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable};
use std::fs::File; use std::fs::File;
use std::io; use std::io;
use std::io::{BufReader, ErrorKind, IoSlice, IoSliceMut, Read, Write}; use std::io::{BufReader, ErrorKind, IoSlice, IoSliceMut, Read, Write};
use std::os::fd::{FromRawFd, OwnedFd, RawFd}; use std::mem::ManuallyDrop;
use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream}; use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream};
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
@ -386,3 +387,29 @@ impl UnixSocketExt for UnixStream {
Ok(fds) Ok(fds)
} }
} }
pub fn send_fd(socket: RawFd, fd: RawFd) -> bool {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket.send_fds(&[fd]).log().is_ok()
}
pub fn send_fds(socket: RawFd, fds: &[RawFd]) -> bool {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket.send_fds(fds).log().is_ok()
}
pub fn recv_fd(socket: RawFd) -> RawFd {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket
.recv_fd()
.log()
.unwrap_or(None)
.map_or(-1, IntoRawFd::into_raw_fd)
}
pub fn recv_fds(socket: RawFd) -> Vec<RawFd> {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
let fds = socket.recv_fds().log().unwrap_or(Vec::new());
// SAFETY: OwnedFd and RawFd has the same layout
unsafe { std::mem::transmute(fds) }
}

View File

@ -11,10 +11,6 @@ struct sock_cred : public ucred {
}; };
bool get_client_cred(int fd, sock_cred *cred); bool get_client_cred(int fd, sock_cred *cred);
std::vector<int> recv_fds(int sockfd);
int recv_fd(int sockfd);
int send_fds(int sockfd, const int *fds, int cnt);
int send_fd(int sockfd, int fd);
int read_int(int fd); int read_int(int fd);
int read_int_be(int fd); int read_int_be(int fd);
void write_int(int fd, int val); void write_int(int fd, int val);

View File

@ -7,7 +7,7 @@
#![allow(clippy::missing_safety_doc)] #![allow(clippy::missing_safety_doc)]
use base::Utf8CStr; use base::Utf8CStr;
use daemon::{daemon_entry, get_magiskd, MagiskD}; use daemon::{daemon_entry, get_magiskd, recv_fd, recv_fds, send_fd, send_fds, MagiskD};
use db::get_default_db_settings; use db::get_default_db_settings;
use logging::{android_logging, setup_logfile, zygisk_close_logd, zygisk_get_logd, zygisk_logging}; use logging::{android_logging, setup_logfile, zygisk_close_logd, zygisk_get_logd, zygisk_logging};
use mount::{clean_mounts, find_preinit_device, revert_unmount, setup_mounts}; use mount::{clean_mounts, find_preinit_device, revert_unmount, setup_mounts};
@ -195,6 +195,10 @@ pub mod ffi {
unsafe fn persist_get_props(prop_cb: Pin<&mut PropCb>); unsafe fn persist_get_props(prop_cb: Pin<&mut PropCb>);
unsafe fn persist_delete_prop(name: Utf8CStrRef) -> bool; unsafe fn persist_delete_prop(name: Utf8CStrRef) -> bool;
unsafe fn persist_set_prop(name: Utf8CStrRef, value: Utf8CStrRef) -> bool; unsafe fn persist_set_prop(name: Utf8CStrRef, value: Utf8CStrRef) -> bool;
fn send_fd(socket: i32, fd: i32) -> bool;
fn send_fds(socket: i32, fds: &[i32]) -> bool;
fn recv_fd(socket: i32) -> i32;
fn recv_fds(socket: i32) -> Vec<i32>;
#[namespace = "rust"] #[namespace = "rust"]
fn daemon_entry(); fn daemon_entry();

View File

@ -19,133 +19,6 @@ bool get_client_cred(int fd, sock_cred *cred) {
return true; return true;
} }
static int send_fds(int sockfd, void *cmsgbuf, size_t bufsz, const int *fds, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
if (cnt) {
msg.msg_control = cmsgbuf;
msg.msg_controllen = bufsz;
cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof(int) * cnt);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * cnt);
}
return xsendmsg(sockfd, &msg, 0);
}
int send_fds(int sockfd, const int *fds, int cnt) {
if (cnt == 0) {
return send_fds(sockfd, nullptr, 0, nullptr, 0);
}
vector<char> cmsgbuf;
cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt));
return send_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), fds, cnt);
}
int send_fd(int sockfd, int fd) {
if (fd < 0) {
return send_fds(sockfd, nullptr, 0, nullptr, 0);
}
char cmsgbuf[CMSG_SPACE(sizeof(int))];
return send_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), &fd, 1);
}
static void *recv_fds(int sockfd, char *cmsgbuf, size_t bufsz, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cmsgbuf,
.msg_controllen = bufsz
};
xrecvmsg(sockfd, &msg, MSG_WAITALL);
if (msg.msg_controllen != bufsz) {
LOGE("recv_fd: msg_flags = %d, msg_controllen(%zu) != %zu\n",
msg.msg_flags, msg.msg_controllen, bufsz);
return nullptr;
}
cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg == nullptr) {
LOGE("recv_fd: cmsg == nullptr\n");
return nullptr;
}
if (cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt)) {
LOGE("recv_fd: cmsg_len(%zu) != %zu\n", cmsg->cmsg_len, CMSG_LEN(sizeof(int) * cnt));
return nullptr;
}
if (cmsg->cmsg_level != SOL_SOCKET) {
LOGE("recv_fd: cmsg_level != SOL_SOCKET\n");
return nullptr;
}
if (cmsg->cmsg_type != SCM_RIGHTS) {
LOGE("recv_fd: cmsg_type != SCM_RIGHTS\n");
return nullptr;
}
return CMSG_DATA(cmsg);
}
vector<int> recv_fds(int sockfd) {
vector<int> results;
// Peek fd count to allocate proper buffer
int cnt;
recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK);
if (cnt == 0) {
// Consume data
recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL);
return results;
}
vector<char> cmsgbuf;
cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt));
void *data = recv_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), cnt);
if (data == nullptr)
return results;
results.resize(cnt);
memcpy(results.data(), data, sizeof(int) * cnt);
return results;
}
int recv_fd(int sockfd) {
// Peek fd count
int cnt;
recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK);
if (cnt == 0) {
// Consume data
recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL);
return -1;
}
char cmsgbuf[CMSG_SPACE(sizeof(int))];
void *data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1);
if (data == nullptr)
return -1;
int result;
memcpy(&result, data, sizeof(int));
return result;
}
int read_int(int fd) { int read_int(int fd) {
int val; int val;
if (xxread(fd, &val, sizeof(val)) != sizeof(val)) if (xxread(fd, &val, sizeof(val)) != sizeof(val))

View File

@ -68,7 +68,7 @@ void MagiskD::connect_zygiskd(int client) const noexcept {
} }
close(fds[1]); close(fds[1]);
rust::Vec<int> module_fds = get_module_fds(is_64_bit); rust::Vec<int> module_fds = get_module_fds(is_64_bit);
send_fds(zygiskd_socket, module_fds.data(), module_fds.size()); send_fds(zygiskd_socket, rust::Slice<const int>(module_fds));
// Wait for ack // Wait for ack
if (read_int(zygiskd_socket) != 0) { if (read_int(zygiskd_socket) != 0) {
LOGE("zygiskd startup error\n"); LOGE("zygiskd startup error\n");

View File

@ -29,7 +29,7 @@ static void zygiskd(int socket) {
using comp_entry = void(*)(int); using comp_entry = void(*)(int);
vector<comp_entry> modules; vector<comp_entry> modules;
{ {
vector<int> module_fds = recv_fds(socket); auto module_fds = recv_fds(socket);
for (int fd : module_fds) { for (int fd : module_fds) {
comp_entry entry = nullptr; comp_entry entry = nullptr;
struct stat s{}; struct stat s{};

View File

@ -208,7 +208,7 @@ bool ZygiskContext::plt_hook_commit() {
// ----------------------------------------------------------------- // -----------------------------------------------------------------
int ZygiskContext::get_module_info(int uid, std::vector<int> &fds) { int ZygiskContext::get_module_info(int uid, rust::Vec<int> &fds) {
if (int fd = zygisk_request(+ZygiskRequest::GetInfo); fd >= 0) { if (int fd = zygisk_request(+ZygiskRequest::GetInfo); fd >= 0) {
write_int(fd, uid); write_int(fd, uid);
write_string(fd, process); write_string(fd, process);
@ -331,7 +331,7 @@ void ZygiskContext::fork_post() {
sigmask(SIG_UNBLOCK, SIGCHLD); sigmask(SIG_UNBLOCK, SIGCHLD);
} }
void ZygiskContext::run_modules_pre(vector<int> &fds) { void ZygiskContext::run_modules_pre(rust::Vec<int> &fds) {
for (int i = 0; i < fds.size(); ++i) { for (int i = 0; i < fds.size(); ++i) {
owned_fd fd = fds[i]; owned_fd fd = fds[i];
struct stat s{}; struct stat s{};
@ -386,7 +386,7 @@ void ZygiskContext::run_modules_post() {
void ZygiskContext::app_specialize_pre() { void ZygiskContext::app_specialize_pre() {
flags |= APP_SPECIALIZE; flags |= APP_SPECIALIZE;
vector<int> module_fds; rust::Vec<int> module_fds;
owned_fd fd = get_module_info(args.app->uid, module_fds); owned_fd fd = get_module_info(args.app->uid, module_fds);
if ((info_flags & UNMOUNT_MASK) == UNMOUNT_MASK) { if ((info_flags & UNMOUNT_MASK) == UNMOUNT_MASK) {
ZLOGI("[%s] is on the denylist\n", process); ZLOGI("[%s] is on the denylist\n", process);
@ -407,7 +407,7 @@ void ZygiskContext::app_specialize_post() {
} }
void ZygiskContext::server_specialize_pre() { void ZygiskContext::server_specialize_pre() {
vector<int> module_fds; rust::Vec<int> module_fds;
if (owned_fd fd = get_module_info(1000, module_fds); fd >= 0) { if (owned_fd fd = get_module_info(1000, module_fds); fd >= 0) {
if (module_fds.empty()) { if (module_fds.empty()) {
write_int(fd, 0); write_int(fd, 0);

View File

@ -261,7 +261,7 @@ struct ZygiskContext {
ZygiskContext(JNIEnv *env, void *args); ZygiskContext(JNIEnv *env, void *args);
~ZygiskContext(); ~ZygiskContext();
void run_modules_pre(std::vector<int> &fds); void run_modules_pre(rust::Vec<int> &fds);
void run_modules_post(); void run_modules_post();
DCL_PRE_POST(fork) DCL_PRE_POST(fork)
DCL_PRE_POST(app_specialize) DCL_PRE_POST(app_specialize)
@ -270,7 +270,7 @@ struct ZygiskContext {
DCL_PRE_POST(nativeSpecializeAppProcess) DCL_PRE_POST(nativeSpecializeAppProcess)
DCL_PRE_POST(nativeForkSystemServer) DCL_PRE_POST(nativeForkSystemServer)
int get_module_info(int uid, std::vector<int> &fds); int get_module_info(int uid, rust::Vec<int> &fds);
void sanitize_fds(); void sanitize_fds();
bool exempt_fd(int fd); bool exempt_fd(int fd);
bool can_exempt_fd() const; bool can_exempt_fd() const;