fix(Io): JAR loading and saving (#8)

* refactor: Complete rewrite of `Io`

* style: format code

* style: rewrite todos

* fix: use lateinit instead of nonnull assert for zipEntry

* fix: use lateinit instead of nonnull assert for jarEntry & reuse zipEntry

* docs: add docs to `Patcher`

* test: match output of patcher

* chore: add todo to `Io` for removing non-class files

Co-authored-by: Sculas <contact@sculas.xyz>
This commit is contained in:
oSumAtrIX 2022-03-21 18:48:35 +01:00 committed by she11sh0cked
parent 87bbde5e06
commit 4d98cbc9e8
8 changed files with 133 additions and 77 deletions

View File

@ -5,28 +5,49 @@ import net.revanced.patcher.patch.Patch
import net.revanced.patcher.resolver.MethodResolver import net.revanced.patcher.resolver.MethodResolver
import net.revanced.patcher.signature.Signature import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Io import net.revanced.patcher.util.Io
import org.objectweb.asm.tree.ClassNode
import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
/** /**
* The patcher. (docs WIP) * The Patcher class.
* ***It is of utmost importance that the input and output streams are NEVER closed.***
* *
* @param input the input stream to read from, must be a JAR * @param input the input stream to read from, must be a JAR
* @param output the output stream to write to
* @param signatures the signatures * @param signatures the signatures
* @sample net.revanced.patcher.PatcherTest * @sample net.revanced.patcher.PatcherTest
* @throws IOException if one of the streams are closed
*/ */
class Patcher( class Patcher(
private val input: InputStream, private val input: InputStream,
private val output: OutputStream,
signatures: Array<Signature>, signatures: Array<Signature>,
) { ) {
var cache: Cache var cache: Cache
private val patches: MutableList<Patch> = mutableListOf()
private var io: Io
private val patches = mutableListOf<Patch>()
init { init {
val classes = Io.readClassesFromJar(input) val classes = mutableListOf<ClassNode>()
io = Io(input, output, classes)
io.readFromJar()
cache = Cache(classes, MethodResolver(classes, signatures).resolve()) cache = Cache(classes, MethodResolver(classes, signatures).resolve())
} }
/**
* Saves the output to the output stream.
* Calling this method will close the input and output streams,
* meaning this method should NEVER be called after.
*
* @throws IOException if one of the streams are closed
*/
fun save() {
io.saveAsJar()
}
fun addPatches(vararg patches: Patch) { fun addPatches(vararg patches: Patch) {
this.patches.addAll(patches) this.patches.addAll(patches)
} }
@ -46,8 +67,4 @@ class Patcher(
} }
} }
} }
fun saveTo(output: OutputStream) {
Io.writeClassesToJar(input, output, cache.classes)
}
} }

View File

@ -2,7 +2,7 @@ package net.revanced.patcher.cache
import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.ClassNode
class Cache ( class Cache(
val classes: List<ClassNode>, val classes: List<ClassNode>,
val methods: MethodMap val methods: MethodMap
) )

View File

@ -10,8 +10,9 @@ data class PatchData(
val method: MethodNode, val method: MethodNode,
val scanData: PatternScanData val scanData: PatternScanData
) { ) {
@Suppress("Unused") // TODO(Sculas): remove this when we have coverage for this method.
fun findParentMethod(signature: Signature): PatchData? { fun findParentMethod(signature: Signature): PatchData? {
return MethodResolver.resolveMethod(declaringClass, signature) return MethodResolver.resolveMethod(declaringClass, signature)
} }
} }

View File

