diff --git a/native/src/core/daemon.rs b/native/src/core/daemon.rs index 4862feeab..cc86ffbae 100644 --- a/native/src/core/daemon.rs +++ b/native/src/core/daemon.rs @@ -7,14 +7,15 @@ use crate::package::ManagerInfo; use base::libc::{O_CLOEXEC, O_RDONLY}; use base::{ cstr, info, libc, open_fd, warn, BufReadExt, Directory, FsPath, FsPathBuf, LoggedResult, - ReadExt, Utf8CStr, Utf8CStrBufArr, WriteExt, + ReadExt, ResultExt, Utf8CStr, Utf8CStrBufArr, WriteExt, }; use bit_set::BitSet; use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable}; use std::fs::File; use std::io; use std::io::{BufReader, ErrorKind, IoSlice, IoSliceMut, Read, Write}; -use std::os::fd::{FromRawFd, OwnedFd, RawFd}; +use std::mem::ManuallyDrop; +use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream}; use std::sync::{Mutex, OnceLock}; @@ -386,3 +387,29 @@ impl UnixSocketExt for UnixStream { Ok(fds) } } + +pub fn send_fd(socket: RawFd, fd: RawFd) -> bool { + let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) }); + socket.send_fds(&[fd]).log().is_ok() +} + +pub fn send_fds(socket: RawFd, fds: &[RawFd]) -> bool { + let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) }); + socket.send_fds(fds).log().is_ok() +} + +pub fn recv_fd(socket: RawFd) -> RawFd { + let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) }); + socket + .recv_fd() + .log() + .unwrap_or(None) + .map_or(-1, IntoRawFd::into_raw_fd) +} + +pub fn recv_fds(socket: RawFd) -> Vec { + let mut socket = ManuallyDrop::new(unsafe { UnixStream::from_raw_fd(socket) }); + let fds = socket.recv_fds().log().unwrap_or(Vec::new()); + // SAFETY: OwnedFd and RawFd has the same layout + unsafe { std::mem::transmute(fds) } +} diff --git a/native/src/core/include/socket.hpp b/native/src/core/include/socket.hpp index 973991a6b..c51c6650d 100644 --- a/native/src/core/include/socket.hpp +++ b/native/src/core/include/socket.hpp @@ -11,10 +11,6 @@ struct sock_cred : public ucred { }; bool get_client_cred(int fd, sock_cred *cred); -std::vector recv_fds(int sockfd); -int recv_fd(int sockfd); -int send_fds(int sockfd, const int *fds, int cnt); -int send_fd(int sockfd, int fd); int read_int(int fd); int read_int_be(int fd); void write_int(int fd, int val); diff --git a/native/src/core/lib.rs b/native/src/core/lib.rs index 6a065f191..cc8b4aaed 100644 --- a/native/src/core/lib.rs +++ b/native/src/core/lib.rs @@ -7,7 +7,7 @@ #![allow(clippy::missing_safety_doc)] use base::Utf8CStr; -use daemon::{daemon_entry, get_magiskd, MagiskD}; +use daemon::{daemon_entry, get_magiskd, recv_fd, recv_fds, send_fd, send_fds, MagiskD}; use db::get_default_db_settings; use logging::{android_logging, setup_logfile, zygisk_close_logd, zygisk_get_logd, zygisk_logging}; use mount::{clean_mounts, find_preinit_device, revert_unmount, setup_mounts}; @@ -195,6 +195,10 @@ pub mod ffi { unsafe fn persist_get_props(prop_cb: Pin<&mut PropCb>); unsafe fn persist_delete_prop(name: Utf8CStrRef) -> bool; unsafe fn persist_set_prop(name: Utf8CStrRef, value: Utf8CStrRef) -> bool; + fn send_fd(socket: i32, fd: i32) -> bool; + fn send_fds(socket: i32, fds: &[i32]) -> bool; + fn recv_fd(socket: i32) -> i32; + fn recv_fds(socket: i32) -> Vec; #[namespace = "rust"] fn daemon_entry(); diff --git a/native/src/core/socket.cpp b/native/src/core/socket.cpp index a499b10a7..30640d43c 100644 --- a/native/src/core/socket.cpp +++ b/native/src/core/socket.cpp @@ -19,133 +19,6 @@ bool get_client_cred(int fd, sock_cred *cred) { return true; } -static int send_fds(int sockfd, void *cmsgbuf, size_t bufsz, const int *fds, int cnt) { - iovec iov = { - .iov_base = &cnt, - .iov_len = sizeof(cnt), - }; - msghdr msg = { - .msg_iov = &iov, - .msg_iovlen = 1, - }; - - if (cnt) { - msg.msg_control = cmsgbuf; - msg.msg_controllen = bufsz; - cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof(int) * cnt); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - - memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * cnt); - } - - return xsendmsg(sockfd, &msg, 0); -} - -int send_fds(int sockfd, const int *fds, int cnt) { - if (cnt == 0) { - return send_fds(sockfd, nullptr, 0, nullptr, 0); - } - vector cmsgbuf; - cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt)); - return send_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), fds, cnt); -} - -int send_fd(int sockfd, int fd) { - if (fd < 0) { - return send_fds(sockfd, nullptr, 0, nullptr, 0); - } - char cmsgbuf[CMSG_SPACE(sizeof(int))]; - return send_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), &fd, 1); -} - -static void *recv_fds(int sockfd, char *cmsgbuf, size_t bufsz, int cnt) { - iovec iov = { - .iov_base = &cnt, - .iov_len = sizeof(cnt), - }; - msghdr msg = { - .msg_iov = &iov, - .msg_iovlen = 1, - .msg_control = cmsgbuf, - .msg_controllen = bufsz - }; - - xrecvmsg(sockfd, &msg, MSG_WAITALL); - if (msg.msg_controllen != bufsz) { - LOGE("recv_fd: msg_flags = %d, msg_controllen(%zu) != %zu\n", - msg.msg_flags, msg.msg_controllen, bufsz); - return nullptr; - } - - cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - if (cmsg == nullptr) { - LOGE("recv_fd: cmsg == nullptr\n"); - return nullptr; - } - if (cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt)) { - LOGE("recv_fd: cmsg_len(%zu) != %zu\n", cmsg->cmsg_len, CMSG_LEN(sizeof(int) * cnt)); - return nullptr; - } - if (cmsg->cmsg_level != SOL_SOCKET) { - LOGE("recv_fd: cmsg_level != SOL_SOCKET\n"); - return nullptr; - } - if (cmsg->cmsg_type != SCM_RIGHTS) { - LOGE("recv_fd: cmsg_type != SCM_RIGHTS\n"); - return nullptr; - } - - return CMSG_DATA(cmsg); -} - -vector recv_fds(int sockfd) { - vector results; - - // Peek fd count to allocate proper buffer - int cnt; - recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK); - if (cnt == 0) { - // Consume data - recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL); - return results; - } - - vector cmsgbuf; - cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt)); - - void *data = recv_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), cnt); - if (data == nullptr) - return results; - - results.resize(cnt); - memcpy(results.data(), data, sizeof(int) * cnt); - - return results; -} - -int recv_fd(int sockfd) { - // Peek fd count - int cnt; - recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK); - if (cnt == 0) { - // Consume data - recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL); - return -1; - } - - char cmsgbuf[CMSG_SPACE(sizeof(int))]; - - void *data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1); - if (data == nullptr) - return -1; - - int result; - memcpy(&result, data, sizeof(int)); - return result; -} - int read_int(int fd) { int val; if (xxread(fd, &val, sizeof(val)) != sizeof(val)) diff --git a/native/src/core/zygisk/entry.cpp b/native/src/core/zygisk/entry.cpp index b0fd044b9..a49a9a1d6 100644 --- a/native/src/core/zygisk/entry.cpp +++ b/native/src/core/zygisk/entry.cpp @@ -68,7 +68,7 @@ void MagiskD::connect_zygiskd(int client) const noexcept { } close(fds[1]); rust::Vec module_fds = get_module_fds(is_64_bit); - send_fds(zygiskd_socket, module_fds.data(), module_fds.size()); + send_fds(zygiskd_socket, rust::Slice(module_fds)); // Wait for ack if (read_int(zygiskd_socket) != 0) { LOGE("zygiskd startup error\n"); diff --git a/native/src/core/zygisk/main.cpp b/native/src/core/zygisk/main.cpp index f66a857f7..ea93bf20b 100644 --- a/native/src/core/zygisk/main.cpp +++ b/native/src/core/zygisk/main.cpp @@ -29,7 +29,7 @@ static void zygiskd(int socket) { using comp_entry = void(*)(int); vector modules; { - vector module_fds = recv_fds(socket); + auto module_fds = recv_fds(socket); for (int fd : module_fds) { comp_entry entry = nullptr; struct stat s{}; diff --git a/native/src/core/zygisk/module.cpp b/native/src/core/zygisk/module.cpp index dc5ad1fa8..a9ce48593 100644 --- a/native/src/core/zygisk/module.cpp +++ b/native/src/core/zygisk/module.cpp @@ -208,7 +208,7 @@ bool ZygiskContext::plt_hook_commit() { // ----------------------------------------------------------------- -int ZygiskContext::get_module_info(int uid, std::vector &fds) { +int ZygiskContext::get_module_info(int uid, rust::Vec &fds) { if (int fd = zygisk_request(+ZygiskRequest::GetInfo); fd >= 0) { write_int(fd, uid); write_string(fd, process); @@ -331,7 +331,7 @@ void ZygiskContext::fork_post() { sigmask(SIG_UNBLOCK, SIGCHLD); } -void ZygiskContext::run_modules_pre(vector &fds) { +void ZygiskContext::run_modules_pre(rust::Vec &fds) { for (int i = 0; i < fds.size(); ++i) { owned_fd fd = fds[i]; struct stat s{}; @@ -386,7 +386,7 @@ void ZygiskContext::run_modules_post() { void ZygiskContext::app_specialize_pre() { flags |= APP_SPECIALIZE; - vector module_fds; + rust::Vec module_fds; owned_fd fd = get_module_info(args.app->uid, module_fds); if ((info_flags & UNMOUNT_MASK) == UNMOUNT_MASK) { ZLOGI("[%s] is on the denylist\n", process); @@ -407,7 +407,7 @@ void ZygiskContext::app_specialize_post() { } void ZygiskContext::server_specialize_pre() { - vector module_fds; + rust::Vec module_fds; if (owned_fd fd = get_module_info(1000, module_fds); fd >= 0) { if (module_fds.empty()) { write_int(fd, 0); diff --git a/native/src/core/zygisk/module.hpp b/native/src/core/zygisk/module.hpp index f7bfb4954..a686a2020 100644 --- a/native/src/core/zygisk/module.hpp +++ b/native/src/core/zygisk/module.hpp @@ -261,7 +261,7 @@ struct ZygiskContext { ZygiskContext(JNIEnv *env, void *args); ~ZygiskContext(); - void run_modules_pre(std::vector &fds); + void run_modules_pre(rust::Vec &fds); void run_modules_post(); DCL_PRE_POST(fork) DCL_PRE_POST(app_specialize) @@ -270,7 +270,7 @@ struct ZygiskContext { DCL_PRE_POST(nativeSpecializeAppProcess) DCL_PRE_POST(nativeForkSystemServer) - int get_module_info(int uid, std::vector &fds); + int get_module_info(int uid, rust::Vec &fds); void sanitize_fds(); bool exempt_fd(int fd); bool can_exempt_fd() const;