Implement fd I/O on Rust side

This commit is contained in:
topjohnwu 2025-01-29 21:02:22 +08:00 committed by John Wu
parent a48a9c858a
commit b575c95710
8 changed files with 42 additions and 142 deletions

View File

@ -7,14 +7,15 @@ use crate::package::ManagerInfo;
use base::libc::{O_CLOEXEC, O_RDONLY};
use base::{
cstr, info, libc, open_fd, warn, BufReadExt, Directory, FsPath, FsPathBuf, LoggedResult,
ReadExt, Utf8CStr, Utf8CStrBufArr, WriteExt,
ReadExt, ResultExt, Utf8CStr, Utf8CStrBufArr, WriteExt,
};
use bit_set::BitSet;
use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable};
use std::fs::File;
use std::io;
use std::io::{BufReader, ErrorKind, IoSlice, IoSliceMut, Read, Write};
use std::os::fd::{FromRawFd, OwnedFd, RawFd};
use std::mem::ManuallyDrop;
use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream};
use std::sync::{Mutex, OnceLock};
@ -386,3 +387,29 @@ impl UnixSocketExt for UnixStream {
Ok(fds)
}
}
pub fn send_fd(socket: RawFd, fd: RawFd) -> bool {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket.send_fds(&[fd]).log().is_ok()
}
pub fn send_fds(socket: RawFd, fds: &[RawFd]) -> bool {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket.send_fds(fds).log().is_ok()
}
pub fn recv_fd(socket: RawFd) -> RawFd {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
socket
.recv_fd()
.log()
.unwrap_or(None)
.map_or(-1, IntoRawFd::into_raw_fd)
}
pub fn recv_fds(socket: RawFd) -> Vec<RawFd> {
let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) });
let fds = socket.recv_fds().log().unwrap_or(Vec::new());
// SAFETY: OwnedFd and RawFd has the same layout
unsafe { std::mem::transmute(fds) }
}

View File

@ -11,10 +11,6 @@ struct sock_cred : public ucred {
};
bool get_client_cred(int fd, sock_cred *cred);
std::vector<int> recv_fds(int sockfd);
int recv_fd(int sockfd);
int send_fds(int sockfd, const int *fds, int cnt);
int send_fd(int sockfd, int fd);
int read_int(int fd);
int read_int_be(int fd);
void write_int(int fd, int val);

View File