@ -3,47 +3,91 @@ package net.revanced.patcher.util
import org.objectweb.asm.ClassReader import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassWriter import org.objectweb.asm.ClassWriter
import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.ClassNode
import java.io.BufferedInputStream
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.util.jar.JarEntry import java.util.jar.JarEntry
import java.util.jar.JarInputStream import java.util.jar.JarInputStream
import java.util.jar.JarOutputStream import java.util.zip.ZipEntry
import java.util.zip.ZipInputStream
import java.util.zip.ZipOutputStream
object Io { internal class Io(
fun readClassesFromJar(input: InputStream) = mutableListOf<ClassNode>().apply { private val input: InputStream,
val jar = JarInputStream(input) private val output: OutputStream,
while (true) { private val classes: MutableList<ClassNode>
val e = jar.nextJarEntry ?: break ) {
if (e.name.endsWith(".class")) { private val bufferedInputStream = BufferedInputStream(input)
val classNode = ClassNode()
ClassReader(jar.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) fun readFromJar() {
this.add(classNode) bufferedInputStream.mark(0)
} // create a BufferedInputStream in order to read the input stream again when calling saveAsJar(..)
jar.closeEntry() val jis = JarInputStream(bufferedInputStream)
// read all entries from the input stream
// we use JarEntry because we only read .class files
lateinit var jarEntry: JarEntry
while (jis.nextJarEntry.also { if (it != null) jarEntry = it } != null) {
// if the current entry ends with .class (indicating a java class file), add it to our list of classes to return
if (jarEntry.name.endsWith(".class")) {
// create a new ClassNode
val classNode = ClassNode()
// read the bytes with a ClassReader into the ClassNode
ClassReader(jis.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
// add it to our list
classes.add(classNode)
} }
// finally, close the entry
jis.closeEntry()
}
// at last reset the buffered input stream
bufferedInputStream.reset()
} }
fun writeClassesToJar(input: InputStream, output: OutputStream, classes: List<ClassNode>) { fun saveAsJar() {
val jis = JarInputStream(input) val jis = ZipInputStream(bufferedInputStream)
val jos = JarOutputStream(output) val jos = ZipOutputStream(output)
// TODO: Add support for adding new/custom classes // first write all non .class zip entries from the original input stream to the output stream
while (true) { // we read it first to close the input stream as fast as possible
val next = jis.nextJarEntry ?: break // TODO(oSumAtrIX): There is currently no way to remove non .class files.
val e = JarEntry(next) // clone it, to not modify the input (if possible) lateinit var zipEntry: ZipEntry
jos.putNextEntry(e) while (jis.nextEntry.also { if (it != null) zipEntry = it } != null) {
// skip all class files because we added them in the loop above
// TODO(oSumAtrIX): Check for zipEntry.isDirectory
if (zipEntry.name.endsWith(".class")) continue
val clazz = classes.singleOrNull { // create a new zipEntry and write the contents of the zipEntry to the output stream
clazz -> clazz.name+".class" == e.name // clazz.name is the class name only while e.name is the full filename with extension jos.putNextEntry(ZipEntry(zipEntry))
}; jos.write(jis.readBytes())
if (clazz != null) {
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES) // close the newly created zipEntry
clazz.accept(cw)
jos.write(cw.toByteArray())
} else {
jos.write(jis.readBytes())
}
jos.closeEntry() jos.closeEntry()
} }
// finally, close the input stream
jis.close()
bufferedInputStream.close()
input.close()
// now write all the patched classes to the output stream
for (patchedClass in classes) {
// create a new entry of the patched class
jos.putNextEntry(JarEntry(patchedClass.name + ".class"))
// parse the patched class to a byte array and write it to the output stream
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
patchedClass.accept(cw)
jos.write(cw.toByteArray())
// close the newly created jar entry
jos.closeEntry()
}
// finally, close the rest of the streams
jos.close()
output.close()
} }
} }

View File

@ -7,6 +7,7 @@ object ASMWriter {
fun InsnList.setAt(index: Int, node: AbstractInsnNode) { fun InsnList.setAt(index: Int, node: AbstractInsnNode) {
this[this.get(index)] = node this[this.get(index)] = node
} }
fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) { fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) {
this.insert(this.get(index), nodes.toInsnList()) this.insert(this.get(index), nodes.toInsnList())
} }

View File

