mirror of
https://github.com/revanced/revanced-patcher.git
synced 2025-05-02 22:04:24 +02:00
fix(Io): JAR loading and saving (#8)
* refactor: Complete rewrite of `Io` * style: format code * style: rewrite todos * fix: use lateinit instead of nonnull assert for zipEntry * fix: use lateinit instead of nonnull assert for jarEntry & reuse zipEntry * docs: add docs to `Patcher` * test: match output of patcher * chore: add todo to `Io` for removing non-class files Co-authored-by: Sculas <contact@sculas.xyz>
This commit is contained in:
parent
87bbde5e06
commit
4d98cbc9e8
@ -5,28 +5,49 @@ import net.revanced.patcher.patch.Patch
|
|||||||
import net.revanced.patcher.resolver.MethodResolver
|
import net.revanced.patcher.resolver.MethodResolver
|
||||||
import net.revanced.patcher.signature.Signature
|
import net.revanced.patcher.signature.Signature
|
||||||
import net.revanced.patcher.util.Io
|
import net.revanced.patcher.util.Io
|
||||||
|
import org.objectweb.asm.tree.ClassNode
|
||||||
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The patcher. (docs WIP)
|
* The Patcher class.
|
||||||
|
* ***It is of utmost importance that the input and output streams are NEVER closed.***
|
||||||
*
|
*
|
||||||
* @param input the input stream to read from, must be a JAR
|
* @param input the input stream to read from, must be a JAR
|
||||||
|
* @param output the output stream to write to
|
||||||
* @param signatures the signatures
|
* @param signatures the signatures
|
||||||
* @sample net.revanced.patcher.PatcherTest
|
* @sample net.revanced.patcher.PatcherTest
|
||||||
|
* @throws IOException if one of the streams are closed
|
||||||
*/
|
*/
|
||||||
class Patcher(
|
class Patcher(
|
||||||
private val input: InputStream,
|
private val input: InputStream,
|
||||||
|
private val output: OutputStream,
|
||||||
signatures: Array<Signature>,
|
signatures: Array<Signature>,
|
||||||
) {
|
) {
|
||||||
var cache: Cache
|
var cache: Cache
|
||||||
private val patches: MutableList<Patch> = mutableListOf()
|
|
||||||
|
private var io: Io
|
||||||
|
private val patches = mutableListOf<Patch>()
|
||||||
|
|
||||||
init {
|
init {
|
||||||
val classes = Io.readClassesFromJar(input)
|
val classes = mutableListOf<ClassNode>()
|
||||||
|
io = Io(input, output, classes)
|
||||||
|
io.readFromJar()
|
||||||
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
|
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Saves the output to the output stream.
|
||||||
|
* Calling this method will close the input and output streams,
|
||||||
|
* meaning this method should NEVER be called after.
|
||||||
|
*
|
||||||
|
* @throws IOException if one of the streams are closed
|
||||||
|
*/
|
||||||
|
fun save() {
|
||||||
|
io.saveAsJar()
|
||||||
|
}
|
||||||
|
|
||||||
fun addPatches(vararg patches: Patch) {
|
fun addPatches(vararg patches: Patch) {
|
||||||
this.patches.addAll(patches)
|
this.patches.addAll(patches)
|
||||||
}
|
}
|
||||||
@ -46,8 +67,4 @@ class Patcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun saveTo(output: OutputStream) {
|
|
||||||
Io.writeClassesToJar(input, output, cache.classes)
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -2,7 +2,7 @@ package net.revanced.patcher.cache
|
|||||||
|
|
||||||
import org.objectweb.asm.tree.ClassNode
|
import org.objectweb.asm.tree.ClassNode
|
||||||
|
|
||||||
class Cache (
|
class Cache(
|
||||||
val classes: List<ClassNode>,
|
val classes: List<ClassNode>,
|
||||||
val methods: MethodMap
|
val methods: MethodMap
|
||||||
)
|
)
|
||||||
|
@ -10,8 +10,9 @@ data class PatchData(
|
|||||||
val method: MethodNode,
|
val method: MethodNode,
|
||||||
val scanData: PatternScanData
|
val scanData: PatternScanData
|
||||||
) {
|
) {
|
||||||
|
@Suppress("Unused") // TODO(Sculas): remove this when we have coverage for this method.
|
||||||
fun findParentMethod(signature: Signature): PatchData? {
|
fun findParentMethod(signature: Signature): PatchData? {
|
||||||
return MethodResolver.resolveMethod(declaringClass, signature)
|
return MethodResolver.resolveMethod(declaringClass, signature)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,47 +3,91 @@ package net.revanced.patcher.util
|
|||||||
import org.objectweb.asm.ClassReader
|
import org.objectweb.asm.ClassReader
|
||||||
import org.objectweb.asm.ClassWriter
|
import org.objectweb.asm.ClassWriter
|
||||||
import org.objectweb.asm.tree.ClassNode
|
import org.objectweb.asm.tree.ClassNode
|
||||||
|
import java.io.BufferedInputStream
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
import java.util.jar.JarEntry
|
import java.util.jar.JarEntry
|
||||||
import java.util.jar.JarInputStream
|
import java.util.jar.JarInputStream
|
||||||
import java.util.jar.JarOutputStream
|
import java.util.zip.ZipEntry
|
||||||
|
import java.util.zip.ZipInputStream
|
||||||
|
import java.util.zip.ZipOutputStream
|
||||||
|
|
||||||
object Io {
|
internal class Io(
|
||||||
fun readClassesFromJar(input: InputStream) = mutableListOf<ClassNode>().apply {
|
private val input: InputStream,
|
||||||
val jar = JarInputStream(input)
|
private val output: OutputStream,
|
||||||
while (true) {
|
private val classes: MutableList<ClassNode>
|
||||||
val e = jar.nextJarEntry ?: break
|
) {
|
||||||
if (e.name.endsWith(".class")) {
|
private val bufferedInputStream = BufferedInputStream(input)
|
||||||
val classNode = ClassNode()
|
|
||||||
ClassReader(jar.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
|
fun readFromJar() {
|
||||||
this.add(classNode)
|
bufferedInputStream.mark(0)
|
||||||
}
|
// create a BufferedInputStream in order to read the input stream again when calling saveAsJar(..)
|
||||||
jar.closeEntry()
|
val jis = JarInputStream(bufferedInputStream)
|
||||||
|
|
||||||
|
// read all entries from the input stream
|
||||||
|
// we use JarEntry because we only read .class files
|
||||||
|
lateinit var jarEntry: JarEntry
|
||||||
|
while (jis.nextJarEntry.also { if (it != null) jarEntry = it } != null) {
|
||||||
|
// if the current entry ends with .class (indicating a java class file), add it to our list of classes to return
|
||||||
|
if (jarEntry.name.endsWith(".class")) {
|
||||||
|
// create a new ClassNode
|
||||||
|
val classNode = ClassNode()
|
||||||
|
// read the bytes with a ClassReader into the ClassNode
|
||||||
|
ClassReader(jis.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
|
||||||
|
// add it to our list
|
||||||
|
classes.add(classNode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finally, close the entry
|
||||||
|
jis.closeEntry()
|
||||||
|
}
|
||||||
|
|
||||||
|
// at last reset the buffered input stream
|
||||||
|
bufferedInputStream.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun writeClassesToJar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
|
fun saveAsJar() {
|
||||||
val jis = JarInputStream(input)
|
val jis = ZipInputStream(bufferedInputStream)
|
||||||
val jos = JarOutputStream(output)
|
val jos = ZipOutputStream(output)
|
||||||
|
|
||||||
// TODO: Add support for adding new/custom classes
|
// first write all non .class zip entries from the original input stream to the output stream
|
||||||
while (true) {
|
// we read it first to close the input stream as fast as possible
|
||||||
val next = jis.nextJarEntry ?: break
|
// TODO(oSumAtrIX): There is currently no way to remove non .class files.
|
||||||
val e = JarEntry(next) // clone it, to not modify the input (if possible)
|
lateinit var zipEntry: ZipEntry
|
||||||
jos.putNextEntry(e)
|
while (jis.nextEntry.also { if (it != null) zipEntry = it } != null) {
|
||||||
|
// skip all class files because we added them in the loop above
|
||||||
|
// TODO(oSumAtrIX): Check for zipEntry.isDirectory
|
||||||
|
if (zipEntry.name.endsWith(".class")) continue
|
||||||
|
|
||||||
val clazz = classes.singleOrNull {
|
// create a new zipEntry and write the contents of the zipEntry to the output stream
|
||||||
clazz -> clazz.name+".class" == e.name // clazz.name is the class name only while e.name is the full filename with extension
|
jos.putNextEntry(ZipEntry(zipEntry))
|
||||||
};
|
jos.write(jis.readBytes())
|
||||||
if (clazz != null) {
|
|
||||||
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
|
// close the newly created zipEntry
|
||||||
clazz.accept(cw)
|
|
||||||
jos.write(cw.toByteArray())
|
|
||||||
} else {
|
|
||||||
jos.write(jis.readBytes())
|
|
||||||
}
|
|
||||||
jos.closeEntry()
|
jos.closeEntry()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finally, close the input stream
|
||||||
|
jis.close()
|
||||||
|
bufferedInputStream.close()
|
||||||
|
input.close()
|
||||||
|
|
||||||
|
// now write all the patched classes to the output stream
|
||||||
|
for (patchedClass in classes) {
|
||||||
|
// create a new entry of the patched class
|
||||||
|
jos.putNextEntry(JarEntry(patchedClass.name + ".class"))
|
||||||
|
|
||||||
|
// parse the patched class to a byte array and write it to the output stream
|
||||||
|
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
|
||||||
|
patchedClass.accept(cw)
|
||||||
|
jos.write(cw.toByteArray())
|
||||||
|
|
||||||
|
// close the newly created jar entry
|
||||||
|
jos.closeEntry()
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, close the rest of the streams
|
||||||
|
jos.close()
|
||||||
|
output.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -7,6 +7,7 @@ object ASMWriter {
|
|||||||
fun InsnList.setAt(index: Int, node: AbstractInsnNode) {
|
fun InsnList.setAt(index: Int, node: AbstractInsnNode) {
|
||||||
this[this.get(index)] = node
|
this[this.get(index)] = node
|
||||||
}
|
}
|
||||||
|
|
||||||
fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) {
|
fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) {
|
||||||
this.insert(this.get(index), nodes.toInsnList())
|
this.insert(this.get(index), nodes.toInsnList())
|
||||||
}
|
}
|
||||||
|
@ -12,13 +12,16 @@ import net.revanced.patcher.writer.ASMWriter.setAt
|
|||||||
import org.junit.jupiter.api.assertDoesNotThrow
|
import org.junit.jupiter.api.assertDoesNotThrow
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.Opcodes.*
|
||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import org.objectweb.asm.tree.*
|
import org.objectweb.asm.tree.FieldInsnNode
|
||||||
|
import org.objectweb.asm.tree.LdcInsnNode
|
||||||
|
import org.objectweb.asm.tree.MethodInsnNode
|
||||||
|
import java.io.ByteArrayOutputStream
|
||||||
import java.io.PrintStream
|
import java.io.PrintStream
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class PatcherTest {
|
internal class PatcherTest {
|
||||||
companion object {
|
companion object {
|
||||||
val testSigs: Array<Signature> = arrayOf(
|
val testSignatures: Array<Signature> = arrayOf(
|
||||||
// Java:
|
// Java:
|
||||||
// public static void main(String[] args) {
|
// public static void main(String[] args) {
|
||||||
// System.out.println("Hello, world!");
|
// System.out.println("Hello, world!");
|
||||||
@ -45,8 +48,11 @@ internal class PatcherTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPatcher() {
|
fun testPatcher() {
|
||||||
val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
|
val patcher = Patcher(
|
||||||
val patcher = Patcher(testData, testSigs)
|
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
|
||||||
|
ByteArrayOutputStream(),
|
||||||
|
testSignatures
|
||||||
|
)
|
||||||
|
|
||||||
patcher.addPatches(
|
patcher.addPatches(
|
||||||
object : Patch("TestPatch") {
|
object : Patch("TestPatch") {
|
||||||
@ -74,9 +80,9 @@ internal class PatcherTest {
|
|||||||
startIndex + 1,
|
startIndex + 1,
|
||||||
FieldInsnNode(
|
FieldInsnNode(
|
||||||
GETSTATIC,
|
GETSTATIC,
|
||||||
Type.getInternalName(System::class.java), // "java/io/System"
|
Type.getInternalName(System::class.java), // "java/lang/System"
|
||||||
"out",
|
"out",
|
||||||
Type.getInternalName(PrintStream::class.java) // "java.io.PrintStream"
|
"L" + Type.getInternalName(PrintStream::class.java) // "Ljava/io/PrintStream"
|
||||||
),
|
),
|
||||||
LdcInsnNode("Hello, ReVanced! Adding bytecode."),
|
LdcInsnNode("Hello, ReVanced! Adding bytecode."),
|
||||||
MethodInsnNode(
|
MethodInsnNode(
|
||||||
@ -111,41 +117,27 @@ internal class PatcherTest {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Apply all patches loaded in the patcher
|
// Apply all patches loaded in the patcher
|
||||||
val result = patcher.applyPatches()
|
val patchResult = patcher.applyPatches()
|
||||||
// You can check if an error occurred
|
// You can check if an error occurred
|
||||||
for ((s, r) in result) {
|
for ((patchName, result) in patchResult) {
|
||||||
if (r.isFailure) {
|
if (result.isFailure) {
|
||||||
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
|
throw Exception("Patch $patchName failed", result.exceptionOrNull()!!)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO Doesn't work, needs to be fixed.
|
patcher.save()
|
||||||
//val out = ByteArrayOutputStream()
|
|
||||||
//patcher.saveTo(out)
|
|
||||||
//assertTrue(
|
|
||||||
// // 8 is a random value, it's just weird if it's any lower than that
|
|
||||||
// out.size() > 8,
|
|
||||||
// "Output must be at least 8 bytes"
|
|
||||||
//)
|
|
||||||
//
|
|
||||||
//out.close()
|
|
||||||
testData.close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO Doesn't work, needs to be fixed.
|
@Test
|
||||||
//@Test
|
fun `test patcher with no changes`() {
|
||||||
//fun `test patcher with no changes`() {
|
val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
|
||||||
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
|
// val available = testData.available()
|
||||||
// val available = testData.available()
|
val out = ByteArrayOutputStream()
|
||||||
// val patcher = Patcher(testData, testSigs)
|
Patcher(testData, out, testSignatures).save()
|
||||||
//
|
// FIXME(Sculas): There seems to be a 1-byte difference, not sure what it is.
|
||||||
// val out = ByteArrayOutputStream()
|
// assertEquals(available, out.size())
|
||||||
// patcher.saveTo(out)
|
out.close()
|
||||||
// assertEquals(available, out.size())
|
}
|
||||||
//
|
|
||||||
// out.close()
|
|
||||||
// testData.close()
|
|
||||||
//}
|
|
||||||
|
|
||||||
@Test()
|
@Test()
|
||||||
fun `should not raise an exception if any signature member except the name is missing`() {
|
fun `should not raise an exception if any signature member except the name is missing`() {
|
||||||
@ -154,6 +146,7 @@ internal class PatcherTest {
|
|||||||
assertDoesNotThrow("Should raise an exception because opcodes is empty") {
|
assertDoesNotThrow("Should raise an exception because opcodes is empty") {
|
||||||
Patcher(
|
Patcher(
|
||||||
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
|
PatcherTest::class.java.getResourceAsStream("/test1.jar")!!,
|
||||||
|
ByteArrayOutputStream(),
|
||||||
arrayOf(
|
arrayOf(
|
||||||
Signature(
|
Signature(
|
||||||
sigName,
|
sigName,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
package net.revanced.patcher
|
package net.revanced.patcher
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class ReaderTest {
|
internal class ReaderTest {
|
||||||
@Test
|
@Test
|
||||||
fun `read jar containing multiple classes`() {
|
fun `read jar containing multiple classes`() {
|
||||||
val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!!
|
val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!!
|
||||||
Patcher(testData, PatcherTest.testSigs) // reusing test sigs from PatcherTest
|
Patcher(testData, ByteArrayOutputStream(), PatcherTest.testSignatures) // reusing test sigs from PatcherTest
|
||||||
testData.close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -17,7 +17,7 @@ object TestUtil {
|
|||||||
private fun AbstractInsnNode.nodeString(): String {
|
private fun AbstractInsnNode.nodeString(): String {
|
||||||
val sb = NodeStringBuilder()
|
val sb = NodeStringBuilder()
|
||||||
when (this) {
|
when (this) {
|
||||||
// TODO: Add more types
|
// TODO(Sculas): Add more types
|
||||||
is LdcInsnNode -> sb
|
is LdcInsnNode -> sb
|
||||||
.addType("cst", cst)
|
.addType("cst", cst)
|
||||||
is FieldInsnNode -> sb
|
is FieldInsnNode -> sb
|
||||||
|
Loading…
x
Reference in New Issue
Block a user