@ -7,7 +7,7 @@
#![allow(clippy::missing_safety_doc)]
use base::Utf8CStr;
use daemon::{daemon_entry, get_magiskd, MagiskD};
use daemon::{daemon_entry, get_magiskd, recv_fd, recv_fds, send_fd, send_fds, MagiskD};
use db::get_default_db_settings;
use logging::{android_logging, setup_logfile, zygisk_close_logd, zygisk_get_logd, zygisk_logging};
use mount::{clean_mounts, find_preinit_device, revert_unmount, setup_mounts};
@ -195,6 +195,10 @@ pub mod ffi {
unsafe fn persist_get_props(prop_cb: Pin<&mut PropCb>);
unsafe fn persist_delete_prop(name: Utf8CStrRef) -> bool;
unsafe fn persist_set_prop(name: Utf8CStrRef, value: Utf8CStrRef) -> bool;
fn send_fd(socket: i32, fd: i32) -> bool;
fn send_fds(socket: i32, fds: &[i32]) -> bool;
fn recv_fd(socket: i32) -> i32;
fn recv_fds(socket: i32) -> Vec<i32>;
#[namespace = "rust"]
fn daemon_entry();

View File

@ -19,133 +19,6 @@ bool get_client_cred(int fd, sock_cred *cred) {
return true;
}
static int send_fds(int sockfd, void *cmsgbuf, size_t bufsz, const int *fds, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
if (cnt) {
msg.msg_control = cmsgbuf;
msg.msg_controllen = bufsz;
cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof(int) * cnt);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * cnt);
}
return xsendmsg(sockfd, &msg, 0);
}
int send_fds(int sockfd, const int *fds, int cnt) {
if (cnt == 0) {
return send_fds(sockfd, nullptr, 0, nullptr, 0);
}
vector<char> cmsgbuf;
cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt));
return send_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), fds, cnt);
}
int send_fd(int sockfd, int fd) {
if (fd < 0) {
return send_fds(sockfd, nullptr, 0, nullptr, 0);
}
char cmsgbuf[CMSG_SPACE(sizeof(int))];
return send_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), &fd, 1);
}
static void *recv_fds(int sockfd, char *cmsgbuf, size_t bufsz, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cmsgbuf,
.msg_controllen = bufsz
};
xrecvmsg(sockfd, &msg, MSG_WAITALL);
if (msg.msg_controllen != bufsz) {
LOGE("recv_fd: msg_flags = %d, msg_controllen(%zu) != %zu\n",
msg.msg_flags, msg.msg_controllen, bufsz);
return nullptr;
}
cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg == nullptr) {
LOGE("recv_fd: cmsg == nullptr\n");
return nullptr;
}
if (cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt)) {
LOGE("recv_fd: cmsg_len(%zu) != %zu\n", cmsg->cmsg_len, CMSG_LEN(sizeof(int) * cnt));
return nullptr;
}
if (cmsg->cmsg_level != SOL_SOCKET) {
LOGE("recv_fd: cmsg_level != SOL_SOCKET\n");
return nullptr;
}
if (cmsg->cmsg_type != SCM_RIGHTS) {
LOGE("recv_fd: cmsg_type != SCM_RIGHTS\n");
return nullptr;
}
return CMSG_DATA(cmsg);
}
vector<int> recv_fds(int sockfd) {
vector<int> results;
// Peek fd count to allocate proper buffer
int cnt;
recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK);
if (cnt == 0) {
// Consume data
recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL);
return results;
}
vector<char> cmsgbuf;
cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt));
void *data = recv_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), cnt);
if (data == nullptr)
return results;
results.resize(cnt);
memcpy(results.data(), data, sizeof(int) * cnt);
return results;
}
int recv_fd(int sockfd) {
// Peek fd count
int cnt;
recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK);
if (cnt == 0) {
// Consume data
recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL);
return -1;
}
char cmsgbuf[CMSG_SPACE(sizeof(int))];
void *data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1);
if (data == nullptr)
return -1;
int result;
memcpy(&result, data, sizeof(int));
return result;
}
int read_int(int fd) {
int val;
if (xxread(fd, &val, sizeof(val)) != sizeof(val))

View File

@ -68,7 +68,7 @@ void MagiskD::connect_zygiskd(int client) const noexcept {
}
close(fds[1]);
rust::Vec<int> module_fds = get_module_fds(is_64_bit);
send_fds(zygiskd_socket, module_fds.data(), module_fds.size());
send_fds(zygiskd_socket, rust::Slice<const int>(module_fds));
// Wait for ack
if (read_int(zygiskd_socket) != 0) {
LOGE("zygiskd startup error\n");

View File

@ -29,7 +29,7 @@ static void zygiskd(int socket) {
using comp_entry = void(*)(int);
vector<comp_entry> modules;
{
vector<int> module_fds = recv_fds(socket);
auto module_fds = recv_fds(socket);
for (int fd : module_fds) {
comp_entry entry = nullptr;
struct stat s{};

View File

@ -208,7 +208,7 @@ bool ZygiskContext::plt_hook_commit() {
// -----------------------------------------------------------------
int ZygiskContext::get_module_info(int uid, std::vector<int> &fds) {
int ZygiskContext::get_module_info(int uid, rust::Vec<int> &fds) {
if (int fd = zygisk_request(+ZygiskRequest::GetInfo); fd >= 0) {
write_int(fd, uid);
write_string(fd, process);
@ -331,7 +331,7 @@ void ZygiskContext::fork_post() {
sigmask(SIG_UNBLOCK, SIGCHLD);
}
void ZygiskContext::run_modules_pre(vector<int> &fds) {
void ZygiskContext::run_modules_pre(rust::Vec<int> &fds) {
for (int i = 0; i < fds.size(); ++i) {
owned_fd fd = fds[i];
struct stat s{};
@ -386,7 +386,7 @@ void ZygiskContext::run_modules_post() {
void ZygiskContext::app_specialize_pre() {
flags |= APP_SPECIALIZE;
vector<int> module_fds;
rust::Vec<int> module_fds;
owned_fd fd = get_module_info(args.app->uid, module_fds);
if ((info_flags & UNMOUNT_MASK) == UNMOUNT_MASK) {
ZLOGI("[%s] is on the denylist\n", process);
@ -407,7 +407,7 @@ void ZygiskContext::app_specialize_post() {
}
void ZygiskContext::server_specialize_pre() {
vector<int> module_fds;
rust::Vec<int> module_fds;
if (owned_fd fd = get_module_info(1000, module_fds); fd >= 0) {
if (module_fds.empty()) {
write_int(fd, 0);

View File

@ -261,7 +261,7 @@ struct ZygiskContext {
ZygiskContext(JNIEnv *env, void *args);
~ZygiskContext();
void run_modules_pre(std::vector<int> &fds);
void run_modules_pre(rust::Vec<int> &fds);
void run_modules_post();
DCL_PRE_POST(fork)
DCL_PRE_POST(app_specialize)
@ -270,7 +270,7 @@ struct ZygiskContext {
DCL_PRE_POST(nativeSpecializeAppProcess)
DCL_PRE_POST(nativeForkSystemServer)
int get_module_info(int uid, std::vector<int> &fds);
int get_module_info(int uid, rust::Vec<int> &fds);
void sanitize_fds();
bool exempt_fd(int fd);
bool can_exempt_fd() const;