diff --git a/src/main/kotlin/app/revanced/patcher/Matching.kt b/src/main/kotlin/app/revanced/patcher/Matching.kt index 85b3037..dc84f74 100644 --- a/src/main/kotlin/app/revanced/patcher/Matching.kt +++ b/src/main/kotlin/app/revanced/patcher/Matching.kt @@ -50,6 +50,7 @@ fun MethodImplementation.anyTryBlock(predicate: TryBlock.( fun MethodImplementation.anyDebugItem(predicate: Any.() -> Boolean) = debugItems.any(predicate) fun Iterable.anyInstruction(predicate: Instruction.() -> Boolean) = any(predicate) + fun BytecodePatchContext.firstClassDefOrNull(predicate: context(MatchContext) ClassDef.() -> Boolean) = with(MatchContext()) { classDefs.firstOrNull { it.predicate() } } @@ -88,33 +89,25 @@ fun Iterable.firstMethodOrNull(predicate: context(MatchContext) Method fun Iterable.firstMethod(predicate: context(MatchContext) Method.() -> Boolean) = requireNotNull(firstMethodOrNull(predicate)) -context(context: BytecodePatchContext) -fun Iterable.firstMethodMutableOrNull(predicate: context(MatchContext) Method.() -> Boolean): MutableMethod? = - with(context) { - with(MatchContext()) { - this@firstMethodMutableOrNull.forEach { classDef -> - classDef.methods.firstOrNull { it.predicate() }?.let { method -> - return classDef.mutable().methods.first { MethodUtil.methodSignaturesMatch(it, method) } - } - } - - null - } - } - -context(_: BytecodePatchContext) -fun Iterable.firstMethodMutable(predicate: context(MatchContext) Method.() -> Boolean) = - requireNotNull(firstMethodMutableOrNull(predicate)) +/** Can't compile due to JVM platform declaration clash +fun Iterable.firstMethodOrNull(predicate: context(MatchContext) Method.() -> Boolean) = +with(MatchContext()) { firstOrNull { it.predicate() } } +fun Iterable.firstMethod(predicate: context(MatchContext) Method.() -> Boolean) = +with(MatchContext()) { requireNotNull(firstMethodOrNull(predicate)) } + **/ fun BytecodePatchContext.firstMethodOrNull(predicate: context(MatchContext) Method.() -> Boolean) = classDefs.firstMethodOrNull(predicate) fun BytecodePatchContext.firstMethod(predicate: context(MatchContext) Method.() -> Boolean) = requireNotNull(firstMethodOrNull(predicate)) - fun BytecodePatchContext.firstMethodMutableOrNull(predicate: context(MatchContext) Method.() -> Boolean) = - classDefs.firstMethodMutableOrNull(predicate) + classDefs.firstMethodOrNull(predicate)?.let { method -> + lookupMaps.classDefsByType[method.definingClass]!!.mutable().methods.first { + MethodUtil.methodSignaturesMatch(method, it) + } + } fun BytecodePatchContext.firstMethodMutable(predicate: context(MatchContext) Method.() -> Boolean) = requireNotNull(firstMethodMutableOrNull(predicate)) @@ -150,72 +143,81 @@ fun BytecodePatchContext.firstMethodMutable( vararg strings: String, predicate: context(MatchContext) Method.() -> Boolean = { true } ) = requireNotNull(firstMethodMutableOrNull(*strings, predicate = predicate)) -inline fun ReadOnlyProperty(crossinline block: C.(KProperty<*>) -> T) = - ReadOnlyProperty { thisRef, property -> - require(thisRef is C) +class CachedReadOnlyProperty internal constructor( + private val block: BytecodePatchContext.(KProperty<*>) -> T +) : ReadOnlyProperty { + private var value: T? = null + private var cached = false - thisRef.block(property) + override fun getValue(thisRef: BytecodePatchContext, property: KProperty<*>): T { + if (!cached) { + value = thisRef.block(property) + cached = true + } + + return value!! } +} fun gettingFirstClassDefOrNull(predicate: context(MatchContext) ClassDef.() -> Boolean) = - ReadOnlyProperty { firstClassDefOrNull(predicate) } + CachedReadOnlyProperty { firstClassDefOrNull(predicate) } fun gettingFirstClassDef(predicate: context(MatchContext) ClassDef.() -> Boolean) = - requireNotNull(gettingFirstClassDefOrNull(predicate)) + CachedReadOnlyProperty { firstClassDef(predicate) } fun gettingFirstClassDefMutableOrNull(predicate: context(MatchContext) ClassDef.() -> Boolean) = - ReadOnlyProperty { firstClassDefMutableOrNull(predicate) } + CachedReadOnlyProperty { firstClassDefMutableOrNull(predicate) } fun gettingFirstClassDefMutable(predicate: context(MatchContext) ClassDef.() -> Boolean) = - requireNotNull(gettingFirstClassDefMutableOrNull(predicate)) + CachedReadOnlyProperty { firstClassDefMutable(predicate) } fun gettingFirstClassDefOrNull( type: String, predicate: (context(MatchContext) ClassDef.() -> Boolean)? = null -) = ReadOnlyProperty { firstClassDefOrNull(type, predicate) } +) = CachedReadOnlyProperty { firstClassDefOrNull(type, predicate) } fun gettingFirstClassDef( type: String, predicate: (context(MatchContext) ClassDef.() -> Boolean)? = null -) = requireNotNull(gettingFirstClassDefOrNull(type, predicate)) +) = CachedReadOnlyProperty { firstClassDef(type, predicate) } fun gettingFirstClassDefMutableOrNull( type: String, predicate: (context(MatchContext) ClassDef.() -> Boolean)? = null -) = ReadOnlyProperty { firstClassDefMutableOrNull(type, predicate) } +) = CachedReadOnlyProperty { firstClassDefMutableOrNull(type, predicate) } fun gettingFirstClassDefMutable( type: String, predicate: (context(MatchContext) ClassDef.() -> Boolean)? = null -) = requireNotNull(gettingFirstClassDefMutableOrNull(type, predicate)) +) = CachedReadOnlyProperty { firstClassDefMutable(type, predicate) } fun gettingFirstMethodOrNull(predicate: context(MatchContext) Method.() -> Boolean) = - ReadOnlyProperty { firstMethodOrNull(predicate) } + CachedReadOnlyProperty { firstMethodOrNull(predicate) } fun gettingFirstMethod(predicate: context(MatchContext) Method.() -> Boolean) = - requireNotNull(gettingFirstMethodOrNull(predicate)) + CachedReadOnlyProperty { firstMethod(predicate) } fun gettingFirstMethodMutableOrNull(predicate: context(MatchContext) Method.() -> Boolean) = - ReadOnlyProperty { firstMethodMutableOrNull(predicate) } + CachedReadOnlyProperty { firstMethodMutableOrNull(predicate) } fun gettingFirstMethodMutable(predicate: context(MatchContext) Method.() -> Boolean) = - requireNotNull(gettingFirstMethodMutableOrNull(predicate)) + CachedReadOnlyProperty { firstMethodMutable(predicate) } fun gettingFirstMethodOrNull( vararg strings: String, predicate: context(MatchContext) Method.() -> Boolean = { true }, -) = ReadOnlyProperty { firstMethodOrNull(*strings, predicate = predicate) } +) = CachedReadOnlyProperty { firstMethodOrNull(*strings, predicate = predicate) } fun gettingFirstMethod( vararg strings: String, predicate: context(MatchContext) Method.() -> Boolean = { true }, -) = requireNotNull(gettingFirstMethodOrNull(*strings, predicate = predicate)) +) = CachedReadOnlyProperty { firstMethod(*strings, predicate = predicate) } fun gettingFirstMethodMutableOrNull( vararg strings: String, predicate: context(MatchContext) Method.() -> Boolean = { true }, -) = ReadOnlyProperty { firstMethodMutableOrNull(*strings, predicate = predicate) } +) = CachedReadOnlyProperty { firstMethodMutableOrNull(*strings, predicate = predicate) } fun gettingFirstMethodMutable( vararg strings: String, predicate: context(MatchContext) Method.() -> Boolean = { true }, -) = requireNotNull(gettingFirstMethodMutableOrNull(*strings, predicate = predicate)) +) = CachedReadOnlyProperty { firstMethodMutable(*strings, predicate = predicate) } fun indexedMatcher() = IndexedMatcher() @@ -406,6 +408,22 @@ fun BytecodePatchContext.firstClassDefMutableByDeclarativePredicate( predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit ) = requireNotNull(firstClassDefMutableByDeclarativePredicateOrNull(type, predicate)) +fun BytecodePatchContext.firstMethodByDeclarativePredicateOrNull( + predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit +) = firstMethodOrNull { rememberDeclarativePredicate(predicate) } + +fun BytecodePatchContext.firstMethodByDeclarativePredicate( + predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit +) = requireNotNull(firstMethodByDeclarativePredicateOrNull(predicate)) + +fun BytecodePatchContext.firstMethodMutableByDeclarativePredicateOrNull( + predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit +) = firstMethodMutableOrNull { rememberDeclarativePredicate(predicate) } + +fun BytecodePatchContext.firstMethodMutableByDeclarativePredicate( + predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit +) = requireNotNull(firstMethodMutableByDeclarativePredicateOrNull(predicate)) + fun BytecodePatchContext.firstMethodByDeclarativePredicateOrNull( vararg strings: String, predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit @@ -434,7 +452,7 @@ fun gettingFirstClassDefByDeclarativePredicateOrNull( fun gettingFirstClassDefByDeclarativePredicate( type: String, predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstClassDefByDeclarativePredicateOrNull(type, predicate)) +) = CachedReadOnlyProperty { firstClassDefByDeclarativePredicate(type, predicate) } fun gettingFirstClassDefMutableByDeclarativePredicateOrNull( type: String, @@ -444,7 +462,7 @@ fun gettingFirstClassDefMutableByDeclarativePredicateOrNull( fun gettingFirstClassDefMutableByDeclarativePredicate( type: String, predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstClassDefMutableByDeclarativePredicateOrNull(type, predicate)) +) = CachedReadOnlyProperty { firstClassDefMutableByDeclarativePredicate(type, predicate) } fun gettingFirstClassDefByDeclarativePredicateOrNull( predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit @@ -452,7 +470,7 @@ fun gettingFirstClassDefByDeclarativePredicateOrNull( fun gettingFirstClassDefByDeclarativePredicate( predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstClassDefByDeclarativePredicateOrNull(predicate)) +) = CachedReadOnlyProperty { firstClassDefByDeclarativePredicate(predicate) } fun gettingFirstClassDefMutableByDeclarativePredicateOrNull( predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit @@ -460,7 +478,7 @@ fun gettingFirstClassDefMutableByDeclarativePredicateOrNull( fun gettingFirstClassDefMutableByDeclarativePredicate( predicate: context(MatchContext, ClassDef) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstClassDefMutableByDeclarativePredicateOrNull(predicate)) +) = CachedReadOnlyProperty { firstClassDefMutableByDeclarativePredicate(predicate) } fun gettingFirstMethodByDeclarativePredicateOrNull( predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit @@ -468,7 +486,7 @@ fun gettingFirstMethodByDeclarativePredicateOrNull( fun gettingFirstMethodByDeclarativePredicate( predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstMethodByDeclarativePredicateOrNull(predicate)) +) = CachedReadOnlyProperty { firstMethodByDeclarativePredicate(predicate = predicate) } fun gettingFirstMethodMutableByDeclarativePredicateOrNull( predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit @@ -476,7 +494,7 @@ fun gettingFirstMethodMutableByDeclarativePredicateOrNull( fun gettingFirstMethodMutableByDeclarativePredicate( predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstMethodMutableByDeclarativePredicateOrNull(predicate)) +) = CachedReadOnlyProperty { firstMethodMutableByDeclarativePredicate(predicate = predicate) } fun gettingFirstMethodByDeclarativePredicateOrNull( vararg strings: String, @@ -486,7 +504,7 @@ fun gettingFirstMethodByDeclarativePredicateOrNull( fun gettingFirstMethodByDeclarativePredicate( vararg strings: String, predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstMethodByDeclarativePredicateOrNull(*strings, predicate = predicate)) +) = CachedReadOnlyProperty { firstMethodByDeclarativePredicate(*strings, predicate = predicate) } fun gettingFirstMethodMutableByDeclarativePredicateOrNull( vararg strings: String, @@ -496,7 +514,7 @@ fun gettingFirstMethodMutableByDeclarativePredicateOrNull( fun gettingFirstMethodMutableByDeclarativePredicate( vararg strings: String, predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit -) = requireNotNull(gettingFirstMethodMutableByDeclarativePredicateOrNull(*strings, predicate = predicate)) +) = CachedReadOnlyProperty { firstMethodMutableByDeclarativePredicate(*strings, predicate = predicate) } class DeclarativePredicateBuilder internal constructor() { @@ -545,8 +563,20 @@ fun DeclarativePredicateBuilder.custom(block: context(MatchContext) Meth class Composition internal constructor( val indices: List, - predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit + private val predicate: context(MatchContext, Method) DeclarativePredicateBuilder.() -> Unit ) { - val methodOrNull by gettingFirstMethodMutableByDeclarativePredicateOrNull(predicate) - val method = requireNotNull(methodOrNull) + private var _methodOrNull: MutableMethod? = null + + context(context: BytecodePatchContext) + val methodOrNull: MutableMethod? + get() { + if (_methodOrNull == null) { + _methodOrNull = context.firstMethodMutableByDeclarativePredicateOrNull(predicate) + } + + return _methodOrNull + } + + context(_: BytecodePatchContext) + val method get() = requireNotNull(methodOrNull) } diff --git a/src/test/kotlin/app/revanced/patcher/PatcherTest.kt b/src/test/kotlin/app/revanced/patcher/PatcherTest.kt index 0448960..e2db339 100644 --- a/src/test/kotlin/app/revanced/patcher/PatcherTest.kt +++ b/src/test/kotlin/app/revanced/patcher/PatcherTest.kt @@ -3,7 +3,7 @@ package app.revanced.patcher import app.revanced.patcher.extensions.toInstructions import app.revanced.patcher.patch.* import com.android.tools.smali.dexlib2.Opcode -import com.android.tools.smali.dexlib2.iface.instruction.Instruction +import com.android.tools.smali.dexlib2.iface.ClassDef import com.android.tools.smali.dexlib2.immutable.ImmutableClassDef import com.android.tools.smali.dexlib2.immutable.ImmutableMethod import com.android.tools.smali.dexlib2.immutable.ImmutableMethodImplementation @@ -15,8 +15,12 @@ import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.assertAll +import org.junit.jupiter.api.assertThrows import java.util.logging.Logger -import kotlin.test.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull internal object PatcherTest { private lateinit var patcher: Patcher @@ -79,7 +83,11 @@ internal object PatcherTest { infix fun Patch<*>.produces(equals: List) { val patches = setOf(this) - patches() + try { + patches() + } catch (_: PatchException) { + // Swallow expected exceptions for testing purposes. + } assertEquals(equals, executed, "Expected patches to be executed in correct order.") @@ -159,10 +167,7 @@ internal object PatcherTest { } } - assertTrue( - patch().exception != null, - "Expected an exception because the fingerprint can't match.", - ) + assertThrows("Expected an exception because the fingerprint can't match.") { patch() } } @Test @@ -290,18 +295,22 @@ internal object PatcherTest { ), ), ) + every { + with(patcher.context.bytecodeContext) { + any(ClassDef::class).mutable() + } + } answers { callOriginal() } + + val a = gettingFirstMethodOrNull { true } val fingerprint = fingerprint { returns("V") } val fingerprint2 = fingerprint { returns("V") } val fingerprint3 = fingerprint { returns("V") } - val matchIndices = indexedMatcher() - val method by gettingFirstMethod { - implementation { - matchIndices(instructions) { - head { opcode == Opcode.CONST_STRING } - add { opcode == Opcode.IPUT_OBJECT } - } + val composite = firstMethodComposite { + instructions { + head { opcode == Opcode.CONST_STRING } + add { opcode == Opcode.IPUT_OBJECT } } } @@ -311,19 +320,21 @@ internal object PatcherTest { fingerprint.match(classDefs.first().methods.first()) fingerprint2.match(classDefs.first()) fingerprint3.originalClassDef - println(method) + composite.method } }, ) patches() - with(patcher.context.bytecodeContext) { + with(patcher.context.bytecodeContext) + { assertAll( "Expected fingerprints to match.", { assertNotNull(fingerprint.originalClassDefOrNull) }, { assertNotNull(fingerprint2.originalClassDefOrNull) }, { assertNotNull(fingerprint3.originalClassDefOrNull) }, + { assertEquals("method", composite.method.name) }, ) } } @@ -333,7 +344,11 @@ internal object PatcherTest { every { patcher.context.bytecodeContext.lookupMaps } returns with(patcher.context.bytecodeContext) { LookupMaps() } every { with(patcher.context.bytecodeContext) { mergeExtension(any()) } } just runs - return runBlocking { patcher().toList() } + return runBlocking { + patcher().toList().also { results -> + results.firstOrNull { result -> result.exception != null }?.let { result -> throw result.exception!! } + } + } } private operator fun Patch<*>.invoke() = setOf(this)().first()