diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index bc0e087..42898eb 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -5,7 +5,6 @@ android-compileSdk = "36" android-minSdk = "26" kotlin = "2.3.0" apktool-lib = "2.10.1.1" -kotlinx-coroutines-core = "1.10.2" mockk = "1.14.7" multidexlib2 = "3.0.3.r3" # Tracking https://github.com/google/smali/issues/64. @@ -17,7 +16,6 @@ vanniktechMavenPublish = "0.35.0" [libraries] apktool-lib = { module = "app.revanced:apktool-lib", version.ref = "apktool-lib" } kotlin-reflect = { module = "org.jetbrains.kotlin:kotlin-reflect", version.ref = "kotlin" } -kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines-core" } kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } mockk = { module = "io.mockk:mockk", version.ref = "mockk" } multidexlib2 = { module = "app.revanced:multidexlib2", version.ref = "multidexlib2" } diff --git a/patcher/build.gradle.kts b/patcher/build.gradle.kts index 3d27d69..2663e17 100644 --- a/patcher/build.gradle.kts +++ b/patcher/build.gradle.kts @@ -41,7 +41,6 @@ kotlin { commonMain.dependencies { implementation(libs.apktool.lib) implementation(libs.kotlin.reflect) - implementation(libs.kotlinx.coroutines.core) implementation(libs.multidexlib2) implementation(libs.smali) implementation(libs.xpp3) diff --git a/patcher/src/commonMain/kotlin/app/revanced/patcher/InstructionFilter.kt b/patcher/src/commonMain/kotlin/app/revanced/patcher/InstructionFilter.kt new file mode 100644 index 0000000..dd2f68b --- /dev/null +++ b/patcher/src/commonMain/kotlin/app/revanced/patcher/InstructionFilter.kt @@ -0,0 +1,1082 @@ +// Temporarily adding this file for development purposes of patches + +@file:Suppress("unused") + +package app.revanced.patcher + +import app.revanced.patcher.FieldAccessFilter.Companion.parseJvmFieldAccess +import app.revanced.patcher.MethodCallFilter.Companion.parseJvmMethodCall +import com.android.tools.smali.dexlib2.Opcode +import com.android.tools.smali.dexlib2.iface.Method +import com.android.tools.smali.dexlib2.iface.instruction.Instruction +import com.android.tools.smali.dexlib2.iface.instruction.ReferenceInstruction +import com.android.tools.smali.dexlib2.iface.instruction.WideLiteralInstruction +import com.android.tools.smali.dexlib2.iface.reference.FieldReference +import com.android.tools.smali.dexlib2.iface.reference.MethodReference +import com.android.tools.smali.dexlib2.iface.reference.StringReference +import com.android.tools.smali.dexlib2.iface.reference.TypeReference +import java.util.EnumSet + +/** + * Simple interface to control how much space is allowed between a previous + * [InstructionFilter] match and the current [InstructionFilter]. + */ +fun interface InstructionLocation { + /** + * @param previouslyMatchedIndex The previously matched index, or -1 if this is the first filter. + * @param currentIndex The current method index that is about to be checked. + */ + fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean + + /** + * Matching can occur anywhere after the previous instruction filter match index. + * Is the default behavior for all filters. + */ + class MatchAfterAnywhere : InstructionLocation { + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int) = true + } + + /** + * Matches the first instruction of a method. + * + * This can only be used for the first filter, and using with any other filter will throw an exception. + */ + class MatchFirst() : InstructionLocation { + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean { + require(previouslyMatchedIndex < 0) { + "MatchFirst can only be used for the first instruction filter" + } + return true + } + } + + /** + * Instruction index immediately after the previous filter. + * + * Useful for opcodes that must always appear immediately after the last filter such as: + * - [Opcode.MOVE_RESULT] + * - [Opcode.MOVE_RESULT_WIDE] + * - [Opcode.MOVE_RESULT_OBJECT] + * + * This cannot be used for the first filter and will throw an exception. + */ + class MatchAfterImmediately() : InstructionLocation { + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean { + require(previouslyMatchedIndex >= 0) { + "MatchAfterImmediately cannot be used for the first instruction filter" + } + return currentIndex - 1 == previouslyMatchedIndex + } + } + + /** + * Instruction index can occur within a range of the previous instruction filter match index. + * used to constrain instruction matching to a region after the previous instruction filter. + * + * This cannot be used for the first filter and will throw an exception. + * + * @param matchDistance The number of unmatched instructions that can exist between the + * current instruction filter and the previously matched instruction filter. + * A value of 0 means the current filter can only match immediately after + * the previously matched instruction (making this functionally identical to + * [MatchAfterImmediately]). A value of 10 means between 0 and 10 unmatched + * instructions can exist between the previously matched instruction and + * the current instruction filter. + */ + class MatchAfterWithin(val matchDistance: Int) : InstructionLocation { + init { + require(matchDistance >= 0) { + "matchDistance must be non-negative" + } + } + + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean { + require(previouslyMatchedIndex >= 0) { + "MatchAfterImmediately cannot be used for the first instruction filter" + } + return currentIndex - previouslyMatchedIndex - 1 <= matchDistance + } + } + + /** + * Instruction index can occur only after a minimum number of unmatched instructions from the + * previous instruction match. Or if this is used with the first filter of a fingerprint then + * this can only match starting from a given instruction index. + * + * @param minimumDistanceFromLastInstruction The minimum number of unmatched instructions that + * must exist between this instruction and the last matched instruction. A value of 0 is + * functionally identical to [MatchAfterImmediately]. + */ + class MatchAfterAtLeast(var minimumDistanceFromLastInstruction: Int) : InstructionLocation { + init { + require(minimumDistanceFromLastInstruction >= 0) { + "minimumDistanceFromLastInstruction must >= 0" + } + } + + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean { + return currentIndex - previouslyMatchedIndex - 1 >= minimumDistanceFromLastInstruction + } + } + + /** + * Functionally combines both [MatchAfterAtLeast] and [MatchAfterWithin] to give a bounded range + * where the next instruction must match relative to the previous matched instruction. + * + * Unlike [MatchAfterImmediately] or [MatchAfterWithin], this can be used for the first filter + * to constrain matching to a specific range starting from index 0. + * + * @param minimumDistanceFromLastInstruction The minimum number of unmatched instructions that + * must exist between this instruction and the last + * matched instruction. + * @param maximumDistanceFromLastInstruction The maximum number of unmatched instructions + * that can exist between this instruction and the + * last matched instruction. + */ + class MatchAfterRange( + val minimumDistanceFromLastInstruction: Int, + val maximumDistanceFromLastInstruction: Int + ) : InstructionLocation { + + private val minMatcher = MatchAfterAtLeast(minimumDistanceFromLastInstruction) + private val maxMatcher = MatchAfterWithin(maximumDistanceFromLastInstruction) + + init { + require(minimumDistanceFromLastInstruction <= maximumDistanceFromLastInstruction) { + "minimumDistanceFromLastInstruction must be <= maximumDistanceFromLastInstruction" + } + } + + override fun indexIsValidForMatching(previouslyMatchedIndex: Int, currentIndex: Int): Boolean { + // For the first filter, previouslyMatchedIndex will be -1, and both delegates + // will correctly enforce their own semantics starting from index 0. + return minMatcher.indexIsValidForMatching(previouslyMatchedIndex, currentIndex) && + maxMatcher.indexIsValidForMatching(previouslyMatchedIndex, currentIndex) + } + } +} + + +/** + * String comparison type. + */ +enum class StringComparisonType { + EQUALS, + CONTAINS, + STARTS_WITH, + ENDS_WITH; + + /** + * @param targetString The target string to search + * @param searchString To search for in the target string (or to compare entirely for equality). + */ + fun compare(targetString: String, searchString: String): Boolean { + return when (this) { + EQUALS -> targetString == searchString + CONTAINS -> targetString.contains(searchString) + STARTS_WITH -> targetString.startsWith(searchString) + ENDS_WITH -> targetString.endsWith(searchString) + } + } + + /** + * Throws [IllegalArgumentException] if the class type search string is invalid and can never match. + */ + internal fun validateSearchStringForClassType(classTypeSearchString: String) { + when (this) { + EQUALS -> { + STARTS_WITH.validateSearchStringForClassType(classTypeSearchString) + ENDS_WITH.validateSearchStringForClassType(classTypeSearchString) + } + + CONTAINS -> Unit // Nothing to validate, anything goes. + STARTS_WITH -> require(classTypeSearchString.startsWith('L')) { + "Class type does not start with L: $classTypeSearchString" + } + + ENDS_WITH -> require(classTypeSearchString.endsWith(';')) { + "Class type does not end with a semicolon: $classTypeSearchString" + } + } + } +} + + +/** + * Matches method [Instruction] objects, similar to how [Fingerprint] matches entire methods. + * + * The most basic filters match only opcodes and nothing more, + * and more precise filters can match: + * - Field references (get/put opcodes) by name/type. + * - Method calls (invoke_* opcodes) by name/parameter/return type. + * - Object instantiation for specific class types. + * - Literal const values. + * + * If creating a custom filter for unusual or app specific purposes, consider extending + * [OpcodeFilter] or [OpcodesFilter] to reduce boilerplate opcode checking logic. + */ +fun interface InstructionFilter { + + /** + * The [InstructionLocation] associated with this filter. + */ + val location: InstructionLocation + get() = InstructionLocation.MatchAfterAnywhere() + + /** + * If this filter matches the method instruction. + * + * @param enclosingMethod The method of that contains [instruction]. + * @param instruction The instruction to check for a match. + */ + fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean +} + + +class AnyInstruction internal constructor( + internal val filters: List, + override val location: InstructionLocation +) : InstructionFilter { + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + return filters.any { filter -> + filter.matches(enclosingMethod, instruction) + } + } +} + +/** + * Logical OR operator where the first filter that matches satisfies this filter. + * + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun anyInstruction( + vararg filters: InstructionFilter, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = AnyInstruction(filters.asList(), location) + + +/** + * Single opcode match. + * + * Patches can extend this as desired to do unusual or app specific instruction filtering. + * Or Alternatively can implement [InstructionFilter] directly. + * + * @param opcode Opcode to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +open class OpcodeFilter( + val opcode: Opcode, + override val location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) : InstructionFilter { + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + return instruction.opcode == opcode + } +} + +/** + * Single opcode match. + * + * @param opcode Opcode to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun opcode( + opcode: Opcode, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = OpcodeFilter(opcode, location) + + +/** + * Matches a single instruction from many kinds of opcodes. + * + * Patches can extend this as desired to do unusual or app specific instruction filtering. + * Or Alternatively can implement [InstructionFilter] directly. + * + * @param opcodes Set of opcodes to match to. Value of `null` will match any opcode. + * If matching only a single opcode then instead use [OpcodeFilter]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +open class OpcodesFilter protected constructor( + val opcodes: EnumSet?, + override val location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) : InstructionFilter { + + protected constructor( + opcodes: List?, + location: InstructionLocation + ) : this(if (opcodes == null) null else EnumSet.copyOf(opcodes), location) + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (opcodes == null) { + return true // Match anything. + } + return opcodes.contains(instruction.opcode) + } + + internal companion object { + /** + * First opcode can match anywhere in a method, but all + * subsequent opcodes must match after the previous opcode. + * + * A value of `null` indicates to match any opcode. + */ + internal fun listOfOpcodes(opcodes: Collection): List { + val list = ArrayList(opcodes.size) + var location: InstructionLocation? = null + + opcodes.forEach { opcode -> + // First opcode can match anywhere. + val opcodeLocation = location ?: InstructionLocation.MatchAfterAnywhere() + + list += if (opcode == null) { + // Null opcode matches anything. + OpcodesFilter( + null as List?, + opcodeLocation + ) + } else { + OpcodeFilter(opcode, opcodeLocation) + } + + if (location == null) { + location = InstructionLocation.MatchAfterImmediately() + } + } + + return list + } + } +} + + +class LiteralFilter internal constructor( + val literal: () -> Long, + opcodes: List? = null, + location: InstructionLocation +) : OpcodesFilter(opcodes, location) { + + /** + * Store the lambda value instead of calling it more than once. + */ + private val literalValue: Long by lazy(literal) + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + if (instruction !is WideLiteralInstruction) return false + + return instruction.wideLiteral == literalValue + } +} + +/** + * Long literal. Automatically converts literal to opcode hex. + * + * @param literal Literal number. + * @param opcodes Opcodes to match. By default this matches any literal number opcode such as: + * [Opcode.CONST_4], [Opcode.CONST_16], [Opcode.CONST], [Opcode.CONST_WIDE]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun literal( + literal: Long, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = LiteralFilter({ literal }, opcodes, location) + +/** + * Integer literal. Automatically converts literal to opcode hex. + * + * @param literal Literal number. + * @param opcodes Opcodes to match. By default this matches any literal number opcode such as: + * [Opcode.CONST_4], [Opcode.CONST_16], [Opcode.CONST], [Opcode.CONST_WIDE]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun literal( + literal: Int, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = LiteralFilter({ literal.toLong() }, opcodes, location) + +/** + * Double point literal. Automatically converts literal to opcode hex. + * + * @param literal Literal number. + * @param opcodes Opcodes to match. By default this matches any literal number opcode such as: + * [Opcode.CONST_4], [Opcode.CONST_16], [Opcode.CONST], [Opcode.CONST_WIDE]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun literal( + literal: Double, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = LiteralFilter({ literal.toRawBits() }, opcodes, location) + +/** + * Floating point literal. Automatically converts literal to opcode hex. + * + * @param literal Floating point literal. + * @param opcodes Opcodes to match. By default this matches any literal number opcode such as: + * [Opcode.CONST_4], [Opcode.CONST_16], [Opcode.CONST], [Opcode.CONST_WIDE]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun literal( + literal: Float, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = LiteralFilter({ literal.toRawBits().toLong() }, opcodes, location) + +/** + * Literal number value. Automatically converts the provided number to opcode hex. + * + * @param literal Literal number. + * @param opcodes Opcodes to match. By default this matches any literal number opcode such as: + * [Opcode.CONST_4], [Opcode.CONST_16], [Opcode.CONST], [Opcode.CONST_WIDE]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun literal( + literal: () -> Long, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = LiteralFilter(literal, opcodes, location) + + +class MethodCallFilter internal constructor( + val definingClass: String? = null, + val name: String? = null, + val parameters: List? = null, + val returnType: String? = null, + opcodes: List? = null, + location: InstructionLocation +) : OpcodesFilter(opcodes, location) { + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + val reference = (instruction as? ReferenceInstruction)?.reference as? MethodReference + ?: return false + + if (definingClass != null) { + val referenceClass = reference.definingClass + + if (!StringComparisonType.ENDS_WITH.compare(referenceClass, definingClass)) { + // Check if 'this' defining class is used. + // Would be nice if this also checked all super classes, + // but doing so requires iteratively checking all superclasses + // up to the root class since class defs are mere Strings. + if (!(definingClass == "this" && referenceClass == enclosingMethod.definingClass)) { + return false + } // else, the method call is for 'this' class. + } + } + + if (name != null && reference.name != name) { + return false + } + + if (returnType != null && + !StringComparisonType.STARTS_WITH.compare(reference.returnType, returnType) + ) { + return false + } + fun parametersStartsWith( + parameters1: Iterable, + parameters2: Iterable, + ): Boolean { + if (parameters1.count() != parameters2.count()) return false + val iterator1 = parameters1.iterator() + parameters2.forEach { + if (!it.startsWith(iterator1.next())) return false + } + return true + } + if (parameters != null && + !parametersStartsWith(reference.parameterTypes, parameters) + ) { + return false + } + + return true + } + + internal companion object { + private val regex = Regex("""^(L[^;]+;)->([^(\s]+)\(([^)]*)\)(\[?L[^;]+;|\[?[BCSIJFDZV])${'$'}""") + + internal fun parseJvmMethodCall( + methodSignature: String, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() + ): MethodCallFilter { + val matchResult = regex.matchEntire(methodSignature) + ?: throw IllegalArgumentException("Invalid method signature: $methodSignature") + + val classDescriptor = matchResult.groupValues[1] + val methodName = matchResult.groupValues[2] + val paramDescriptorString = matchResult.groupValues[3] + val returnDescriptor = matchResult.groupValues[4] + + val paramDescriptors = parseParameterDescriptors(paramDescriptorString) + + return MethodCallFilter( + classDescriptor, + methodName, + paramDescriptors, + returnDescriptor, + opcodes, + location + ) + } + + /** + * Parses a single JVM type descriptor or an array descriptor at the current position. + * For example: Lcom/example/SomeClass; or I or [I or [Lcom/example/SomeClass; + */ + private fun parseSingleType(params: String, startIndex: Int): Pair { + var i = startIndex + + // Skip past array declaration, including multi-dimensional arrays. + val paramsLength = params.length + while (i < paramsLength && params[i] == '[') { + i++ + } + + return if (i < paramsLength && params[i] == 'L') { + // It's an object type starting with 'L', read until ';' + val semicolonPos = params.indexOf(';', i) + if (semicolonPos < 0) { + throw IllegalArgumentException("Malformed object descriptor (missing semicolon): $params") + } + // Substring from startIndex up to and including the semicolon. + val typeDescriptor = params.substring(startIndex, semicolonPos + 1) + typeDescriptor to (semicolonPos + 1) + } else { + // It's either a primitive or we've already consumed the array part + // So just take one character (e.g. 'I', 'Z', 'B', etc.) + val typeDescriptor = params.substring(startIndex, i + 1) + typeDescriptor to (i + 1) + } + } + + /** + * Parses the parameters into a list of JVM type descriptors. + */ + private fun parseParameterDescriptors(paramString: String): List { + val result = mutableListOf() + var currentIndex = 0 + val stringLength = paramString.length + + while (currentIndex < stringLength) { + val (type, nextIndex) = parseSingleType(paramString, currentIndex) + result.add(type) + currentIndex = nextIndex + } + + return result + } + } +} + +/** + * Matches a method call, such as: + * `invoke-virtual {v3, v4}, La;->b(I)V` + * + * @param definingClass Defining class of the field call. Compares using [StringComparisonType.ENDS_WITH]. + * For calls to a method in the same class, use 'this' as the defining class. + * Note: 'this' does not work for fields found in superclasses. + * @param name Full name of the method. Compares using [StringComparisonType.EQUALS]. + * @param parameters Parameters of the method call. Each parameter matches using[StringComparisonType.STARTS_WITH] + * and semantics are the same as [Fingerprint] parameters. + * @param returnType Return type. Matches using [StringComparisonType.STARTS_WITH]. + * @param opcodes Opcode types to match. By default this matches any method call opcode: `Opcode.INVOKE_*`. + * If this filter must match specific types of method call, then specify the desired opcodes +such as [Opcode.INVOKE_STATIC], [Opcode.INVOKE_STATIC_RANGE] to match only static calls. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun methodCall( + definingClass: String? = null, + name: String? = null, + parameters: List? = null, + returnType: String? = null, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = MethodCallFilter( + definingClass, + name, + parameters, + returnType, + opcodes, + location +) + +/** + * Matches a method call, such as: + * `invoke-virtual {v3, v4}, La;->b(I)V` + * + * @param definingClass Defining class of the field call. Compares using [StringComparisonType.ENDS_WITH]. + * For calls to a method in the same class, use 'this' as the defining class. + * Note: 'this' does not work for fields found in superclasses. + * @param name Full name of the method. Compares using [StringComparisonType.EQUALS]. + * @param parameters Parameters of the method call. Each parameter matches using[StringComparisonType.STARTS_WITH] + * and semantics are the same as [Fingerprint] parameters. + * @param returnType Return type. Matches using [StringComparisonType.STARTS_WITH]. + * @param opcode Single opcode type to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun methodCall( + definingClass: String? = null, + name: String? = null, + parameters: List? = null, + returnType: String? = null, + opcode: Opcode, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = MethodCallFilter( + definingClass, + name, + parameters, + returnType, + listOf(opcode), + location +) + +/** + * Method call for a copy pasted SMALI style method signature. e.g.: + * `Landroid/view/View;->inflate(Landroid/content/Context;ILandroid/view/ViewGroup;)Landroid/view/View;` + * + * Should never be used with obfuscated method names or parameter/return types. + * + * @param smali Smali method call reference, such as + * `Landroid/view/View;->inflate(Landroid/content/Context;ILandroid/view/ViewGroup;)Landroid/view/View;`. + * @param opcodes List of all possible opcodes to match. Defaults to matching all method calls types: `Opcode.INVOKE_*`. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun methodCall( + smali: String, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = parseJvmMethodCall(smali, opcodes, location) + +/** + * Method call for a copy pasted SMALI style method signature. e.g.: + * `Landroid/view/View;->inflate(Landroid/content/Context;ILandroid/view/ViewGroup;)Landroid/view/View;` + * + * Should never be used with obfuscated method names or parameter/return types. + * + * @param smali Smali method call reference, such as + * `Landroid/view/View;->inflate(Landroid/content/Context;ILandroid/view/ViewGroup;)Landroid/view/View;`. + * @param opcode Single opcode type to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun methodCall( + smali: String, + opcode: Opcode, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = parseJvmMethodCall(smali, listOf(opcode), location) + + +class FieldAccessFilter internal constructor( + val definingClass: String? = null, + val name: String? = null, + val type: String? = null, + opcodes: List? = null, + location: InstructionLocation +) : OpcodesFilter(opcodes, location) { + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + val reference = (instruction as? ReferenceInstruction)?.reference as? FieldReference + ?: return false + + if (definingClass != null) { + val referenceClass = reference.definingClass + + if (!referenceClass.endsWith(definingClass)) { + if (!(definingClass == "this" && referenceClass == enclosingMethod.definingClass)) { + return false + } // else, the method call is for 'this' class. + } + } + + if (name != null && reference.name != name) { + return false + } + + if (type != null && !reference.type.startsWith(type)) { + return false + } + + return true + } + + internal companion object { + private val regex = Regex("""^(L[^;]+;)->([^:]+):(\[?L[^;]+;|\[?[BCSIJFDZV])${'$'}""") + + internal fun parseJvmFieldAccess( + fieldSignature: String, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() + ): FieldAccessFilter { + val matchResult = regex.matchEntire(fieldSignature) + ?: throw IllegalArgumentException("Invalid field access smali: $fieldSignature") + + return fieldAccess( + definingClass = matchResult.groupValues[1], + name = matchResult.groupValues[2], + type = matchResult.groupValues[3], + opcodes = opcodes, + location = location + ) + } + } +} + + +/** + * Matches a field call, such as: + * `iget-object v0, p0, Lahhh;->g:Landroid/view/View;` + * + * @param definingClass Defining class of the field call. Compares using [StringComparisonType.ENDS_WITH]. + * For calls to a method in the same class, use 'this' as the defining class. + * Note: 'this' does not work for fields found in superclasses. + * @param name Full name of the field. Compares using [StringComparisonType.EQUALS]. + * @param type Class type of field. Compares using [StringComparisonType.STARTS_WITH]. + * @param opcode Single opcode type to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun fieldAccess( + definingClass: String? = null, + name: String? = null, + type: String? = null, + opcode: Opcode, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = fieldAccess( + definingClass, + name, + type, + listOf(opcode), + location +) + +/** + * Matches a field call, such as: + * `iget-object v0, p0, Lahhh;->g:Landroid/view/View;` + * + * @param definingClass Defining class of the field call. Compares using [StringComparisonType.ENDS_WITH]. + * For calls to a method in the same class, use 'this' as the defining class. + * Note: 'this' does not work for fields found in superclasses. + * @param name Full name of the field. Compares using [StringComparisonType.EQUALS]. + * @param type Class type of field. Compares using [StringComparisonType.STARTS_WITH]. + * @param opcodes List of all possible opcodes to match. Defaults to matching all get/put opcodes. + * (`Opcode.IGET`, `Opcode.SGET`, `Opcode.IPUT`, `Opcode.SPUT`, etc). + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun fieldAccess( + definingClass: String? = null, + name: String? = null, + type: String? = null, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = FieldAccessFilter( + definingClass, + name, + type, + opcodes, + location +) + +/** + * Field access for a copy pasted SMALI style field access call. e.g.: + * `Ljava/lang/Boolean;->TRUE:Ljava/lang/Boolean;` + * + * Should never be used with obfuscated field names or obfuscated field types. + * @param smali Smali field access statement, such as `Ljava/lang/Boolean;->TRUE:Ljava/lang/Boolean;`. + * @param opcodes List of all possible opcodes to match. Defaults to matching all get/put opcodes. + * (`Opcode.IGET`, `Opcode.SGET`, `Opcode.IPUT`, `Opcode.SPUT`, etc). + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun fieldAccess( + smali: String, + opcodes: List? = null, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = parseJvmFieldAccess(smali, opcodes, location) + +/** + * Field access for a copy pasted SMALI style field access call. e.g.: + * `Ljava/lang/Boolean;->TRUE:Ljava/lang/Boolean;` + * + * Should never be used with obfuscated field names or obfuscated field types. + * + * @param smali Smali field access statement, such as `Ljava/lang/Boolean;->TRUE:Ljava/lang/Boolean;`. + * @param opcode Single opcode type to match. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun fieldAccess( + smali: String, + opcode: Opcode, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = parseJvmFieldAccess(smali, listOf(opcode), location) + + +class StringFilter internal constructor( + val string: () -> String, + val comparison: StringComparisonType, + location: InstructionLocation +) : OpcodesFilter(listOf(Opcode.CONST_STRING, Opcode.CONST_STRING_JUMBO), location) { + + /** + * Store the lambda value instead of calling it more than once. + */ + private val stringValue: String by lazy(string) + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + val stringReference = (instruction as ReferenceInstruction).reference as StringReference + return comparison.compare(stringReference.string, stringValue) + } +} + +/** + * Literal String instruction. + * + * @param string string literal, using exact matching of [StringComparisonType.EQUALS]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun string( + string: String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = StringFilter({ string }, StringComparisonType.EQUALS, location) + +/** + * Literal String instruction. + * + * @param string string literal, using exact matching of [StringComparisonType.EQUALS]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun string( + string: () -> String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = StringFilter(string, StringComparisonType.EQUALS, location) + +/** + * Literal String instruction. + * + * @param string string literal. + * @param comparison How to compare the string literal. For more precise matching of strings, + * consider using [anyInstruction] with multiple exact string declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun string( + string: String, + /** + * How to match a given string opcode literal. Default is exact string equality. For more + * precise matching of multiple strings, consider using [anyInstruction] with multiple + * exact string declarations. + */ + comparison: StringComparisonType = StringComparisonType.EQUALS, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = StringFilter({ string }, comparison, location) + +/** + * Literal String instruction. + * + * @param string string literal. + * @param comparison How to compare the string literal. For more precise matching of strings, + * consider using [anyInstruction] with multiple exact string declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun string( + string: () -> String, + comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = StringFilter(string, comparison, location) + + +class NewInstanceFilter internal constructor( + val type: () -> String, + val comparison: StringComparisonType, + location: InstructionLocation +) : OpcodesFilter(listOf(Opcode.NEW_INSTANCE, Opcode.NEW_ARRAY), location) { + + /** + * Store the lambda value instead of calling it more than once. + */ + private val typeValue: String by lazy { + val typeValue = type() + comparison.validateSearchStringForClassType(typeValue) + typeValue + } + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + val reference = (instruction as ReferenceInstruction).reference as TypeReference + return comparison.compare(reference.type, typeValue) + } +} + +/** + * Opcode type [Opcode.NEW_INSTANCE] or [Opcode.NEW_ARRAY] with a non obfuscated class type. + * + * @param type Class type, compared using [StringComparisonType.ENDS_WITH]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun newInstance( + type: String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = NewInstanceFilter({ type }, StringComparisonType.ENDS_WITH, location) + +/** + * Opcode type [Opcode.NEW_INSTANCE] or [Opcode.NEW_ARRAY] with a non obfuscated class type. + * + * @param type Class type, compared using [StringComparisonType.ENDS_WITH]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun newInstance( + type: () -> String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere(), +) = NewInstanceFilter(type, StringComparisonType.ENDS_WITH, location) + +/** + * Opcode type [Opcode.NEW_INSTANCE] or [Opcode.NEW_ARRAY] with a non obfuscated class type. + * + * @param type Class type. + * @param comparison How to compare the opcode class type. For more precise matching of types, + * consider using [anyInstruction] with multiple exact type declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun newInstance( + type: String, + comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = NewInstanceFilter({ type }, comparison, location) + +/** + * Opcode type [Opcode.NEW_INSTANCE] or [Opcode.NEW_ARRAY] with a non obfuscated class type. + * + * @param type Class type. + * @param comparison How to compare the opcode class type. For more precise matching of types, + * consider using [anyInstruction] with multiple exact type declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun newInstance( + type: () -> String, + comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = NewInstanceFilter(type, comparison, location) + + +class CheckCastFilter internal constructor( + val type: () -> String, + val comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) : OpcodeFilter(Opcode.CHECK_CAST, location) { + + /** + * Store the lambda value instead of calling it more than once. + */ + private val typeValue: String by lazy { + val typeValue = type() + comparison.validateSearchStringForClassType(typeValue) + typeValue + } + + override fun matches( + enclosingMethod: Method, + instruction: Instruction + ): Boolean { + if (!super.matches(enclosingMethod, instruction)) { + return false + } + + val reference = (instruction as ReferenceInstruction).reference as TypeReference + return comparison.compare(reference.type, typeValue) + } +} + +/** + * Opcode type [Opcode.CHECK_CAST] with a non obfuscated class type. + * + * @param type Class type, compared using [StringComparisonType.ENDS_WITH]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun checkCast( + type: String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = CheckCastFilter({ type }, StringComparisonType.ENDS_WITH, location) + +/** + * Opcode type [Opcode.CHECK_CAST] with a non obfuscated class type. + * + * @param type Class type, compared using [StringComparisonType.ENDS_WITH]. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun checkCast( + type: () -> String, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = CheckCastFilter(type, StringComparisonType.ENDS_WITH, location) + +/** + * Opcode type [Opcode.CHECK_CAST] with a non obfuscated class type using the provided string comparison type. + * + * @param type Class type. + * @param comparison How to compare the opcode class type. For more precise matching of types, + * consider using [anyInstruction] with multiple exact type declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun checkCast( + type: String, + comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = CheckCastFilter({ type }, comparison, location) + +/** + * Opcode type [Opcode.CHECK_CAST] with a non obfuscated class type using the provided string comparison type. + * + * @param type Class type. + * @param comparison How to compare the opcode class type. For more precise matching of types, + * consider using [anyInstruction] with multiple exact type declarations. + * @param location Where this filter is allowed to match. Default is anywhere after the previous instruction. + */ +fun checkCast( + type: () -> String, + comparison: StringComparisonType, + location: InstructionLocation = InstructionLocation.MatchAfterAnywhere() +) = CheckCastFilter(type, comparison, location) diff --git a/patcher/src/commonMain/kotlin/app/revanced/patcher/Matching.kt b/patcher/src/commonMain/kotlin/app/revanced/patcher/Matching.kt index 769c484..e660817 100644 --- a/patcher/src/commonMain/kotlin/app/revanced/patcher/Matching.kt +++ b/patcher/src/commonMain/kotlin/app/revanced/patcher/Matching.kt @@ -2,6 +2,17 @@ package app.revanced.patcher +import app.revanced.patcher.BytecodePatchContextClassDefMatching.firstClassDefOrNull +import app.revanced.patcher.BytecodePatchContextClassDefMatching.firstMutableClassDefOrNull +import app.revanced.patcher.BytecodePatchContextMethodMatching.firstMutableMethod +import app.revanced.patcher.BytecodePatchContextMethodMatching.firstMutableMethodOrNull +import app.revanced.patcher.BytecodePatchContextMethodMatching.gettingFirstMethodDeclarativelyOrNull +import app.revanced.patcher.ClassDefMethodMatching.firstMethodDeclarativelyOrNull +import app.revanced.patcher.IterableClassDefClassDefMatching.firstClassDefOrNull +import app.revanced.patcher.IterableClassDefMethodMatching.firstMethodOrNull +import app.revanced.patcher.IterableMethodMethodMatching.firstMethodDeclarativelyOrNull +import app.revanced.patcher.IterableMethodMethodMatching.firstMethodOrNull +import app.revanced.patcher.IterableMethodMethodMatching.firstMutableMethodOrNull import app.revanced.patcher.extensions.* import app.revanced.patcher.patch.BytecodePatchContext import com.android.tools.smali.dexlib2.AccessFlags @@ -10,6 +21,7 @@ import com.android.tools.smali.dexlib2.Opcode import com.android.tools.smali.dexlib2.iface.* import com.android.tools.smali.dexlib2.iface.Annotation import com.android.tools.smali.dexlib2.iface.instruction.* +import com.android.tools.smali.dexlib2.iface.reference.MethodReference import com.android.tools.smali.dexlib2.mutable.MutableMethod import com.android.tools.smali.dexlib2.util.MethodUtil import kotlin.properties.ReadOnlyProperty @@ -31,8 +43,7 @@ fun ClassDef.anyStaticField(predicate: Field.() -> Boolean) = staticFields.any(p fun ClassDef.anyInterface(predicate: String.() -> Boolean) = interfaces.any(predicate) -fun ClassDef.anyAnnotation(predicate: Annotation.() -> Boolean) = - annotations.any(predicate) +fun ClassDef.anyAnnotation(predicate: Annotation.() -> Boolean) = annotations.any(predicate) fun Method.implementation(predicate: MethodImplementation.() -> Boolean) = implementation?.predicate() ?: false @@ -56,99 +67,499 @@ private typealias ClassDefPredicate = context(PredicateContext) ClassDef.() -> B private typealias MethodPredicate = context(PredicateContext) Method.() -> Boolean -fun BytecodePatchContext.firstClassDefOrNull( - type: String? = null, predicate: ClassDefPredicate = { true } -) = with(PredicateContext()) { - if (type == null) classDefs.firstOrNull { it.predicate() } - else classDefs[type]?.takeIf { it.predicate() } -} -fun BytecodePatchContext.firstClassDef( - type: String? = null, - predicate: ClassDefPredicate = { true } -) = requireNotNull(firstClassDefOrNull(type, predicate)) +inline fun PredicateContext.remember(key: Any, defaultValue: () -> V) = if (key in this) get(key) as V +else defaultValue().also { put(key, it) } -fun BytecodePatchContext.firstClassDefMutableOrNull( - type: String? = null, - predicate: ClassDefPredicate = { true } -) = firstClassDefOrNull(type, predicate)?.let { classDefs.getOrReplaceMutable(it) } +private fun cachedReadOnlyProperty(block: BytecodePatchContext.(KProperty<*>) -> T) = + JVMConflict.cachedReadOnlyProperty(block) -fun BytecodePatchContext.firstClassDefMutable( - type: String? = null, - predicate: ClassDefPredicate = { true } -) = requireNotNull(firstClassDefMutableOrNull(type, predicate)) +private object JVMConflict { + fun cachedReadOnlyProperty(block: R.(KProperty<*>) -> T) = object : ReadOnlyProperty { + private val cache = HashMap(1) -fun BytecodePatchContext.firstMethodOrNull( - vararg strings: String, - predicate: MethodPredicate = { true }, -): Method? = with(PredicateContext()) { - if (strings.isEmpty()) - return classDefs.asSequence().flatMap { it.methods.asSequence() }.firstOrNull { it.predicate() } - - val methodsWithStrings = strings.mapNotNull { classDefs.methodsByString[it] } - if (methodsWithStrings.size != strings.size) return null - - return methodsWithStrings.minBy { it.size }.firstOrNull { method -> - val containsAllOtherStrings = methodsWithStrings.all { method in it } - containsAllOtherStrings && method.predicate() + override fun getValue(thisRef: R, property: KProperty<*>) = + (if (thisRef in cache) cache[thisRef] else cache.getOrPut(thisRef) { thisRef.block(property) }) } } -fun BytecodePatchContext.firstMethod( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = requireNotNull(firstMethodOrNull(strings = strings, predicate)) +class MutablePredicateList internal constructor() : MutableList Boolean> by mutableListOf() -fun BytecodePatchContext.firstMethodMutableOrNull( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = firstMethodOrNull(strings = strings, predicate)?.let { method -> - firstClassDefMutable(method.definingClass).methods.first { - MethodUtil.methodSignaturesMatch(method, it) +private typealias DeclarativePredicate = context(PredicateContext) MutablePredicateList.() -> Unit + + +fun T.declarativePredicate(build: MutablePredicateList.() -> Unit) = + context(MutablePredicateList().apply(build)) { + all(this) } + +context(context: PredicateContext) +fun T.rememberDeclarativePredicate(key: Any, block: MutablePredicateList.() -> Unit) = + context(context.remember(key) { MutablePredicateList().apply(block) }) { + all(this) + } + +context(_: PredicateContext) +private fun T.rememberDeclarativePredicate( + predicate: DeclarativePredicate +) = rememberDeclarativePredicate("declarativePredicate") { predicate() } + +object IterableMethodMethodMatching { + fun Iterable.firstMethodOrNull( + methodReference: MethodReference + ) = firstOrNull { MethodUtil.methodSignaturesMatch(methodReference, it) } + + fun Iterable.firstMethod( + methodReference: MethodReference + ) = requireNotNull(firstMethodOrNull(methodReference)) + + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethodOrNull( + methodReference: MethodReference + ) = firstMethodOrNull(methodReference)?.let { context.firstMutableMethod(it) } + + context(_: BytecodePatchContext) + fun Iterable.firstMutableMethod( + methodReference: MethodReference + ) = requireNotNull(firstMutableMethodOrNull(methodReference)) + + fun Iterable.firstMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = if (strings.isEmpty()) withPredicateContext { firstOrNull { it.predicate() } } + else withPredicateContext { + first { method -> + val instructions = method.instructionsOrNull ?: return@first false + + // TODO: Check potential to optimize (Set or not). + // Maybe even use context maps, but the methods may not be present in the context yet. + val methodStrings = instructions.asSequence().mapNotNull { it.string }.toSet() + + if (strings.any { it !in methodStrings }) return@first false + + method.predicate() + } + } + + fun Iterable.firstMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(strings = strings, predicate)) + + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = firstMethodOrNull(strings = strings, predicate)?.let { context.firstMutableMethod(it) } + + context(_: BytecodePatchContext) + fun Iterable.firstMutableMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMutableMethodOrNull(strings = strings, predicate)) + + fun Iterable.firstMethodDeclarativelyOrNull( + predicate: DeclarativePredicate + ) = firstMethodOrNull { rememberDeclarativePredicate(predicate) } + + fun Iterable.firstMethodDeclaratively( + predicate: DeclarativePredicate + ) = requireNotNull(firstMethodDeclarativelyOrNull(predicate)) } -fun BytecodePatchContext.firstMethodMutable( - vararg strings: String, predicate: MethodPredicate = { true } -) = requireNotNull(firstMethodMutableOrNull(strings = strings, predicate)) +object IterableClassDefMethodMatching { + fun Iterable.firstMethodOrNull( + methodReference: MethodReference + ) = asSequence().flatMap { it.methods.asSequence() }.asIterable().firstMethodOrNull(methodReference) -fun gettingFirstClassDefOrNull( - type: String? = null, predicate: ClassDefPredicate = { true } -) = cachedReadOnlyProperty { firstClassDefOrNull(type, predicate) } + fun Iterable.firstMethod( + methodReference: MethodReference + ) = requireNotNull(firstMethodOrNull(methodReference)) -fun gettingFirstClassDef( - type: String? = null, predicate: ClassDefPredicate = { true } -) = cachedReadOnlyProperty { firstClassDef(type, predicate) } + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethodOrNull( + methodReference: MethodReference + ) = asSequence().flatMap { it.methods.asSequence() }.asIterable().firstMutableMethodOrNull(methodReference) -fun gettingFirstClassDefMutableOrNull( - type: String? = null, predicate: ClassDefPredicate = { true } -) = cachedReadOnlyProperty { firstClassDefMutableOrNull(type, predicate) } + context(_: BytecodePatchContext) + fun Iterable.firstMutableMethod( + methodReference: MethodReference + ) = requireNotNull(firstMutableMethodOrNull(methodReference)) -fun gettingFirstClassDefMutable( - type: String? = null, predicate: ClassDefPredicate = { true } -) = cachedReadOnlyProperty { firstClassDefMutable(type, predicate) } + fun Iterable.firstMethodOrNull( + predicate: MethodPredicate = { true }, + ) = asSequence().flatMap { it.methods.asSequence() }.asIterable() + .firstMethodOrNull(strings = emptyArray(), predicate) -fun gettingFirstMethodOrNull( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = cachedReadOnlyProperty { firstMethodOrNull(strings = strings, predicate) } + fun Iterable.firstMethod( + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(predicate)) -fun gettingFirstMethod( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = cachedReadOnlyProperty { firstMethod(strings = strings, predicate) } + fun Iterable.firstMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = asSequence().flatMap { it.methods.asSequence() }.asIterable().firstMethodOrNull(strings = strings, predicate) -fun gettingFirstMethodMutableOrNull( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = cachedReadOnlyProperty { firstMethodMutableOrNull(strings = strings, predicate) } + fun Iterable.firstMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(strings = strings, predicate)) -fun gettingFirstMethodMutable( - vararg strings: String, - predicate: MethodPredicate = { true }, -) = cachedReadOnlyProperty { firstMethodMutable(strings = strings, predicate) } + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = firstMethodOrNull(strings = strings, predicate)?.let { context.firstMutableMethod(it) } -class PredicateContext internal constructor() : MutableMap by mutableMapOf() + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMutableMethodOrNull(strings = strings, predicate)) + + fun Iterable.firstMethodDeclarativelyOrNull( + predicate: DeclarativePredicate + ) = firstMethodOrNull { rememberDeclarativePredicate(predicate) } + + fun Iterable.firstMethodDeclaratively( + predicate: DeclarativePredicate + ) = requireNotNull(firstMethodDeclarativelyOrNull(predicate)) + + context(context: BytecodePatchContext) + fun Iterable.firstMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate + ) = firstMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } + + context(context: BytecodePatchContext) + fun Iterable.firstMethodDeclaratively( + vararg strings: String, + predicate: DeclarativePredicate + ) = requireNotNull(firstMethodDeclarativelyOrNull(strings = strings, predicate)) + + context(context: BytecodePatchContext) + fun Iterable.firstMutableMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate + ) = firstMutableMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } +} + +object ClassDefMethodMatching { + fun ClassDef.firstMethodOrNull( + methodReference: MethodReference + ) = methods.firstMethodOrNull(methodReference) + + fun ClassDef.firstMethod( + methodReference: MethodReference + ) = requireNotNull(firstMethodOrNull(methodReference)) + + context(_: BytecodePatchContext) + fun ClassDef.firstMutableMethodOrNull( + methodReference: MethodReference + ) = methods.firstMutableMethodOrNull(methodReference) + + context(_: BytecodePatchContext) + fun ClassDef.firstMutableMethod( + methodReference: MethodReference + ) = requireNotNull(firstMutableMethodOrNull(methodReference)) + + fun ClassDef.firstMethodOrNull( + predicate: MethodPredicate = { true }, + ) = methods.firstMethodOrNull(strings = emptyArray(), predicate) + + fun ClassDef.firstMethod( + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(predicate)) + + fun ClassDef.firstMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = methods.firstMethodOrNull(strings = strings, predicate) + + fun ClassDef.firstMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(strings = strings, predicate)) + + fun ClassDef.firstMethodDeclarativelyOrNull( + predicate: DeclarativePredicate + ) = methods.firstMethodDeclarativelyOrNull(predicate) + + fun ClassDef.firstMethodDeclaratively( + predicate: DeclarativePredicate + ) = requireNotNull(firstMethodDeclarativelyOrNull(predicate)) +} + +object IterableClassDefClassDefMatching { + fun Iterable.firstClassDefOrNull( + predicate: ClassDefPredicate = { true } + ) = withPredicateContext { firstOrNull { it.predicate() } } + + fun Iterable.firstClassDef( + predicate: ClassDefPredicate = { true } + ) = requireNotNull(firstClassDefOrNull(predicate)) + + fun Iterable.firstClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = if (type == null) firstClassDefOrNull(predicate) + else withPredicateContext { firstOrNull { it.type == type && it.predicate() } } + + fun Iterable.firstClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(firstClassDefOrNull(type, predicate)) + + context(context: BytecodePatchContext) + fun Iterable.firstMutableClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = firstClassDefOrNull(type, predicate)?.let { context.classDefs.getOrReplaceMutable(it) } + + context(_: BytecodePatchContext) + fun Iterable.firstMutableClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(firstMutableClassDefOrNull(type, predicate)) + + fun Iterable.firstClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate + ) = firstClassDefOrNull(type) { rememberDeclarativePredicate(predicate) } + + fun Iterable.firstClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate + ) = requireNotNull(firstClassDefDeclarativelyOrNull(type, predicate)) + + context(_: BytecodePatchContext) + fun Iterable.firstMutableClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate + ) = firstMutableClassDefOrNull(type) { rememberDeclarativePredicate(predicate) } + + context(_: BytecodePatchContext) + fun Iterable.firstMutableClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate + ) = requireNotNull(firstMutableClassDefDeclarativelyOrNull(type, predicate)) +} + +object BytecodePatchContextMethodMatching { + fun BytecodePatchContext.firstMethodOrNull( + methodReference: MethodReference + ) = firstClassDefOrNull(methodReference.definingClass)?.methods?.firstMethodOrNull(methodReference) + + fun BytecodePatchContext.firstMethod( + method: MethodReference + ) = requireNotNull(firstMethodOrNull(method)) + + fun BytecodePatchContext.firstMutableMethodOrNull( + methodReference: MethodReference + ): MutableMethod? = firstMutableClassDefOrNull(methodReference.definingClass)?.methods + ?.first { MethodUtil.methodSignaturesMatch(methodReference, it) } + + fun BytecodePatchContext.firstMutableMethod( + method: MethodReference + ) = requireNotNull(firstMutableMethodOrNull(method)) + + fun BytecodePatchContext.firstMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ): Method? = withPredicateContext { + if (strings.isEmpty()) return classDefs.firstMethodOrNull(predicate) + + val methodsWithStrings = strings.mapNotNull { classDefs.methodsByString[it] } + if (methodsWithStrings.size != strings.size) return null + + return methodsWithStrings.minBy { it.size }.firstOrNull { method -> + val containsAllOtherStrings = methodsWithStrings.all { method in it } + containsAllOtherStrings && method.predicate() + } + } + + fun BytecodePatchContext.firstMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = requireNotNull(firstMethodOrNull(strings = strings, predicate)) + + fun BytecodePatchContext.firstMutableMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = firstMethodOrNull(strings = strings, predicate)?.let { method -> + firstMutableMethodOrNull(method) + } + + fun BytecodePatchContext.firstMutableMethod( + vararg strings: String, + predicate: MethodPredicate = { true } + ) = requireNotNull(firstMutableMethodOrNull(strings = strings, predicate)) + + fun gettingFirstMethodOrNull( + method: MethodReference + ) = cachedReadOnlyProperty { firstMethodOrNull(method) } + + fun gettingFirstMethod( + method: MethodReference + ) = cachedReadOnlyProperty { firstMethod(method) } + + fun gettingFirstMutableMethodOrNull( + method: MethodReference + ) = cachedReadOnlyProperty { firstMutableMethodOrNull(method) } + + fun gettingFirstMutableMethod( + method: MethodReference + ) = cachedReadOnlyProperty { firstMutableMethod(method) } + + fun gettingFirstMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = cachedReadOnlyProperty { firstMethodOrNull(strings = strings, predicate) } + + fun gettingFirstMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = cachedReadOnlyProperty { firstMethod(strings = strings, predicate) } + + fun gettingFirstMutableMethodOrNull( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = cachedReadOnlyProperty { firstMutableMethodOrNull(strings = strings, predicate) } + + fun gettingFirstMutableMethod( + vararg strings: String, + predicate: MethodPredicate = { true }, + ) = cachedReadOnlyProperty { firstMutableMethod(strings = strings, predicate) } + + fun BytecodePatchContext.firstMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = firstMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } + + fun BytecodePatchContext.firstMethodDeclaratively( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = requireNotNull(firstMethodDeclarativelyOrNull(strings = strings, predicate)) + + fun BytecodePatchContext.firstMutableMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = firstMutableMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } + + fun BytecodePatchContext.firstMutableMethodDeclaratively( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = requireNotNull(firstMutableMethodDeclarativelyOrNull(strings = strings, predicate)) + + fun gettingFirstMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = gettingFirstMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } + + fun gettingFirstMethodDeclaratively( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = gettingFirstMethod(strings = strings) { rememberDeclarativePredicate(predicate) } + + fun gettingFirstMutableMethodDeclarativelyOrNull( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = gettingFirstMutableMethodOrNull(strings = strings) { rememberDeclarativePredicate(predicate) } + + fun gettingFirstMutableMethodDeclaratively( + vararg strings: String, + predicate: DeclarativePredicate = { } + ) = gettingFirstMutableMethod(strings = strings) { rememberDeclarativePredicate(predicate) } +} + +object BytecodePatchContextClassDefMatching { + fun BytecodePatchContext.firstClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = withPredicateContext { + if (type == null) classDefs.firstClassDefOrNull(predicate) + else classDefs[type]?.takeIf { it.predicate() } + } + + fun BytecodePatchContext.firstClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(firstClassDefOrNull(type, predicate)) + + fun BytecodePatchContext.firstMutableClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = firstClassDefOrNull(type, predicate)?.let { classDefs.getOrReplaceMutable(it) } + + fun BytecodePatchContext.firstMutableClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(firstMutableClassDefOrNull(type, predicate)) + + fun gettingFirstClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = cachedReadOnlyProperty { firstClassDefOrNull(type, predicate) } + + fun gettingFirstClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(gettingFirstClassDefOrNull(type, predicate)) + + fun gettingFirstMutableClassDefOrNull( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = cachedReadOnlyProperty { firstMutableClassDefOrNull(type, predicate) } + + fun gettingFirstMutableClassDef( + type: String? = null, + predicate: ClassDefPredicate = { true } + ) = requireNotNull(gettingFirstMutableClassDefOrNull(type, predicate)) + + fun BytecodePatchContext.firstClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = firstClassDefOrNull(type) { rememberDeclarativePredicate(predicate) } + + fun BytecodePatchContext.firstClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = requireNotNull(firstClassDefDeclarativelyOrNull(type, predicate)) + + fun BytecodePatchContext.firstMutableClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = firstMutableClassDefOrNull(type) { rememberDeclarativePredicate(predicate) } + + fun BytecodePatchContext.firstMutableClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = requireNotNull(firstMutableClassDefDeclarativelyOrNull(type, predicate)) + + fun gettingFirstClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = cachedReadOnlyProperty { firstClassDefDeclarativelyOrNull(type, predicate) } + + fun gettingFirstClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = requireNotNull(gettingFirstClassDefDeclarativelyOrNull(type, predicate)) + + fun gettingFirstMutableClassDefDeclarativelyOrNull( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = cachedReadOnlyProperty { firstMutableClassDefDeclarativelyOrNull(type, predicate) } + + fun gettingFirstMutableClassDefDeclaratively( + type: String? = null, + predicate: DeclarativePredicate = { } + ) = requireNotNull(gettingFirstMutableClassDefDeclarativelyOrNull(type, predicate)) +} + +class PredicateContext internal constructor() : MutableMap by mutableMapOf() + +private inline fun withPredicateContext(block: PredicateContext.() -> T) = PredicateContext().block() // region Matcher @@ -163,17 +574,15 @@ fun Iterable.matchIndexed(build: IndexedMatcher.() -> Unit) = indexedMatcher(build)(this) context(_: PredicateContext) -fun Iterable.rememberedMatchIndexed(key: Any, build: IndexedMatcher.() -> Unit) = +fun Iterable.rememberMatchIndexed(key: Any, build: IndexedMatcher.() -> Unit) = indexedMatcher()(key, this, build) -context(_: IndexedMatcher) fun head( predicate: T.(lastMatchedIndex: Int, currentIndex: Int) -> Boolean ): T.(Int, Int) -> Boolean = { lastMatchedIndex, currentIndex -> currentIndex == 0 && predicate(lastMatchedIndex, currentIndex) } -context(_: IndexedMatcher) fun head(predicate: T.() -> Boolean): T.(Int, Int) -> Boolean = head { _, _ -> predicate() } @@ -284,12 +693,12 @@ class IndexedMatcher : Matcher> M.invoke( key: Any, iterable: Iterable, builder: M.() -> Unit -) = remembered(key) { apply(builder) }(iterable) +) = context.remember(key) { apply(builder) }(iterable) context(_: PredicateContext) inline operator fun > M.invoke( @@ -306,198 +715,81 @@ abstract class Matcher : MutableList by mutableListOf() { // endregion Matcher -context(context: PredicateContext) - -inline fun remembered(key: Any, defaultValue: () -> V) = - context[key] as? V ?: defaultValue().also { context[key] = it } - -private fun cachedReadOnlyProperty(block: BytecodePatchContext.(KProperty<*>) -> T) = - object : ReadOnlyProperty { - private var value: T? = null - private var cached = false - - override fun getValue(thisRef: BytecodePatchContext, property: KProperty<*>): T { - if (!cached) { - value = thisRef.block(property) - cached = true - } - - return value!! - } - } - -private typealias DeclarativeClassDefPredicate = context(PredicateContext, MutableList Boolean>) () -> Unit - -private typealias DeclarativeMethodPredicate = context(PredicateContext, MutableList Boolean>) () -> Unit - -fun T.declarativePredicate(build: context(MutableList Boolean>) () -> Unit) = - context(mutableListOf Boolean>().apply(build)) { - all(this) - } - -context(_: PredicateContext) -fun T.rememberedDeclarativePredicate(key: Any, block: context(MutableList Boolean>) () -> Unit) = - context(remembered(key) { mutableListOf Boolean>().apply(block) }) { - all(this) - } - -context(_: PredicateContext) -private fun T.rememberedDeclarativePredicate( - predicate: context(PredicateContext, MutableList Boolean>) () -> Unit -) = rememberedDeclarativePredicate("declarativePredicate") { predicate() } - -fun BytecodePatchContext.firstClassDefByDeclarativePredicateOrNull( - predicate: DeclarativeClassDefPredicate -) = firstClassDefOrNull { rememberedDeclarativePredicate(predicate) } - -fun BytecodePatchContext.firstClassDefByDeclarativePredicateOrNull( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = firstClassDefOrNull(type) { rememberedDeclarativePredicate(predicate) } - -fun BytecodePatchContext.firstClassDefByDeclarativePredicate( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = requireNotNull(firstClassDefByDeclarativePredicateOrNull(type, predicate)) - -fun BytecodePatchContext.firstClassDefMutableByDeclarativePredicateOrNull( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = firstClassDefMutableOrNull(type) { rememberedDeclarativePredicate(predicate) } - -fun BytecodePatchContext.firstClassDefMutableByDeclarativePredicate( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = requireNotNull(firstClassDefMutableByDeclarativePredicateOrNull(type, predicate)) - -fun BytecodePatchContext.firstMethodByDeclarativePredicateOrNull( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = firstMethodOrNull(strings = strings) { rememberedDeclarativePredicate(predicate) } - -fun BytecodePatchContext.firstMethodByDeclarativePredicate( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = requireNotNull(firstMethodByDeclarativePredicateOrNull(strings = strings, predicate)) - -fun BytecodePatchContext.firstMethodMutableByDeclarativePredicateOrNull( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = firstMethodMutableOrNull(strings = strings) { rememberedDeclarativePredicate(predicate) } - -fun BytecodePatchContext.firstMethodMutableByDeclarativePredicate( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = requireNotNull(firstMethodMutableByDeclarativePredicateOrNull(strings = strings, predicate)) - -fun gettingFirstClassDefByDeclarativePredicateOrNull( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = gettingFirstClassDefOrNull(type) { rememberedDeclarativePredicate(predicate) } - -fun gettingFirstClassDefByDeclarativePredicate( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = cachedReadOnlyProperty { firstClassDefByDeclarativePredicate(type, predicate) } - -fun gettingFirstClassDefMutableByDeclarativePredicateOrNull( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = gettingFirstClassDefMutableOrNull(type) { rememberedDeclarativePredicate(predicate) } - -fun gettingFirstClassDefMutableByDeclarativePredicate( - type: String? = null, - predicate: DeclarativeClassDefPredicate = { } -) = cachedReadOnlyProperty { firstClassDefMutableByDeclarativePredicate(type, predicate) } - -fun gettingFirstMethodByDeclarativePredicateOrNull( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = gettingFirstMethodOrNull(strings = strings) { rememberedDeclarativePredicate(predicate) } - -fun gettingFirstMethodByDeclarativePredicate( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = cachedReadOnlyProperty { firstMethodByDeclarativePredicate(strings = strings, predicate) } - -fun gettingFirstMethodMutableByDeclarativePredicateOrNull( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = gettingFirstMethodMutableOrNull(strings = strings) { rememberedDeclarativePredicate(predicate) } - -fun gettingFirstMethodMutableByDeclarativePredicate( - vararg strings: String, - predicate: DeclarativeMethodPredicate = { } -) = cachedReadOnlyProperty { firstMethodMutableByDeclarativePredicate(strings = strings, predicate) } - -context(list: MutableList Boolean>) -fun allOf(block: MutableList Boolean>.() -> Unit) { - val child = mutableListOf Boolean>().apply(block) +context(list: MutablePredicateList) +fun allOf(block: MutablePredicateList.() -> Unit) { + val child = MutablePredicateList().apply(block) list.add { child.all { it() } } } -context(list: MutableList Boolean>) -fun anyOf(block: MutableList Boolean>.() -> Unit) { - val child = mutableListOf Boolean>().apply(block) +context(list: MutablePredicateList) +fun anyOf(block: MutablePredicateList.() -> Unit) { + val child = MutablePredicateList().apply(block) list.add { child.any { it() } } } -context(list: MutableList Boolean>) +context(list: MutablePredicateList) fun predicate(block: T.() -> Boolean) { list.add(block) } -context(list: MutableList Boolean>) +context(list: MutablePredicateList) fun all(target: T): Boolean = list.all { target.it() } -context(list: MutableList Boolean>) +context(list: MutablePredicateList) fun any(target: T): Boolean = list.any { target.it() } -context(_: MutableList Boolean>) -fun accessFlags(vararg flags: AccessFlags) = +fun MutablePredicateList.accessFlags(vararg flags: AccessFlags) = predicate { accessFlags(flags = flags) } -context(_: MutableList Boolean>) -fun returnType( +fun MutablePredicateList.returnType( returnType: String, compare: String.(String) -> Boolean = String::startsWith ) = predicate { this.returnType.compare(returnType) } -context(_: MutableList Boolean>) -fun name( +fun MutablePredicateList.name( name: String, compare: String.(String) -> Boolean = String::equals ) = predicate { this.name.compare(name) } -context(_: MutableList Boolean>) -fun definingClass( +fun MutablePredicateList.definingClass( definingClass: String, compare: String.(String) -> Boolean = String::equals ) = predicate { this.definingClass.compare(definingClass) } -context(_: MutableList Boolean>) -fun parameterTypes(vararg parameterTypePrefixes: String) = predicate { +fun MutablePredicateList.parameterTypes(vararg parameterTypePrefixes: String) = predicate { parameterTypes.size == parameterTypePrefixes.size && parameterTypes.zip(parameterTypePrefixes) .all { (a, b) -> a.startsWith(b) } } -context(_: MutableList Boolean>, matcher: IndexedMatcher) -fun instructions( - build: context(IndexedMatcher) () -> Unit +fun MutablePredicateList.instructions( + build: IndexedMatcher.() -> Unit ) { - build() - predicate { implementation { matcher(instructions) } } + val match = indexedMatcher() + predicate { implementation { match(instructions) } } } -context(_: MutableList Boolean>, matcher: IndexedMatcher) -fun instructions( +fun MutablePredicateList.instructions( vararg predicates: Instruction.(currentIndex: Int, lastMatchedIndex: Int) -> Boolean ) = instructions { predicates.forEach { +it } } -context(_: MutableList Boolean>) -fun custom(block: Method.() -> Boolean) { +context(matcher: IndexedMatcher) +fun MutablePredicateList.instructions( + build: IndexedMatcher.() -> Unit +) { + matcher.build() + predicate { implementation { matcher(instructions) } } +} + +context(matcher: IndexedMatcher) +fun MutablePredicateList.instructions( + vararg predicates: Instruction.(currentIndex: Int, lastMatchedIndex: Int) -> Boolean +) = instructions { + predicates.forEach { +it } +} + +fun MutablePredicateList.custom(block: Method.() -> Boolean) { predicate { block() } } @@ -612,49 +904,121 @@ fun noneOf( predicates.none { predicate -> predicate(currentIndex, lastMatchedIndex) } } -fun firstMethodBuilder( +private typealias BuildDeclarativePredicate = context( +PredicateContext, +IndexedMatcher, +MutableList +) MutablePredicateList.() -> Unit + +fun firstMethodComposite( vararg strings: String, - builder: - context(PredicateContext, MutableList Boolean>, IndexedMatcher, MutableList)() -> Unit -) = Match(strings = strings, builder) + build: BuildDeclarativePredicate +) = MatchBuilder(strings = strings, build) -class Match private constructor( - private val strings: MutableList, - indexedMatcher: IndexedMatcher = indexedMatcher(), - build: context( - PredicateContext, MutableList Boolean>, - IndexedMatcher, MutableList) () -> Unit +val a = firstMethodComposite { + name("exampleMethod") + definingClass("Lcom/example/MyClass;") + returnType("V") + instructions( + head(Opcode.RETURN_VOID()), + after(1..5, Opcode.INVOKE_VIRTUAL()) + ) +} + +class MatchBuilder private constructor( + val strings: MutableList, + indexedMatcher: IndexedMatcher, + build: BuildDeclarativePredicate, ) { - internal constructor( - vararg strings: String, - builder: context( - PredicateContext, MutableList Boolean>, - IndexedMatcher, MutableList) () -> Unit - ) : this(strings = mutableListOf(elements = strings), build = builder) - private val methodOrNullMap = HashMap(1) + internal constructor(vararg strings: String, build: BuildDeclarativePredicate) : + this(strings = mutableListOf(elements = strings), indexedMatcher(), build) - private val predicate: DeclarativeMethodPredicate = context(strings, indexedMatcher) { { build() } } + private val predicate: DeclarativePredicate = context(strings, indexedMatcher) { { build() } } + + private val indices = indexedMatcher.indices + + private val BytecodePatchContext.cachedImmutableMethodOrNull + by gettingFirstMethodDeclarativelyOrNull(strings = strings.toTypedArray(), predicate) + + private val BytecodePatchContext.cachedImmutableClassDefOrNull by cachedReadOnlyProperty { + val type = cachedImmutableMethodOrNull?.definingClass ?: return@cachedReadOnlyProperty null + firstClassDefOrNull(type) + } context(context: BytecodePatchContext) + val immutableMethodOrNull get() = context.cachedImmutableMethodOrNull - val methodOrNull: MutableMethod? - get() = if (context in methodOrNullMap) methodOrNullMap[context] - else methodOrNullMap.getOrPut(context) { - context.firstMethodMutableByDeclarativePredicateOrNull( - strings = strings.toTypedArray(), - predicate - ) - } + context(_: BytecodePatchContext) + val immutableMethod get() = requireNotNull(immutableMethodOrNull) + + context(context: BytecodePatchContext) + val immutableClassDefOrNull get() = context.cachedImmutableClassDefOrNull + + context(context: BytecodePatchContext) + val immutableClassDef get() = requireNotNull(immutableClassDefOrNull) + + val BytecodePatchContext.cachedMethodOrNull by cachedReadOnlyProperty { + firstMutableMethodOrNull(immutableMethodOrNull ?: return@cachedReadOnlyProperty null) + } + + private val BytecodePatchContext.cachedClassDefOrNull by cachedReadOnlyProperty { + val type = immutableMethodOrNull?.definingClass ?: return@cachedReadOnlyProperty null + firstMutableClassDefOrNull(type) + } + + context(context: BytecodePatchContext) + val methodOrNull get() = context.cachedMethodOrNull context(_: BytecodePatchContext) val method get() = requireNotNull(methodOrNull) context(context: BytecodePatchContext) - val classDefOrNull get() = methodOrNull?.definingClass?.let(context::firstClassDefOrNull) + val classDefOrNull get() = context.cachedClassDefOrNull context(_: BytecodePatchContext) val classDef get() = requireNotNull(classDefOrNull) - val indices = indexedMatcher.indices + context(context: BytecodePatchContext) + fun match(classDef: ClassDef) = Match( + context, + classDef.firstMethodDeclarativelyOrNull { predicate() }, + indices.toList() + ) } + +class Match( + val context: BytecodePatchContext, + val immutableMethodOrNull: Method?, + val indices: List +) { + val immutableMethod by lazy { requireNotNull(immutableMethodOrNull) } + + val methodOrNull by lazy { + context.firstMutableMethodOrNull(immutableMethodOrNull ?: return@lazy null) + } + + val method by lazy { requireNotNull(methodOrNull) } + + val immutableClassDefOrNull by lazy { context(context) { immutableMethodOrNull?.immutableClassDefOrNull } } + + val immutableClassDef by lazy { requireNotNull(context(context) { immutableMethod.immutableClassDef }) } + + val classDefOrNull by lazy { + context.firstMutableClassDefOrNull(immutableMethodOrNull?.definingClass ?: return@lazy null) + } + + val classDef by lazy { requireNotNull(classDefOrNull) } +} + +context(context: BytecodePatchContext) +val Method.immutableClassDefOrNull get() = context.classDefs[definingClass] + +context(_: BytecodePatchContext) +val Method.immutableClassDef get() = requireNotNull(immutableClassDefOrNull) + +context(context: BytecodePatchContext) +val Method.classDefOrNull get() = context.firstMutableClassDefOrNull(definingClass) + +context(_: BytecodePatchContext) +val Method.classDef get() = requireNotNull(classDefOrNull) diff --git a/patcher/src/commonMain/kotlin/app/revanced/patcher/extensions/Method.kt b/patcher/src/commonMain/kotlin/app/revanced/patcher/extensions/Method.kt index 55ff3a3..3b8f87c 100644 --- a/patcher/src/commonMain/kotlin/app/revanced/patcher/extensions/Method.kt +++ b/patcher/src/commonMain/kotlin/app/revanced/patcher/extensions/Method.kt @@ -1,6 +1,5 @@ package app.revanced.patcher.extensions -import com.android.tools.smali.dexlib2.mutable.MutableMethod import com.android.tools.smali.dexlib2.AccessFlags import com.android.tools.smali.dexlib2.builder.BuilderInstruction import com.android.tools.smali.dexlib2.builder.BuilderOffsetInstruction @@ -10,6 +9,7 @@ import com.android.tools.smali.dexlib2.builder.instruction.* import com.android.tools.smali.dexlib2.iface.Method import com.android.tools.smali.dexlib2.iface.MethodImplementation import com.android.tools.smali.dexlib2.iface.instruction.Instruction +import com.android.tools.smali.dexlib2.mutable.MutableMethod fun Method.accessFlags(vararg flags: AccessFlags) = accessFlags.and(flags.map { it.ordinal }.reduce { acc, i -> acc or i }) != 0 diff --git a/patcher/src/commonMain/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt b/patcher/src/commonMain/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt index 842b61f..7ed366f 100644 --- a/patcher/src/commonMain/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt +++ b/patcher/src/commonMain/kotlin/app/revanced/patcher/patch/BytecodePatchContext.kt @@ -2,14 +2,13 @@ package app.revanced.patcher.patch import app.revanced.patcher.PatchesResult import app.revanced.patcher.extensions.instructionsOrNull +import app.revanced.patcher.extensions.string import app.revanced.patcher.util.ClassMerger.merge import app.revanced.patcher.util.MethodNavigator import com.android.tools.smali.dexlib2.iface.ClassDef import com.android.tools.smali.dexlib2.iface.DexFile import com.android.tools.smali.dexlib2.iface.Method -import com.android.tools.smali.dexlib2.iface.instruction.ReferenceInstruction import com.android.tools.smali.dexlib2.iface.reference.MethodReference -import com.android.tools.smali.dexlib2.iface.reference.StringReference import com.android.tools.smali.dexlib2.mutable.MutableClassDef import com.android.tools.smali.dexlib2.mutable.MutableClassDef.Companion.toMutable import lanchon.multidexlib2.BasicDexFileNamer @@ -141,13 +140,8 @@ class BytecodePatchContext internal constructor( private fun ClassDef.forEachString(action: (Method, String) -> Unit) { methods.asSequence().forEach { method -> method.instructionsOrNull?.asSequence() - ?.filterIsInstance() - ?.map { it.reference } - ?.filterIsInstance() - ?.map { it.string } - ?.forEach { string -> - action(method, string) - } + ?.mapNotNull { it.string } + ?.forEach { string -> action(method, string) } } } diff --git a/patcher/src/jvmTest/kotlin/app/revanced/patcher/MatchingTest.kt b/patcher/src/jvmTest/kotlin/app/revanced/patcher/MatchingTest.kt index fd2df01..81282ec 100644 --- a/patcher/src/jvmTest/kotlin/app/revanced/patcher/MatchingTest.kt +++ b/patcher/src/jvmTest/kotlin/app/revanced/patcher/MatchingTest.kt @@ -1,8 +1,11 @@ package app.revanced.patcher +import app.revanced.patcher.BytecodePatchContextMethodMatching.firstMethod +import app.revanced.patcher.BytecodePatchContextMethodMatching.firstMethodDeclarativelyOrNull import app.revanced.patcher.patch.bytecodePatch import com.android.tools.smali.dexlib2.Opcode import com.android.tools.smali.dexlib2.iface.instruction.TwoRegisterInstruction +import com.android.tools.smali.dexlib2.immutable.ImmutableClassDef import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance @@ -13,13 +16,13 @@ import kotlin.test.assertFalse import kotlin.test.assertNotNull @TestInstance(TestInstance.Lifecycle.PER_CLASS) -object MatchingTest : PatcherTestBase() { +class MatchingTest : PatcherTestBase() { @BeforeAll fun setup() = setupMock() @Test fun `finds via builder api`() { - fun firstMethodBuilder(fail: Boolean = false) = firstMethodBuilder { + fun firstMethodComposite(fail: Boolean = false) = firstMethodComposite { name("method") definingClass("class") @@ -36,19 +39,20 @@ object MatchingTest : PatcherTestBase() { ) } - bytecodePatch { - apply { - assertNotNull(firstMethodBuilder().methodOrNull) { "Expected to find a method" } - Assertions.assertNull(firstMethodBuilder(fail = true).methodOrNull) { "Expected to not find a method" } - } - }() + with(bytecodePatchContext) { + assertNotNull(firstMethodComposite().methodOrNull) { "Expected to find a method" } + Assertions.assertNull(firstMethodComposite(fail = true).immutableMethodOrNull) { "Expected to not find a method" } + Assertions.assertNotNull( + firstMethodComposite().match(classDefs.first()).methodOrNull + ) { "Expected to find a method matching in a specific class" } + } } @Test fun `finds via declarative api`() { bytecodePatch { apply { - val method = firstMethodByDeclarativePredicateOrNull { + val method = firstMethodDeclarativelyOrNull { anyOf { predicate { name == "method" } add { false } @@ -78,7 +82,7 @@ object MatchingTest : PatcherTestBase() { val matcher = indexedMatcher() matcher.apply { - +head { this > 5 } + +head { this > 5 } } assertFalse( matcher(iterable), @@ -86,7 +90,7 @@ object MatchingTest : PatcherTestBase() { ) matcher.clear() - matcher.apply { +head { this == 1 } }(iterable) + matcher.apply { +head { this == 1 } }(iterable) assertEquals( listOf(0), matcher.indices, @@ -107,7 +111,7 @@ object MatchingTest : PatcherTestBase() { matcher.clear() matcher.apply { - +head { this == 1 } + +head { this == 1 } add { _, _ -> this == 2 } add { _, _ -> this == 4 } }(iterable) @@ -147,7 +151,7 @@ object MatchingTest : PatcherTestBase() { matcher.clear() matcher.apply { - +head { this == 1 } + +head { this == 1 } +after(2..5) { this == 4 } add { _, _ -> this == 8 } add { _, _ -> this == 9 }