Improve example test testPatcher and increase caching speed

This commit is contained in:
oSumAtrIX 2022-03-20 03:06:23 +01:00
parent 81e0220d15
commit 5d146c362f
No known key found for this signature in database
GPG Key ID: A9B3094ACDB604B4
6 changed files with 71 additions and 60 deletions

View File

@ -7,7 +7,6 @@ import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Jar2ASM import net.revanced.patcher.util.Jar2ASM
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.util.jar.JarOutputStream
/** /**
* The patcher. (docs WIP) * The patcher. (docs WIP)
@ -20,12 +19,12 @@ class Patcher(
private val input: InputStream, private val input: InputStream,
signatures: Array<Signature>, signatures: Array<Signature>,
) { ) {
val cache = Cache() var cache: Cache
private val patches: MutableList<Patch> = mutableListOf() private val patches: MutableList<Patch> = mutableListOf()
init { init {
cache.classes.putAll(Jar2ASM.jar2asm(input)) val classes = Jar2ASM.jar2asm(input);
cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve()) cache = Cache(classes, MethodResolver(classes, signatures).resolve())
} }
fun addPatches(vararg patches: Patch) { fun addPatches(vararg patches: Patch) {

View File

@ -2,14 +2,14 @@ package net.revanced.patcher.cache
import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.ClassNode
class Cache { class Cache (
val classes: MutableMap<String, ClassNode> = mutableMapOf() val classes: List<ClassNode>,
val methods: MethodMap = MethodMap() val methods: MethodMap
} )
class MethodMap : LinkedHashMap<String, PatchData>() { class MethodMap : LinkedHashMap<String, PatchData>() {
override fun get(key: String): PatchData { override fun get(key: String): PatchData {
return super.get(key) ?: throw MethodNotFoundException("Method $key not found in method cache") return super.get(key) ?: throw MethodNotFoundException("Method $key was not found in the method cache")
} }
} }

View File

@ -4,12 +4,12 @@ import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.MethodNode import org.objectweb.asm.tree.MethodNode
data class PatchData( data class PatchData(
val cls: ClassNode, val declaringClass: ClassNode,
val method: MethodNode, val method: MethodNode,
val sd: ScanData val scanData: PatternScanData
) )
data class ScanData( data class PatternScanData(
val startIndex: Int, val startIndex: Int,
val endIndex: Int val endIndex: Int
) )

View File

@ -1,8 +1,9 @@
package net.revanced.patcher.resolver package net.revanced.patcher.resolver
import mu.KotlinLogging import mu.KotlinLogging
import net.revanced.patcher.cache.MethodMap
import net.revanced.patcher.cache.PatchData import net.revanced.patcher.cache.PatchData
import net.revanced.patcher.cache.ScanData import net.revanced.patcher.cache.PatternScanData
import net.revanced.patcher.signature.Signature import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.ExtraTypes import net.revanced.patcher.util.ExtraTypes
import org.objectweb.asm.Type import org.objectweb.asm.Type
@ -13,13 +14,13 @@ import org.objectweb.asm.tree.MethodNode
private val logger = KotlinLogging.logger("MethodResolver") private val logger = KotlinLogging.logger("MethodResolver")
internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) { internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) {
fun resolve(): MutableMap<String, PatchData> { fun resolve(): MethodMap {
val patchData = mutableMapOf<String, PatchData>() val methodMap = MethodMap()
for ((classNode, methods) in classList) { for ((classNode, methods) in classList) {
for (method in methods) { for (method in methods) {
for (signature in signatures) { for (signature in signatures) {
if (patchData.containsKey(signature.name)) { // method already found for this sig if (methodMap.containsKey(signature.name)) { // method already found for this sig
logger.debug { "Sig ${signature.name} already found, skipping." } logger.debug { "Sig ${signature.name} already found, skipping." }
continue continue
} }
@ -30,10 +31,10 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
continue continue
} }
logger.debug { "Method for sig ${signature.name} found!" } logger.debug { "Method for sig ${signature.name} found!" }
patchData[signature.name] = PatchData( methodMap[signature.name] = PatchData(
classNode, classNode,
method, method,
ScanData( PatternScanData(
// sadly we cannot create contracts for a data class, so we must assert // sadly we cannot create contracts for a data class, so we must assert
sr.startIndex!!, sr.startIndex!!,
sr.endIndex!! sr.endIndex!!
@ -44,11 +45,11 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
} }
for (signature in signatures) { for (signature in signatures) {
if (patchData.containsKey(signature.name)) continue if (methodMap.containsKey(signature.name)) continue
logger.error { "Could not find method for sig ${signature.name}!" } logger.error { "Could not find method for sig ${signature.name}!" }
} }
return patchData return methodMap
} }
private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> { private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {

View File

@ -10,21 +10,20 @@ import java.util.jar.JarInputStream
import java.util.jar.JarOutputStream import java.util.jar.JarOutputStream
object Jar2ASM { object Jar2ASM {
fun jar2asm(input: InputStream): Map<String, ClassNode> { fun jar2asm(input: InputStream) = mutableListOf<ClassNode>().apply {
return buildMap { val jar = JarInputStream(input)
val jar = JarInputStream(input)
while (true) { while (true) {
val e = jar.nextJarEntry ?: break val e = jar.nextJarEntry ?: break
if (e.name.endsWith(".class")) { if (e.name.endsWith(".class")) {
val classNode = ClassNode() val classNode = ClassNode()
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
this[e.name] = classNode this.add(classNode)
} }
jar.closeEntry() jar.closeEntry()
} }
}
} }
fun asm2jar(input: InputStream, output: OutputStream, structure: Map<String, ClassNode>) {
fun asm2jar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
val jis = JarInputStream(input) val jis = JarInputStream(input)
val jos = JarOutputStream(output) val jos = JarOutputStream(output)
@ -33,10 +32,13 @@ object Jar2ASM {
val next = jis.nextJarEntry ?: break val next = jis.nextJarEntry ?: break
val e = JarEntry(next) // clone it, to not modify the input (if possible) val e = JarEntry(next) // clone it, to not modify the input (if possible)
jos.putNextEntry(e) jos.putNextEntry(e)
if (structure.containsKey(e.name)) {
val clazz = classes.singleOrNull {
clazz -> clazz.name == e.name
};
if (clazz != null) {
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES) val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
val cn = structure[e.name]!! clazz.accept(cw)
cn.accept(cw)
jos.write(cw.toByteArray()) jos.write(cw.toByteArray())
} else { } else {
jos.write(jis.readAllBytes()) jos.write(jis.readAllBytes())

View File

@ -7,11 +7,8 @@ import net.revanced.patcher.util.ExtraTypes
import net.revanced.patcher.writer.ASMWriter.setAt import net.revanced.patcher.writer.ASMWriter.setAt
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.tree.LdcInsnNode import org.objectweb.asm.tree.*
import java.io.ByteArrayOutputStream
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class PatcherTest { internal class PatcherTest {
private val testSigs: Array<Signature> = arrayOf( private val testSigs: Array<Signature> = arrayOf(
@ -46,14 +43,24 @@ internal class PatcherTest {
patcher.addPatches( patcher.addPatches(
Patch ("TestPatch") { Patch ("TestPatch") {
// Get the method from the resolver cache // Get the method from the resolver cache
val main = patcher.cache.methods["mainMethod"] val mainMethod = patcher.cache.methods["mainMethod"]
// Get the instruction list // Get the instruction list
val insn = main.method.instructions!! val instructions = mainMethod.method.instructions!!
// Let's modify it, so it prints "Hello, ReVanced!" // Let's modify it, so it prints "Hello, ReVanced!"
// Get the start index of our signature // Get the start index of our opcode pattern
// This will be the index of the LDC instruction // This will be the index of the LDC instruction
val startIndex = main.sd.startIndex val startIndex = mainMethod.scanData.startIndex
insn.setAt(startIndex, LdcInsnNode("Hello, ReVanced!")) // Create a new Ldc node and replace the LDC instruction
val stringNode = LdcInsnNode("Hello, ReVanced!");
instructions.setAt(startIndex, stringNode)
// Now lets print our string to the console output
// First create a list of instructions
val printCode = InsnList();
printCode.add(LdcInsnNode("Hello, ReVanced!"))
printCode.add(MethodInsnNode(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V"))
// Add the list after the second instruction by our pattern
instructions.insert(instructions[startIndex + 1], printCode)
// Finally, tell the patcher that this patch was a success. // Finally, tell the patcher that this patch was a success.
// You can also return PatchResultError with a message. // You can also return PatchResultError with a message.
// If an exception is thrown inside this function, // If an exception is thrown inside this function,
@ -62,7 +69,9 @@ internal class PatcherTest {
} }
) )
// Apply all patches loaded in the patcher
val result = patcher.applyPatches() val result = patcher.applyPatches()
// You can check if an error occurred
for ((s, r) in result) { for ((s, r) in result) {
if (r.isFailure) { if (r.isFailure) {
throw Exception("Patch $s failed", r.exceptionOrNull()!!) throw Exception("Patch $s failed", r.exceptionOrNull()!!)
@ -70,30 +79,30 @@ internal class PatcherTest {
} }
// TODO Doesn't work, needs to be fixed. // TODO Doesn't work, needs to be fixed.
// val out = ByteArrayOutputStream() //val out = ByteArrayOutputStream()
// patcher.saveTo(out) //patcher.saveTo(out)
// assertTrue( //assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that // // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8, // out.size() > 8,
// "Output must be at least 8 bytes" // "Output must be at least 8 bytes"
// ) //)
// //
// out.close() //out.close()
testData.close() testData.close()
} }
// TODO Doesn't work, needs to be fixed. // TODO Doesn't work, needs to be fixed.
// @Test //@Test
// fun noChanges() { //fun noChanges() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! // val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available() // val available = testData.available()
// val patcher = Patcher(testData, testSigs) // val patcher = Patcher(testData, testSigs)
// //
// val out = ByteArrayOutputStream() // val out = ByteArrayOutputStream()
// patcher.saveTo(out) // patcher.saveTo(out)
// assertEquals(available, out.size()) // assertEquals(available, out.size())
// //
// out.close() // out.close()
// testData.close() // testData.close()
// } //}
} }