Use argument binding for query

This commit is contained in:
topjohnwu 2025-01-02 01:04:44 -08:00 committed by John Wu
parent 2722875190
commit 8e1a44e7eb
6 changed files with 115 additions and 115 deletions

View File

@ -8,39 +8,28 @@
#include <core.hpp> #include <core.hpp>
#define DB_VERSION 12 #define DB_VERSION 12
#define DB_VERSION_STR "12"
using namespace std; using namespace std;
#define DBLOGV(...) #define DBLOGV(...)
//#define DBLOGV(...) LOGD("magiskdb: " __VA_ARGS__) //#define DBLOGV(...) LOGD("magiskdb: " __VA_ARGS__)
struct db_result { #define sql_chk_log(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) { \
db_result() = default; LOGE("sqlite3(db.cpp:%d): %s\n", __LINE__, sqlite3_errstr(rc)); \
db_result(const char *s) : err(s) {} return false; \
db_result(int code) : err(code == SQLITE_OK ? "" : (sqlite3_errstr(code) ?: "")) {} }
operator bool() {
if (!err.empty()) { static bool open_and_init_db_impl(sqlite3 **dbOut) {
LOGE("sqlite3: %s\n", err.data()); if (!load_sqlite()) {
LOGE("sqlite3: Cannot load libsqlite.so\n");
return false; return false;
} }
return true;
}
private:
string err;
};
static int sql_exec(sqlite3 *db, const char *sql, sql_exec_callback callback = nullptr, void *v = nullptr) {
return sql_exec(db, sql, nullptr, nullptr, callback, v);
}
static db_result open_and_init_db_impl(sqlite3 **dbOut) {
if (!load_sqlite())
return "Cannot load libsqlite.so";
unique_ptr<sqlite3, decltype(sqlite3_close)> db(nullptr, sqlite3_close); unique_ptr<sqlite3, decltype(sqlite3_close)> db(nullptr, sqlite3_close);
{ {
sqlite3 *sql; sqlite3 *sql;
fn_run_ret(sqlite3_open_v2, MAGISKDB, &sql, sql_chk_log(sqlite3_open_v2, MAGISKDB, &sql,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr); SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr);
db.reset(sql); db.reset(sql);
} }
@ -50,10 +39,11 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
auto ver_cb = [](void *ver, auto, DbValues &data) { auto ver_cb = [](void *ver, auto, DbValues &data) {
*static_cast<int *>(ver) = data.get_int(0); *static_cast<int *>(ver) = data.get_int(0);
}; };
fn_run_ret(sql_exec, db.get(), "PRAGMA user_version", ver_cb, &ver); sql_chk_log(sql_exec, db.get(), "PRAGMA user_version", nullptr, nullptr, ver_cb, &ver);
if (ver > DB_VERSION) { if (ver > DB_VERSION) {
// Don't support downgrading database // Don't support downgrading database
return "Downgrading database is not supported"; LOGE("sqlite3: Downgrading database is not supported\n");
return false;
} }
auto create_policy = [&] { auto create_policy = [&] {
@ -90,17 +80,17 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
// 12: rebuild table `policies` to drop column `package_name` // 12: rebuild table `policies` to drop column `package_name`
if (/* 0, 1, 2, 3, 4, 5, 6 */ ver <= 6) { if (/* 0, 1, 2, 3, 4, 5, 6 */ ver <= 6) {
fn_run_ret(create_policy); sql_chk_log(create_policy);
fn_run_ret(create_settings); sql_chk_log(create_settings);
fn_run_ret(create_strings); sql_chk_log(create_strings);
fn_run_ret(create_denylist); sql_chk_log(create_denylist);
// Directly jump to latest // Directly jump to latest
ver = DB_VERSION; ver = DB_VERSION;
upgrade = true; upgrade = true;
} }
if (ver == 7) { if (ver == 7) {
fn_run_ret(sql_exec, db.get(), sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;" "BEGIN TRANSACTION;"
"ALTER TABLE hidelist RENAME TO hidelist_tmp;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;"
"CREATE TABLE IF NOT EXISTS hidelist " "CREATE TABLE IF NOT EXISTS hidelist "
@ -113,7 +103,7 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
upgrade = true; upgrade = true;
} }
if (ver == 8) { if (ver == 8) {
fn_run_ret(sql_exec, db.get(), sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;" "BEGIN TRANSACTION;"
"ALTER TABLE hidelist RENAME TO hidelist_tmp;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;"
"CREATE TABLE IF NOT EXISTS hidelist " "CREATE TABLE IF NOT EXISTS hidelist "
@ -125,20 +115,20 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
upgrade = true; upgrade = true;
} }
if (ver == 9) { if (ver == 9) {
fn_run_ret(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr); sql_chk_log(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr);
ver = 10; ver = 10;
upgrade = true; upgrade = true;
} }
if (ver == 10) { if (ver == 10) {
fn_run_ret(sql_exec, db.get(), sql_chk_log(sql_exec, db.get(),
"DROP TABLE IF EXISTS hidelist;" "DROP TABLE IF EXISTS hidelist;"
"DELETE FROM settings WHERE key='magiskhide';"); "DELETE FROM settings WHERE key='magiskhide';");
fn_run_ret(create_denylist); sql_chk_log(create_denylist);
ver = 11; ver = 11;
upgrade = true; upgrade = true;
} }
if (ver == 11) { if (ver == 11) {
fn_run_ret(sql_exec, db.get(), sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;" "BEGIN TRANSACTION;"
"ALTER TABLE policies RENAME TO policies_tmp;" "ALTER TABLE policies RENAME TO policies_tmp;"
"CREATE TABLE IF NOT EXISTS policies " "CREATE TABLE IF NOT EXISTS policies "
@ -154,20 +144,16 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
if (upgrade) { if (upgrade) {
// Set version // Set version
char query[32]; sql_chk_log(sql_exec, db.get(), "PRAGMA user_version=" DB_VERSION_STR);
sprintf(query, "PRAGMA user_version=%d", ver);
fn_run_ret(sql_exec, db.get(), query);
} }
*dbOut = db.release(); *dbOut = db.release();
return {}; return true;
} }
sqlite3 *open_and_init_db() { sqlite3 *open_and_init_db() {
sqlite3 *db = nullptr; sqlite3 *db = nullptr;
if (!open_and_init_db_impl(&db)) return open_and_init_db_impl(&db) ? db : nullptr;
return nullptr;
return db;
} }
static sqlite3 *get_db() { static sqlite3 *get_db() {
@ -183,13 +169,17 @@ static sqlite3 *get_db() {
return db; return db;
} }
bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn) { bool db_exec(const char *sql, DbArgs args, db_exec_callback exec_fn) {
using db_bind_callback = std::function<int(int, DbStatement&)>;
if (sqlite3 *db = get_db()) { if (sqlite3 *db = get_db()) {
db_bind_callback bind_fn = {};
sql_bind_callback bind_cb = nullptr; sql_bind_callback bind_cb = nullptr;
if (bind_fn) { if (!args.empty()) {
bind_cb = [](void *v, int index, DbStatement &stmt) { bind_fn = std::ref(args);
bind_cb = [](void *v, int index, DbStatement &stmt) -> int {
auto fn = static_cast<db_bind_callback*>(v); auto fn = static_cast<db_bind_callback*>(v);
fn->operator()(index, stmt); return fn->operator()(index, stmt);
}; };
} }
sql_exec_callback exec_cb = nullptr; sql_exec_callback exec_cb = nullptr;
@ -199,52 +189,41 @@ bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn
fn->operator()(columns, data); fn->operator()(columns, data);
}; };
} }
db_result res = sql_exec(db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn); sql_chk_log(sql_exec, db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn);
return res; return true;
} }
return false; return false;
} }
int get_db_settings(db_settings &cfg, int key) { bool get_db_settings(db_settings &cfg, int key) {
bool res;
if (key >= 0) { if (key >= 0) {
char query[128]; return db_exec("SELECT * FROM settings WHERE key=?", { DB_SETTING_KEYS[key] }, cfg);
ssprintf(query, sizeof(query), "SELECT * FROM settings WHERE key='%s'", DB_SETTING_KEYS[key]);
res = db_exec(query, cfg);
} else { } else {
res = db_exec("SELECT * FROM settings", cfg); return db_exec("SELECT * FROM settings", {}, cfg);
} }
return res ? 0 : 1;
} }
int set_db_settings(int key, int value) { bool set_db_settings(int key, int value) {
char sql[128]; return db_exec(
ssprintf(sql, sizeof(sql), "INSERT OR REPLACE INTO settings VALUES ('%s', %d)", "INSERT OR REPLACE INTO settings (key,value) VALUES(?,?)",
DB_SETTING_KEYS[key], value); { DB_SETTING_KEYS[key], value });
return db_exec(sql) ? 0 : 1;
} }
int get_db_strings(db_strings &str, int key) { bool get_db_strings(db_strings &str, int key) {
bool res;
if (key >= 0) { if (key >= 0) {
char query[128]; return db_exec("SELECT * FROM strings WHERE key=?", { DB_STRING_KEYS[key] }, str);
ssprintf(query, sizeof(query), "SELECT * FROM strings WHERE key='%s'", DB_STRING_KEYS[key]);
res = db_exec(query, str);
} else { } else {
res = db_exec("SELECT * FROM strings", str); return db_exec("SELECT * FROM strings", {}, str);
} }
return res ? 0 : 1;
} }
void rm_db_strings(int key) { bool rm_db_strings(int key) {
char query[128]; return db_exec("DELETE FROM strings WHERE key=?", { DB_STRING_KEYS[key] });
ssprintf(query, sizeof(query), "DELETE FROM strings WHERE key == '%s'", DB_STRING_KEYS[key]);
db_exec(query);
} }
void exec_sql(owned_fd client) { void exec_sql(owned_fd client) {
string sql = read_string(client); string sql = read_string(client);
db_exec(sql.data(), [fd = (int) client](StringSlice columns, DbValues &data) { db_exec(sql.data(), {}, [fd = (int) client](StringSlice columns, DbValues &data) {
string out; string out;
for (int i = 0; i < columns.size(); ++i) { for (int i = 0; i < columns.size(); ++i) {
if (i != 0) out += '|'; if (i != 0) out += '|';
@ -306,3 +285,16 @@ void db_strings::operator()(StringSlice columns, DbValues &data) {
su_manager = val; su_manager = val;
} }
} }
int DbArgs::operator()(int index, DbStatement &stmt) {
if (curr < args.size()) {
const auto &arg = args[curr++];
switch (arg.type) {
case DbArg::INT:
return stmt.bind_int64(index, arg.int_val);
case DbArg::TEXT:
return stmt.bind_text(index, arg.str_val);
}
}
return SQLITE_OK;
}

View File

@ -222,7 +222,7 @@ static bool ensure_data() {
LOGI("denylist: initializing internal data structures\n"); LOGI("denylist: initializing internal data structures\n");
default_new(pkg_to_procs_); default_new(pkg_to_procs_);
bool res = db_exec("SELECT * FROM denylist", [](StringSlice columns, DbValues &data) { bool res = db_exec("SELECT * FROM denylist", {}, [](StringSlice columns, DbValues &data) {
const char *package_name; const char *package_name;
const char *process; const char *process;
for (int i = 0; i < columns.size(); ++i) { for (int i = 0; i < columns.size(); ++i) {

View File

@ -88,23 +88,41 @@ struct db_strings {
********************/ ********************/
using db_exec_callback = std::function<void(StringSlice, DbValues&)>; using db_exec_callback = std::function<void(StringSlice, DbValues&)>;
using db_bind_callback = std::function<void(int, DbStatement&)>;
int get_db_settings(db_settings &cfg, int key = -1); struct DbArg {
int set_db_settings(int key, int value); enum {
int get_db_strings(db_strings &str, int key = -1); INT,
void rm_db_strings(int key); TEXT,
} type;
union {
int64_t int_val;
rust::Str str_val;
};
DbArg(int64_t v) : type(INT), int_val(v) {}
DbArg(const char *v) : type(TEXT), str_val(v) {}
};
struct DbArgs {
DbArgs() : curr(0) {}
DbArgs(std::initializer_list<DbArg> list) : args(list), curr(0) {}
int operator()(int index, DbStatement &stmt);
bool empty() const { return args.empty(); }
private:
std::vector<DbArg> args;
size_t curr;
};
bool get_db_settings(db_settings &cfg, int key = -1);
bool set_db_settings(int key, int value);
bool get_db_strings(db_strings &str, int key = -1);
bool rm_db_strings(int key);
void exec_sql(owned_fd client); void exec_sql(owned_fd client);
bool db_exec(const char *sql, db_bind_callback bind_fn = {}, db_exec_callback exec_fn = {}); bool db_exec(const char *sql, DbArgs args = {}, db_exec_callback exec_fn = {});
static inline bool db_exec(const char *sql, db_exec_callback exec_fn) {
return db_exec(sql, {}, std::move(exec_fn));
}
template<typename T> template<typename T>
concept DbData = requires(T t, StringSlice s, DbValues &v) { t(s, v); }; concept DbData = requires(T t, StringSlice s, DbValues &v) { t(s, v); };
template<DbData T> template<DbData T>
bool db_exec(const char *sql, T &data) { bool db_exec(const char *sql, DbArgs args, T &data) {
return db_exec(sql, (db_exec_callback) std::ref(data)); return db_exec(sql, std::move(args), (db_exec_callback) std::ref(data));
} }

View File

@ -17,27 +17,24 @@ extern int (*sqlite3_open_v2)(const char *filename, sqlite3 **ppDb, int flags, c
extern int (*sqlite3_close)(sqlite3 *db); extern int (*sqlite3_close)(sqlite3 *db);
extern const char *(*sqlite3_errstr)(int); extern const char *(*sqlite3_errstr)(int);
// Transparent wrapper of sqlite3_stmt // Transparent wrappers of sqlite3_stmt
struct DbValues { struct DbValues {
const char *get_text(int index); const char *get_text(int index);
int get_int(int index); int get_int(int index);
~DbValues() = delete; ~DbValues() = delete;
}; };
struct DbStatement { struct DbStatement {
int bind_text(int index, const char *val);
int bind_text(int index, rust::Str val); int bind_text(int index, rust::Str val);
int bind_int64(int index, int64_t val); int bind_int64(int index, int64_t val);
~DbStatement() = delete;
}; };
using StringSlice = rust::Slice<rust::String>; using StringSlice = rust::Slice<rust::String>;
using sql_bind_callback = void(*)(void*, int, DbStatement&); using sql_bind_callback = int(*)(void*, int, DbStatement&);
using sql_exec_callback = void(*)(void*, StringSlice, DbValues&); using sql_exec_callback = void(*)(void*, StringSlice, DbValues&);
#define fn_run_ret(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc
bool load_sqlite(); bool load_sqlite();
sqlite3 *open_and_init_db(); sqlite3 *open_and_init_db();
int sql_exec(sqlite3 *db, rust::Str zSql, int sql_exec(sqlite3 *db, rust::Str zSql,
sql_bind_callback bind_cb, void *bind_cookie, sql_bind_callback bind_cb = nullptr, void *bind_cookie = nullptr,
sql_exec_callback exec_cb, void *exec_cookie); sql_exec_callback exec_cb = nullptr, void *exec_cookie = nullptr);

View File

@ -91,8 +91,10 @@ bool load_sqlite() {
} }
using StringVec = rust::Vec<rust::String>; using StringVec = rust::Vec<rust::String>;
using sql_bind_callback_real = int(*)(void*, int, sqlite3_stmt*);
using sql_exec_callback_real = void(*)(void*, StringSlice, sqlite3_stmt*); using sql_exec_callback_real = void(*)(void*, StringSlice, sqlite3_stmt*);
using sql_bind_callback_real = void(*)(void*, int, sqlite3_stmt*);
#define sql_chk(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc
int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_cookie, sql_exec_callback exec_cb, void *exec_cookie) { int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_cookie, sql_exec_callback exec_cb, void *exec_cookie) {
const char *sql = zSql.begin(); const char *sql = zSql.begin();
@ -102,7 +104,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_
// Step 1: prepare statement // Step 1: prepare statement
{ {
sqlite3_stmt *st = nullptr; sqlite3_stmt *st = nullptr;
fn_run_ret(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql); sql_chk(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql);
if (st == nullptr) continue; if (st == nullptr) continue;
stmt.reset(st); stmt.reset(st);
} }
@ -112,7 +114,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_
if (int count = sqlite3_bind_parameter_count(stmt.get())) { if (int count = sqlite3_bind_parameter_count(stmt.get())) {
auto real_cb = reinterpret_cast<sql_bind_callback_real>(bind_cb); auto real_cb = reinterpret_cast<sql_bind_callback_real>(bind_cb);
for (int i = 1; i <= count; ++i) { for (int i = 1; i <= count; ++i) {
real_cb(bind_cookie, i, stmt.get()); sql_chk(real_cb, bind_cookie, i, stmt.get());
} }
} }
} }
@ -155,7 +157,3 @@ int DbStatement::bind_int64(int index, int64_t val) {
int DbStatement::bind_text(int index, rust::Str val) { int DbStatement::bind_text(int index, rust::Str val) {
return sqlite3_bind_text(reinterpret_cast<sqlite3_stmt*>(this), index, val.data(), val.size(), nullptr); return sqlite3_bind_text(reinterpret_cast<sqlite3_stmt*>(this), index, val.data(), val.size(), nullptr);
} }
int DbStatement::bind_text(int index, const char *val) {
return sqlite3_bind_text(reinterpret_cast<sqlite3_stmt*>(this), index, val, -1, nullptr);
}

View File

@ -76,11 +76,10 @@ void su_info::check_db() {
} }
if (eval_uid > 0) { if (eval_uid > 0) {
char query[256]; bool res = db_exec(
ssprintf(query, sizeof(query),
"SELECT policy, logging, notification FROM policies " "SELECT policy, logging, notification FROM policies "
"WHERE uid=%d AND (until=0 OR until>%li)", eval_uid, time(nullptr)); "WHERE uid=? AND (until=0 OR until>?)", { eval_uid, time(nullptr) }, access);
if (!db_exec(query, access)) if (!res)
return; return;
} }
@ -127,15 +126,11 @@ bool uid_granted_root(int uid) {
break; break;
} }
char query[256]; bool granted = false;
ssprintf(query, sizeof(query), db_exec("SELECT policy FROM policies WHERE uid=? AND (until=0 OR until>?)",
"SELECT policy FROM policies WHERE uid=%d AND (until=0 OR until>%li)", { uid, time(nullptr) },
uid, time(nullptr)); [&](auto, DbValues &data) { granted = data.get_int(0) == ALLOW; });
su_access access; return granted;
access.policy = QUERY;
if (!db_exec(query, access))
return false;
return access.policy == ALLOW;
} }
struct policy_uid_list : public vector<int> { struct policy_uid_list : public vector<int> {
@ -147,7 +142,7 @@ struct policy_uid_list : public vector<int> {
void prune_su_access() { void prune_su_access() {
cached.reset(); cached.reset();
policy_uid_list uids; policy_uid_list uids;
if (!db_exec("SELECT uid FROM policies", uids)) if (!db_exec("SELECT uid FROM policies", {}, uids))
return; return;
vector<bool> app_no_list = get_app_no_list(); vector<bool> app_no_list = get_app_no_list();
vector<int> rm_uids; vector<int> rm_uids;