mirror of
https://github.com/revanced/revanced-patcher.git
synced 2025-05-09 16:44:25 +02:00
Improve example test testPatcher
and increase caching speed
This commit is contained in:
parent
81e0220d15
commit
5d146c362f
@ -7,7 +7,6 @@ import net.revanced.patcher.signature.Signature
|
|||||||
import net.revanced.patcher.util.Jar2ASM
|
import net.revanced.patcher.util.Jar2ASM
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
import java.util.jar.JarOutputStream
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The patcher. (docs WIP)
|
* The patcher. (docs WIP)
|
||||||
@ -20,12 +19,12 @@ class Patcher(
|
|||||||
private val input: InputStream,
|
private val input: InputStream,
|
||||||
signatures: Array<Signature>,
|
signatures: Array<Signature>,
|
||||||
) {
|
) {
|
||||||
val cache = Cache()
|
var cache: Cache
|
||||||
private val patches: MutableList<Patch> = mutableListOf()
|
private val patches: MutableList<Patch> = mutableListOf()
|
||||||
|
|
||||||
init {
|
init {
|
||||||
cache.classes.putAll(Jar2ASM.jar2asm(input))
|
val classes = Jar2ASM.jar2asm(input);
|
||||||
cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve())
|
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addPatches(vararg patches: Patch) {
|
fun addPatches(vararg patches: Patch) {
|
||||||
|
@ -2,14 +2,14 @@ package net.revanced.patcher.cache
|
|||||||
|
|
||||||
import org.objectweb.asm.tree.ClassNode
|
import org.objectweb.asm.tree.ClassNode
|
||||||
|
|
||||||
class Cache {
|
class Cache (
|
||||||
val classes: MutableMap<String, ClassNode> = mutableMapOf()
|
val classes: List<ClassNode>,
|
||||||
val methods: MethodMap = MethodMap()
|
val methods: MethodMap
|
||||||
}
|
)
|
||||||
|
|
||||||
class MethodMap : LinkedHashMap<String, PatchData>() {
|
class MethodMap : LinkedHashMap<String, PatchData>() {
|
||||||
override fun get(key: String): PatchData {
|
override fun get(key: String): PatchData {
|
||||||
return super.get(key) ?: throw MethodNotFoundException("Method $key not found in method cache")
|
return super.get(key) ?: throw MethodNotFoundException("Method $key was not found in the method cache")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,12 +4,12 @@ import org.objectweb.asm.tree.ClassNode
|
|||||||
import org.objectweb.asm.tree.MethodNode
|
import org.objectweb.asm.tree.MethodNode
|
||||||
|
|
||||||
data class PatchData(
|
data class PatchData(
|
||||||
val cls: ClassNode,
|
val declaringClass: ClassNode,
|
||||||
val method: MethodNode,
|
val method: MethodNode,
|
||||||
val sd: ScanData
|
val scanData: PatternScanData
|
||||||
)
|
)
|
||||||
|
|
||||||
data class ScanData(
|
data class PatternScanData(
|
||||||
val startIndex: Int,
|
val startIndex: Int,
|
||||||
val endIndex: Int
|
val endIndex: Int
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package net.revanced.patcher.resolver
|
package net.revanced.patcher.resolver
|
||||||
|
|
||||||
import mu.KotlinLogging
|
import mu.KotlinLogging
|
||||||
|
import net.revanced.patcher.cache.MethodMap
|
||||||
import net.revanced.patcher.cache.PatchData
|
import net.revanced.patcher.cache.PatchData
|
||||||
import net.revanced.patcher.cache.ScanData
|
import net.revanced.patcher.cache.PatternScanData
|
||||||
import net.revanced.patcher.signature.Signature
|
import net.revanced.patcher.signature.Signature
|
||||||
import net.revanced.patcher.util.ExtraTypes
|
import net.revanced.patcher.util.ExtraTypes
|
||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
@ -13,13 +14,13 @@ import org.objectweb.asm.tree.MethodNode
|
|||||||
private val logger = KotlinLogging.logger("MethodResolver")
|
private val logger = KotlinLogging.logger("MethodResolver")
|
||||||
|
|
||||||
internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) {
|
internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) {
|
||||||
fun resolve(): MutableMap<String, PatchData> {
|
fun resolve(): MethodMap {
|
||||||
val patchData = mutableMapOf<String, PatchData>()
|
val methodMap = MethodMap()
|
||||||
|
|
||||||
for ((classNode, methods) in classList) {
|
for ((classNode, methods) in classList) {
|
||||||
for (method in methods) {
|
for (method in methods) {
|
||||||
for (signature in signatures) {
|
for (signature in signatures) {
|
||||||
if (patchData.containsKey(signature.name)) { // method already found for this sig
|
if (methodMap.containsKey(signature.name)) { // method already found for this sig
|
||||||
logger.debug { "Sig ${signature.name} already found, skipping." }
|
logger.debug { "Sig ${signature.name} already found, skipping." }
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -30,10 +31,10 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.debug { "Method for sig ${signature.name} found!" }
|
logger.debug { "Method for sig ${signature.name} found!" }
|
||||||
patchData[signature.name] = PatchData(
|
methodMap[signature.name] = PatchData(
|
||||||
classNode,
|
classNode,
|
||||||
method,
|
method,
|
||||||
ScanData(
|
PatternScanData(
|
||||||
// sadly we cannot create contracts for a data class, so we must assert
|
// sadly we cannot create contracts for a data class, so we must assert
|
||||||
sr.startIndex!!,
|
sr.startIndex!!,
|
||||||
sr.endIndex!!
|
sr.endIndex!!
|
||||||
@ -44,11 +45,11 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (signature in signatures) {
|
for (signature in signatures) {
|
||||||
if (patchData.containsKey(signature.name)) continue
|
if (methodMap.containsKey(signature.name)) continue
|
||||||
logger.error { "Could not find method for sig ${signature.name}!" }
|
logger.error { "Could not find method for sig ${signature.name}!" }
|
||||||
}
|
}
|
||||||
|
|
||||||
return patchData
|
return methodMap
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {
|
private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {
|
||||||
|
@ -10,21 +10,20 @@ import java.util.jar.JarInputStream
|
|||||||
import java.util.jar.JarOutputStream
|
import java.util.jar.JarOutputStream
|
||||||
|
|
||||||
object Jar2ASM {
|
object Jar2ASM {
|
||||||
fun jar2asm(input: InputStream): Map<String, ClassNode> {
|
fun jar2asm(input: InputStream) = mutableListOf<ClassNode>().apply {
|
||||||
return buildMap {
|
val jar = JarInputStream(input)
|
||||||
val jar = JarInputStream(input)
|
|
||||||
while (true) {
|
while (true) {
|
||||||
val e = jar.nextJarEntry ?: break
|
val e = jar.nextJarEntry ?: break
|
||||||
if (e.name.endsWith(".class")) {
|
if (e.name.endsWith(".class")) {
|
||||||
val classNode = ClassNode()
|
val classNode = ClassNode()
|
||||||
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
|
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
|
||||||
this[e.name] = classNode
|
this.add(classNode)
|
||||||
}
|
}
|
||||||
jar.closeEntry()
|
jar.closeEntry()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fun asm2jar(input: InputStream, output: OutputStream, structure: Map<String, ClassNode>) {
|
|
||||||
|
fun asm2jar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
|
||||||
val jis = JarInputStream(input)
|
val jis = JarInputStream(input)
|
||||||
val jos = JarOutputStream(output)
|
val jos = JarOutputStream(output)
|
||||||
|
|
||||||
@ -33,10 +32,13 @@ object Jar2ASM {
|
|||||||
val next = jis.nextJarEntry ?: break
|
val next = jis.nextJarEntry ?: break
|
||||||
val e = JarEntry(next) // clone it, to not modify the input (if possible)
|
val e = JarEntry(next) // clone it, to not modify the input (if possible)
|
||||||
jos.putNextEntry(e)
|
jos.putNextEntry(e)
|
||||||
if (structure.containsKey(e.name)) {
|
|
||||||
|
val clazz = classes.singleOrNull {
|
||||||
|
clazz -> clazz.name == e.name
|
||||||
|
};
|
||||||
|
if (clazz != null) {
|
||||||
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
|
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
|
||||||
val cn = structure[e.name]!!
|
clazz.accept(cw)
|
||||||
cn.accept(cw)
|
|
||||||
jos.write(cw.toByteArray())
|
jos.write(cw.toByteArray())
|
||||||
} else {
|
} else {
|
||||||
jos.write(jis.readAllBytes())
|
jos.write(jis.readAllBytes())
|
||||||
|
@ -7,11 +7,8 @@ import net.revanced.patcher.util.ExtraTypes
|
|||||||
import net.revanced.patcher.writer.ASMWriter.setAt
|
import net.revanced.patcher.writer.ASMWriter.setAt
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.Opcodes.*
|
||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import org.objectweb.asm.tree.LdcInsnNode
|
import org.objectweb.asm.tree.*
|
||||||
import java.io.ByteArrayOutputStream
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
internal class PatcherTest {
|
internal class PatcherTest {
|
||||||
private val testSigs: Array<Signature> = arrayOf(
|
private val testSigs: Array<Signature> = arrayOf(
|
||||||
@ -46,14 +43,24 @@ internal class PatcherTest {
|
|||||||
patcher.addPatches(
|
patcher.addPatches(
|
||||||
Patch ("TestPatch") {
|
Patch ("TestPatch") {
|
||||||
// Get the method from the resolver cache
|
// Get the method from the resolver cache
|
||||||
val main = patcher.cache.methods["mainMethod"]
|
val mainMethod = patcher.cache.methods["mainMethod"]
|
||||||
// Get the instruction list
|
// Get the instruction list
|
||||||
val insn = main.method.instructions!!
|
val instructions = mainMethod.method.instructions!!
|
||||||
// Let's modify it, so it prints "Hello, ReVanced!"
|
// Let's modify it, so it prints "Hello, ReVanced!"
|
||||||
// Get the start index of our signature
|
// Get the start index of our opcode pattern
|
||||||
// This will be the index of the LDC instruction
|
// This will be the index of the LDC instruction
|
||||||
val startIndex = main.sd.startIndex
|
val startIndex = mainMethod.scanData.startIndex
|
||||||
insn.setAt(startIndex, LdcInsnNode("Hello, ReVanced!"))
|
// Create a new Ldc node and replace the LDC instruction
|
||||||
|
val stringNode = LdcInsnNode("Hello, ReVanced!");
|
||||||
|
instructions.setAt(startIndex, stringNode)
|
||||||
|
// Now lets print our string to the console output
|
||||||
|
// First create a list of instructions
|
||||||
|
val printCode = InsnList();
|
||||||
|
printCode.add(LdcInsnNode("Hello, ReVanced!"))
|
||||||
|
printCode.add(MethodInsnNode(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V"))
|
||||||
|
// Add the list after the second instruction by our pattern
|
||||||
|
instructions.insert(instructions[startIndex + 1], printCode)
|
||||||
|
|
||||||
// Finally, tell the patcher that this patch was a success.
|
// Finally, tell the patcher that this patch was a success.
|
||||||
// You can also return PatchResultError with a message.
|
// You can also return PatchResultError with a message.
|
||||||
// If an exception is thrown inside this function,
|
// If an exception is thrown inside this function,
|
||||||
@ -62,7 +69,9 @@ internal class PatcherTest {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Apply all patches loaded in the patcher
|
||||||
val result = patcher.applyPatches()
|
val result = patcher.applyPatches()
|
||||||
|
// You can check if an error occurred
|
||||||
for ((s, r) in result) {
|
for ((s, r) in result) {
|
||||||
if (r.isFailure) {
|
if (r.isFailure) {
|
||||||
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
|
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
|
||||||
@ -70,30 +79,30 @@ internal class PatcherTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO Doesn't work, needs to be fixed.
|
// TODO Doesn't work, needs to be fixed.
|
||||||
// val out = ByteArrayOutputStream()
|
//val out = ByteArrayOutputStream()
|
||||||
// patcher.saveTo(out)
|
//patcher.saveTo(out)
|
||||||
// assertTrue(
|
//assertTrue(
|
||||||
// // 8 is a random value, it's just weird if it's any lower than that
|
// // 8 is a random value, it's just weird if it's any lower than that
|
||||||
// out.size() > 8,
|
// out.size() > 8,
|
||||||
// "Output must be at least 8 bytes"
|
// "Output must be at least 8 bytes"
|
||||||
// )
|
//)
|
||||||
//
|
//
|
||||||
// out.close()
|
//out.close()
|
||||||
testData.close()
|
testData.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO Doesn't work, needs to be fixed.
|
// TODO Doesn't work, needs to be fixed.
|
||||||
// @Test
|
//@Test
|
||||||
// fun noChanges() {
|
//fun noChanges() {
|
||||||
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
|
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
|
||||||
// val available = testData.available()
|
// val available = testData.available()
|
||||||
// val patcher = Patcher(testData, testSigs)
|
// val patcher = Patcher(testData, testSigs)
|
||||||
//
|
//
|
||||||
// val out = ByteArrayOutputStream()
|
// val out = ByteArrayOutputStream()
|
||||||
// patcher.saveTo(out)
|
// patcher.saveTo(out)
|
||||||
// assertEquals(available, out.size())
|
// assertEquals(available, out.size())
|
||||||
//
|
//
|
||||||
// out.close()
|
// out.close()
|
||||||
// testData.close()
|
// testData.close()
|
||||||
// }
|
//}
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user