@ -12,13 +12,16 @@ import net.revanced.patcher.writer.ASMWriter.setAt
import org.junit.jupiter.api.assertDoesNotThrow import org.junit.jupiter.api.assertDoesNotThrow
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.tree.* import org.objectweb.asm.tree.FieldInsnNode
import org.objectweb.asm.tree.LdcInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import java.io.ByteArrayOutputStream
import java.io.PrintStream import java.io.PrintStream
import kotlin.test.Test import kotlin.test.Test
internal class PatcherTest { internal class PatcherTest {
companion object { companion object {
val testSigs: Array<Signature> = arrayOf( val testSignatures: Array<Signature> = arrayOf(
// Java: // Java:
// public static void main(String[] args) { // public static void main(String[] args) {
// System.out.println("Hello, world!"); // System.out.println("Hello, world!");
@ -45,8 +48,11 @@ internal class PatcherTest {
@Test @Test
fun testPatcher() { fun testPatcher() {
val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! val patcher = Patcher(
val patcher = Patcher(testData, testSigs) PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
ByteArrayOutputStream(),
testSignatures
)
patcher.addPatches( patcher.addPatches(
object : Patch("TestPatch") { object : Patch("TestPatch") {
@ -74,9 +80,9 @@ internal class PatcherTest {
startIndex + 1, startIndex + 1,
FieldInsnNode( FieldInsnNode(
GETSTATIC, GETSTATIC,
Type.getInternalName(System::class.java), // "java/io/System" Type.getInternalName(System::class.java), // "java/lang/System"
"out", "out",
Type.getInternalName(PrintStream::class.java) // "java.io.PrintStream" "L" + Type.getInternalName(PrintStream::class.java) // "Ljava/io/PrintStream"
), ),
LdcInsnNode("Hello, ReVanced! Adding bytecode."), LdcInsnNode("Hello, ReVanced! Adding bytecode."),
MethodInsnNode( MethodInsnNode(
@ -111,41 +117,27 @@ internal class PatcherTest {
) )
// Apply all patches loaded in the patcher // Apply all patches loaded in the patcher
val result = patcher.applyPatches() val patchResult = patcher.applyPatches()
// You can check if an error occurred // You can check if an error occurred
for ((s, r) in result) { for ((patchName, result) in patchResult) {
if (r.isFailure) { if (result.isFailure) {
throw Exception("Patch $s failed", r.exceptionOrNull()!!) throw Exception("Patch $patchName failed", result.exceptionOrNull()!!)
} }
} }
// TODO Doesn't work, needs to be fixed. patcher.save()
//val out = ByteArrayOutputStream()
//patcher.saveTo(out)
//assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
//)
//
//out.close()
testData.close()
} }
// TODO Doesn't work, needs to be fixed. @Test
//@Test fun `test patcher with no changes`() {
//fun `test patcher with no changes`() { val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! // val available = testData.available()
// val available = testData.available() val out = ByteArrayOutputStream()
// val patcher = Patcher(testData, testSigs) Patcher(testData, out, testSignatures).save()
// // FIXME(Sculas): There seems to be a 1-byte difference, not sure what it is.
// val out = ByteArrayOutputStream() // assertEquals(available, out.size())
// patcher.saveTo(out) out.close()
// assertEquals(available, out.size()) }
//
// out.close()
// testData.close()
//}
@Test() @Test()
fun `should not raise an exception if any signature member except the name is missing`() { fun `should not raise an exception if any signature member except the name is missing`() {
@ -154,6 +146,7 @@ internal class PatcherTest {
assertDoesNotThrow("Should raise an exception because opcodes is empty") { assertDoesNotThrow("Should raise an exception because opcodes is empty") {
Patcher( Patcher(
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!, PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
ByteArrayOutputStream(),
arrayOf( arrayOf(
Signature( Signature(
sigName, sigName,

View File

@ -1,12 +1,12 @@
package net.revanced.patcher package net.revanced.patcher
import java.io.ByteArrayOutputStream
import kotlin.test.Test import kotlin.test.Test
internal class ReaderTest { internal class ReaderTest {
@Test @Test
fun `read jar containing multiple classes`() { fun `read jar containing multiple classes`() {
val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!! val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!!
Patcher(testData, PatcherTest.testSigs) // reusing test sigs from PatcherTest Patcher(testData, ByteArrayOutputStream(), PatcherTest.testSignatures) // reusing test sigs from PatcherTest
testData.close()
} }
} }

View File

@ -17,7 +17,7 @@ object TestUtil {
private fun AbstractInsnNode.nodeString(): String { private fun AbstractInsnNode.nodeString(): String {
val sb = NodeStringBuilder() val sb = NodeStringBuilder()
when (this) { when (this) {
// TODO: Add more types // TODO(Sculas): Add more types
is LdcInsnNode -> sb is LdcInsnNode -> sb
.addType("cst", cst) .addType("cst", cst)
is FieldInsnNode -> sb is FieldInsnNode -> sb