Cache classes with their path & add ScanData for PatchData

This commit is contained in:
Lucaskyy 2022-03-19 21:28:50 +01:00
parent ae5007ebd1
commit 6bd4d80c47
No known key found for this signature in database
GPG Key ID: 1530BFF96D1EEB89
6 changed files with 47 additions and 17 deletions

View File

@ -7,12 +7,14 @@ import net.revanced.patcher.resolver.MethodResolver
import net.revanced.patcher.signature.Signature import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Jar2ASM import net.revanced.patcher.util.Jar2ASM
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
import java.util.jar.JarFile import java.util.jar.JarFile
/** /**
* The patcher. (docs WIP) * The patcher. (docs WIP)
* *
* @param input the input stream to read from, must be a JAR file (for now) * @param input the input stream to read from, must be a JAR
* @param signatures the signatures
*/ */
class Patcher ( class Patcher (
input: InputStream, input: InputStream,
@ -22,7 +24,8 @@ class Patcher (
private val patches: MutableList<Patch> = mutableListOf() private val patches: MutableList<Patch> = mutableListOf()
init { init {
cache.methods.putAll(MethodResolver(Jar2ASM.jar2asm(input), signatures).resolve()) cache.classes.putAll(Jar2ASM.jar2asm(input))
cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve())
} }
fun addPatches(vararg patches: Patch) { fun addPatches(vararg patches: Patch) {
@ -43,4 +46,8 @@ class Patcher (
} }
} }
} }
fun save(output: OutputStream) {
}
} }

View File

@ -1,8 +1,11 @@
package net.revanced.patcher.cache package net.revanced.patcher.cache
data class Cache( import org.objectweb.asm.tree.ClassNode
class Cache {
val classes: MutableMap<String, ClassNode> = mutableMapOf()
val methods: MethodMap = MethodMap() val methods: MethodMap = MethodMap()
) }
class MethodMap : LinkedHashMap<String, PatchData>() { class MethodMap : LinkedHashMap<String, PatchData>() {
override fun get(key: String): PatchData { override fun get(key: String): PatchData {

View File

@ -5,5 +5,11 @@ import org.objectweb.asm.tree.MethodNode
data class PatchData( data class PatchData(
val cls: ClassNode, val cls: ClassNode,
val method: MethodNode val method: MethodNode,
val sr: ScanData
)
data class ScanData(
val startIndex: Int,
val endIndex: Int
) )

View File

@ -2,6 +2,7 @@ package net.revanced.patcher.resolver
import mu.KotlinLogging import mu.KotlinLogging
import net.revanced.patcher.cache.PatchData import net.revanced.patcher.cache.PatchData
import net.revanced.patcher.cache.ScanData
import net.revanced.patcher.signature.Signature import net.revanced.patcher.signature.Signature
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.ClassNode
@ -22,12 +23,21 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
continue continue
} }
logger.debug { "Resolving sig ${signature.name}: ${classNode.name} / ${method.name}" } logger.debug { "Resolving sig ${signature.name}: ${classNode.name} / ${method.name}" }
if (!this.cmp(method, signature)) { val (r, sr) = this.cmp(method, signature)
if (!r || sr == null) {
logger.debug { "Compare result for sig ${signature.name} has failed!" } logger.debug { "Compare result for sig ${signature.name} has failed!" }
continue continue
} }
logger.debug { "Method for sig ${signature.name} found!" } logger.debug { "Method for sig ${signature.name} found!" }
patchData[signature.name] = PatchData(classNode, method) patchData[signature.name] = PatchData(
classNode,
method,
ScanData(
// sadly we cannot create contracts for a data class, so we must assert
sr.startIndex!!,
sr.endIndex!!
)
)
} }
} }
} }
@ -40,28 +50,27 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
return patchData return patchData
} }
private fun cmp(method: MethodNode, signature: Signature): Boolean { private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {
if (signature.returns != Type.getReturnType(method.desc)) { if (signature.returns != Type.getReturnType(method.desc)) {
logger.debug { "Comparing sig ${signature.name}: invalid return type:\nexpected ${signature.returns},\ngot ${Type.getReturnType(method.desc)}" } logger.debug { "Comparing sig ${signature.name}: invalid return type:\nexpected ${signature.returns},\ngot ${Type.getReturnType(method.desc)}" }
return false return false to null
} }
if (signature.accessors != method.access) { if (signature.accessors != method.access) {
logger.debug { "Comparing sig ${signature.name}: invalid accessors:\nexpected ${signature.accessors},\ngot ${method.access}" } logger.debug { "Comparing sig ${signature.name}: invalid accessors:\nexpected ${signature.accessors},\ngot ${method.access}" }
return false return false to null
} }
if (!signature.parameters.contentEquals(Type.getArgumentTypes(method.desc))) { if (!signature.parameters.contentEquals(Type.getArgumentTypes(method.desc))) {
logger.debug { "Comparing sig ${signature.name}: invalid parameter types:\nexpected ${signature.parameters},\ngot ${Type.getArgumentTypes(method.desc)}" } logger.debug { "Comparing sig ${signature.name}: invalid parameter types:\nexpected ${signature.parameters},\ngot ${Type.getArgumentTypes(method.desc)}" }
return false return false to null
} }
val result = method.instructions.scanFor(signature.opcodes) val result = method.instructions.scanFor(signature.opcodes)
if (!result.found) { if (!result.found) {
logger.debug { "Comparing sig ${signature.name}: invalid opcode pattern" } logger.debug { "Comparing sig ${signature.name}: invalid opcode pattern" }
return false return false to null
} }
// TODO make use of the startIndex and endIndex we have from the result
return true return true to result
} }
} }

View File

@ -0,0 +1,5 @@
package net.revanced.patcher.util
object ASMWriter {
}

View File

@ -6,15 +6,15 @@ import java.io.InputStream
import java.util.jar.JarInputStream import java.util.jar.JarInputStream
object Jar2ASM { object Jar2ASM {
fun jar2asm(input: InputStream): List<ClassNode> { fun jar2asm(input: InputStream): Map<String, ClassNode> {
return buildList { return buildMap {
val jar = JarInputStream(input) val jar = JarInputStream(input)
while (true) { while (true) {
val e = jar.nextJarEntry ?: break val e = jar.nextJarEntry ?: break
if (e.name.endsWith(".class")) { if (e.name.endsWith(".class")) {
val classNode = ClassNode() val classNode = ClassNode()
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
this.add(classNode) this[e.name] = classNode
} }
} }
} }