diff --git a/patches/src/main/kotlin/app/revanced/util/BytecodeUtils.kt b/patches/src/main/kotlin/app/revanced/util/BytecodeUtils.kt index 1e5910f03..3b0b85818 100644 --- a/patches/src/main/kotlin/app/revanced/util/BytecodeUtils.kt +++ b/patches/src/main/kotlin/app/revanced/util/BytecodeUtils.kt @@ -22,12 +22,16 @@ import app.revanced.patches.shared.mapping.resourceMappingPatch import app.revanced.util.Utils.printWarn import com.android.tools.smali.dexlib2.AccessFlags import com.android.tools.smali.dexlib2.Opcode +import com.android.tools.smali.dexlib2.Opcode.* import com.android.tools.smali.dexlib2.iface.Method import com.android.tools.smali.dexlib2.iface.MethodParameter import com.android.tools.smali.dexlib2.iface.instruction.FiveRegisterInstruction import com.android.tools.smali.dexlib2.iface.instruction.Instruction import com.android.tools.smali.dexlib2.iface.instruction.OneRegisterInstruction import com.android.tools.smali.dexlib2.iface.instruction.ReferenceInstruction +import com.android.tools.smali.dexlib2.iface.instruction.RegisterRangeInstruction +import com.android.tools.smali.dexlib2.iface.instruction.ThreeRegisterInstruction +import com.android.tools.smali.dexlib2.iface.instruction.TwoRegisterInstruction import com.android.tools.smali.dexlib2.iface.instruction.WideLiteralInstruction import com.android.tools.smali.dexlib2.iface.instruction.formats.Instruction31i import com.android.tools.smali.dexlib2.iface.reference.MethodReference @@ -37,6 +41,7 @@ import com.android.tools.smali.dexlib2.immutable.ImmutableField import com.android.tools.smali.dexlib2.immutable.ImmutableMethod import com.android.tools.smali.dexlib2.immutable.ImmutableMethodImplementation import com.android.tools.smali.dexlib2.util.MethodUtil +import java.util.EnumSet const val REGISTER_TEMPLATE_REPLACEMENT: String = "REGISTER_INDEX" @@ -52,6 +57,121 @@ fun parametersEqual( return true } +/** + * Starting from and including the instruction at index [startIndex], + * finds the next register that is wrote to and not read from. If a return instruction + * is encountered, then the lowest unused register is returned. + * + * This method can return a non 4-bit register, and the calling code may need to temporarily + * swap register contents if a 4-bit register is required. + * + * @param startIndex Inclusive starting index. + * @param registersToExclude Registers to exclude, and consider as used. For most use cases, + * all registers used in injected code should be specified. + * @throws IllegalArgumentException If a branch or conditional statement is encountered + * before a suitable register is found. + */ +internal fun Method.findFreeRegister(startIndex: Int, vararg registersToExclude: Int): Int { + if (implementation == null) { + throw IllegalArgumentException("Method has no implementation: $this") + } + if (startIndex < 0 || startIndex >= instructions.count()) { + throw IllegalArgumentException("startIndex out of bounds: $startIndex") + } + + // All registers used by an instruction. + fun Instruction.getRegistersUsed() = when (this) { + is FiveRegisterInstruction -> listOf(registerC, registerD, registerE, registerF, registerG) + is ThreeRegisterInstruction -> listOf(registerA, registerB, registerC) + is TwoRegisterInstruction -> listOf(registerA, registerB) + is OneRegisterInstruction -> listOf(registerA) + is RegisterRangeInstruction -> (startRegister until (startRegister + registerCount)).toList() + else -> emptyList() + } + + // Register that is written to by an instruction. + fun Instruction.getRegisterWritten() = when (this) { + is ThreeRegisterInstruction -> registerA + is TwoRegisterInstruction -> registerA + is OneRegisterInstruction -> registerA + else -> throw IllegalStateException("Not a write instruction: $this") + } + + val writeOpcodes = EnumSet.of( + NEW_INSTANCE, NEW_ARRAY, + MOVE, MOVE_FROM16, MOVE_16, MOVE_WIDE, MOVE_WIDE_FROM16, MOVE_WIDE_16, MOVE_OBJECT, + MOVE_OBJECT_FROM16, MOVE_OBJECT_16, MOVE_RESULT, MOVE_RESULT_WIDE, MOVE_RESULT_OBJECT, MOVE_EXCEPTION, + IGET, IGET_WIDE, IGET_OBJECT, IGET_BOOLEAN, IGET_BYTE, IGET_CHAR, IGET_SHORT, + SGET, SGET_WIDE, SGET_OBJECT, SGET_BOOLEAN, SGET_BYTE, SGET_CHAR, SGET_SHORT, + ) + + val branchOpcodes = EnumSet.of( + GOTO, GOTO_16, GOTO_32, + IF_EQ, IF_NE, IF_LT, IF_GE, IF_GT, IF_LE, + IF_EQZ, IF_NEZ, IF_LTZ, IF_GEZ, IF_GTZ, IF_LEZ, + ) + + val returnOpcodes = EnumSet.of( + RETURN_VOID, RETURN, RETURN_WIDE, RETURN_OBJECT, + ) + + // Highest 4-bit register available, exclusive. Ideally return a free register less than this. + val maxRegister4Bits = 16 + var bestFreeRegisterFound: Int? = null + val usedRegisters = registersToExclude.toMutableSet() + + for (i in startIndex until instructions.count()) { + val instruction = getInstruction(i) + + if (instruction.opcode in returnOpcodes) { + // Method returns. Use lowest register that hasn't been encountered. + val freeRegister = (0 until implementation!!.registerCount).find { + it !in usedRegisters + } + if (freeRegister != null) { + return freeRegister + } + if (bestFreeRegisterFound != null) { + return bestFreeRegisterFound; + } + + // Somehow every method register was read from before any register was wrote to. + // In practice this never occurs. + throw IllegalArgumentException("Could not find a free register from startIndex: " + + "$startIndex excluding: $registersToExclude") + } + + if (instruction.opcode in branchOpcodes) { + if (bestFreeRegisterFound != null) { + return bestFreeRegisterFound; + } + // This method is simple and does not follow branching. + throw IllegalArgumentException("Encountered a branch statement before a free register could be found") + } + + if (instruction.opcode in writeOpcodes) { + val freeRegister = instruction.getRegisterWritten() + if (freeRegister !in usedRegisters) { + if (freeRegister < maxRegister4Bits) { + // Found an ideal register. + return freeRegister + } + + // Continue searching for a 4-bit register if available. + if (bestFreeRegisterFound == null || freeRegister < bestFreeRegisterFound) { + bestFreeRegisterFound = freeRegister + } + } + } + + usedRegisters.addAll(instruction.getRegistersUsed()) + } + + // Cannot be reached since a branch or return statement will + // be encountered before the end of the method. + throw IllegalStateException() +} + /** * Find the [MutableMethod] from a given [Method] in a [MutableClass]. *