Compare commits

...

6 Commits

Author SHA1 Message Date
semantic-release-bot
34171b534b chore(release): 1.4.0-dev.1 [skip ci]
# [1.4.0-dev.1](https://github.com/ReVanced/revanced-library/compare/v1.3.0...v1.4.0-dev.1) (2023-11-27)

### Features

* Add `PatchUtils#getMostCommonCompatibleVersions` utility function ([c5f3536](c5f3536cbb))
2023-11-27 01:26:17 +00:00
oSumAtrIX
c5f3536cbb feat: Add PatchUtils#getMostCommonCompatibleVersions utility function 2023-11-27 02:24:55 +01:00
oSumAtrIX
893274074b refactor: Do not escape unnecessary 2023-11-27 02:24:55 +01:00
semantic-release-bot
b6c09d42ae chore(release): 1.3.0 [skip ci]
# [1.3.0](https://github.com/ReVanced/revanced-library/compare/v1.2.0...v1.3.0) (2023-11-26)

### Bug Fixes

* Add missing log when calling `UserAdbManager#install` ([90b612b](90b612bee8))
* Delete mount script ([4fe0fb0](4fe0fb0a61))

### Features

* Increase certainty of the possibility to mount ([10f8cd1](10f8cd1470))
* Select first Adb device, if none supplied automatically ([1a5f868](1a5f868ecd))
2023-11-26 04:35:19 +00:00
oSumAtrIX
fe8a2334e6 chore: Merge branch dev to main (#18) 2023-11-26 05:33:52 +01:00
oSumAtrIX
a9e5966145 chore: Lint code 2023-11-26 05:27:29 +01:00
18 changed files with 473 additions and 209 deletions

View File

@@ -1,3 +1,24 @@
# [1.4.0-dev.1](https://github.com/ReVanced/revanced-library/compare/v1.3.0...v1.4.0-dev.1) (2023-11-27)
### Features
* Add `PatchUtils#getMostCommonCompatibleVersions` utility function ([c5f3536](https://github.com/ReVanced/revanced-library/commit/c5f3536cbb6997766076595dc0b2b5d2e861ca73))
# [1.3.0](https://github.com/ReVanced/revanced-library/compare/v1.2.0...v1.3.0) (2023-11-26)
### Bug Fixes
* Add missing log when calling `UserAdbManager#install` ([90b612b](https://github.com/ReVanced/revanced-library/commit/90b612bee8591c01b8befabde4147c7de7a2a09f))
* Delete mount script ([4fe0fb0](https://github.com/ReVanced/revanced-library/commit/4fe0fb0a617082b24199331671193e4fa7f485e2))
### Features
* Increase certainty of the possibility to mount ([10f8cd1](https://github.com/ReVanced/revanced-library/commit/10f8cd1470fd29cfefe53bf00a4a014f71a3f706))
* Select first Adb device, if none supplied automatically ([1a5f868](https://github.com/ReVanced/revanced-library/commit/1a5f868ecd0d278d574c12664ee95139c2423c17))
# [1.3.0-dev.1](https://github.com/ReVanced/revanced-library/compare/v1.2.1-dev.1...v1.3.0-dev.1) (2023-11-26) # [1.3.0-dev.1](https://github.com/ReVanced/revanced-library/compare/v1.2.1-dev.1...v1.3.0-dev.1) (2023-11-26)

View File

@@ -63,6 +63,8 @@ public final class app/revanced/library/Options$Patch$Option {
public final class app/revanced/library/PatchUtils { public final class app/revanced/library/PatchUtils {
public static final field INSTANCE Lapp/revanced/library/PatchUtils; public static final field INSTANCE Lapp/revanced/library/PatchUtils;
public final fun getMostCommonCompatibleVersion (Ljava/util/Set;Ljava/lang/String;)Ljava/lang/String; public final fun getMostCommonCompatibleVersion (Ljava/util/Set;Ljava/lang/String;)Ljava/lang/String;
public final fun getMostCommonCompatibleVersions (Ljava/util/Set;Ljava/util/Set;Z)Ljava/util/Map;
public static synthetic fun getMostCommonCompatibleVersions$default (Lapp/revanced/library/PatchUtils;Ljava/util/Set;Ljava/util/Set;ZILjava/lang/Object;)Ljava/util/Map;
} }
public abstract class app/revanced/library/adb/AdbManager { public abstract class app/revanced/library/adb/AdbManager {
@@ -87,6 +89,7 @@ public final class app/revanced/library/adb/AdbManager$Companion {
} }
public final class app/revanced/library/adb/AdbManager$DeviceNotFoundException : java/lang/Exception { public final class app/revanced/library/adb/AdbManager$DeviceNotFoundException : java/lang/Exception {
public fun <init> ()V
} }
public final class app/revanced/library/adb/AdbManager$FailedToFindInstalledPackageException : java/lang/Exception { public final class app/revanced/library/adb/AdbManager$FailedToFindInstalledPackageException : java/lang/Exception {

View File

@@ -75,4 +75,4 @@ publishing {
} }
} }
} }
} }

View File

@@ -1,4 +1,4 @@
org.gradle.parallel = true org.gradle.parallel = true
org.gradle.caching = true org.gradle.caching = true
kotlin.code.style = official kotlin.code.style = official
version = 1.3.0-dev.1 version = 1.4.0-dev.1

View File

@@ -26,8 +26,9 @@ object ApkSigner {
private val logger = Logger.getLogger(app.revanced.library.ApkSigner::class.java.name) private val logger = Logger.getLogger(app.revanced.library.ApkSigner::class.java.name)
init { init {
if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) {
Security.addProvider(BouncyCastleProvider()) Security.addProvider(BouncyCastleProvider())
}
} }
/** /**
@@ -39,14 +40,15 @@ object ApkSigner {
*/ */
fun newPrivateKeyCertificatePair( fun newPrivateKeyCertificatePair(
commonName: String = "ReVanced", commonName: String = "ReVanced",
validUntil: Date = Date(System.currentTimeMillis() + 356.days.inWholeMilliseconds * 24) validUntil: Date = Date(System.currentTimeMillis() + 356.days.inWholeMilliseconds * 24),
): PrivateKeyCertificatePair { ): PrivateKeyCertificatePair {
logger.fine("Creating certificate for $commonName") logger.fine("Creating certificate for $commonName")
// Generate a new key pair. // Generate a new key pair.
val keyPair = KeyPairGenerator.getInstance("RSA").apply { val keyPair =
initialize(4096) KeyPairGenerator.getInstance("RSA").apply {
}.generateKeyPair() initialize(4096)
}.generateKeyPair()
var serialNumber: BigInteger var serialNumber: BigInteger
do serialNumber = BigInteger.valueOf(SecureRandom().nextLong()) do serialNumber = BigInteger.valueOf(SecureRandom().nextLong())
@@ -55,22 +57,22 @@ object ApkSigner {
val name = X500Name("CN=$commonName") val name = X500Name("CN=$commonName")
// Create a new certificate. // Create a new certificate.
val certificate = JcaX509CertificateConverter().getCertificate( val certificate =
X509v3CertificateBuilder( JcaX509CertificateConverter().getCertificate(
name, X509v3CertificateBuilder(
serialNumber, name,
Date(System.currentTimeMillis()), serialNumber,
validUntil, Date(System.currentTimeMillis()),
Locale.ENGLISH, validUntil,
name, Locale.ENGLISH,
SubjectPublicKeyInfo.getInstance(keyPair.public.encoded) name,
).build(JcaContentSignerBuilder("SHA256withRSA").build(keyPair.private)) SubjectPublicKeyInfo.getInstance(keyPair.public.encoded),
) ).build(JcaContentSignerBuilder("SHA256withRSA").build(keyPair.private)),
)
return PrivateKeyCertificatePair(keyPair.private, certificate) return PrivateKeyCertificatePair(keyPair.private, certificate)
} }
/** /**
* Read a [PrivateKeyCertificatePair] from a keystore entry. * Read a [PrivateKeyCertificatePair] from a keystore entry.
* *
@@ -87,16 +89,18 @@ object ApkSigner {
): PrivateKeyCertificatePair { ): PrivateKeyCertificatePair {
logger.fine("Reading key and certificate pair from keystore entry $keyStoreEntryAlias") logger.fine("Reading key and certificate pair from keystore entry $keyStoreEntryAlias")
if (!keyStore.containsAlias(keyStoreEntryAlias)) if (!keyStore.containsAlias(keyStoreEntryAlias)) {
throw IllegalArgumentException("Keystore does not contain alias $keyStoreEntryAlias") throw IllegalArgumentException("Keystore does not contain alias $keyStoreEntryAlias")
}
// Read the private key and certificate from the keystore. // Read the private key and certificate from the keystore.
val privateKey = try { val privateKey =
keyStore.getKey(keyStoreEntryAlias, keyStoreEntryPassword.toCharArray()) as PrivateKey try {
} catch (exception: UnrecoverableKeyException) { keyStore.getKey(keyStoreEntryAlias, keyStoreEntryPassword.toCharArray()) as PrivateKey
throw IllegalArgumentException("Invalid password for keystore entry $keyStoreEntryAlias") } catch (exception: UnrecoverableKeyException) {
} throw IllegalArgumentException("Invalid password for keystore entry $keyStoreEntryAlias")
}
val certificate = keyStore.getCertificate(keyStoreEntryAlias) as X509Certificate val certificate = keyStore.getCertificate(keyStoreEntryAlias) as X509Certificate
@@ -110,9 +114,7 @@ object ApkSigner {
* @return The created keystore. * @return The created keystore.
* @see KeyStoreEntry * @see KeyStoreEntry
*/ */
fun newKeyStore( fun newKeyStore(entries: List<KeyStoreEntry>): KeyStore {
entries: List<KeyStoreEntry>
): KeyStore {
logger.fine("Creating keystore") logger.fine("Creating keystore")
return KeyStore.getInstance("BKS", BouncyCastleProvider.PROVIDER_NAME).apply { return KeyStore.getInstance("BKS", BouncyCastleProvider.PROVIDER_NAME).apply {
@@ -124,7 +126,7 @@ object ApkSigner {
entry.alias, entry.alias,
entry.privateKeyCertificatePair.privateKey, entry.privateKeyCertificatePair.privateKey,
entry.password.toCharArray(), entry.password.toCharArray(),
arrayOf(entry.privateKeyCertificatePair.certificate) arrayOf(entry.privateKeyCertificatePair.certificate),
) )
} }
} }
@@ -140,10 +142,10 @@ object ApkSigner {
fun newKeyStore( fun newKeyStore(
keyStoreOutputStream: OutputStream, keyStoreOutputStream: OutputStream,
keyStorePassword: String, keyStorePassword: String,
entries: List<KeyStoreEntry> entries: List<KeyStoreEntry>,
) = newKeyStore(entries).store( ) = newKeyStore(entries).store(
keyStoreOutputStream, keyStoreOutputStream,
keyStorePassword.toCharArray() keyStorePassword.toCharArray(),
) // Save the keystore. ) // Save the keystore.
/** /**
@@ -156,7 +158,7 @@ object ApkSigner {
*/ */
fun readKeyStore( fun readKeyStore(
keyStoreInputStream: InputStream, keyStoreInputStream: InputStream,
keyStorePassword: String? keyStorePassword: String?,
): KeyStore { ): KeyStore {
logger.fine("Reading keystore") logger.fine("Reading keystore")
@@ -164,10 +166,11 @@ object ApkSigner {
try { try {
load(keyStoreInputStream, keyStorePassword?.toCharArray()) load(keyStoreInputStream, keyStorePassword?.toCharArray())
} catch (exception: IOException) { } catch (exception: IOException) {
if (exception.cause is UnrecoverableKeyException) if (exception.cause is UnrecoverableKeyException) {
throw IllegalArgumentException("Invalid keystore password") throw IllegalArgumentException("Invalid keystore password")
else } else {
throw exception throw exception
}
} }
} }
} }
@@ -183,20 +186,21 @@ object ApkSigner {
fun newApkSignerBuilder( fun newApkSignerBuilder(
privateKeyCertificatePair: PrivateKeyCertificatePair, privateKeyCertificatePair: PrivateKeyCertificatePair,
signer: String, signer: String,
createdBy: String createdBy: String,
): ApkSigner.Builder { ): ApkSigner.Builder {
logger.fine( logger.fine(
"Creating new ApkSigner " + "Creating new ApkSigner " +
"with $signer as signer and " + "with $signer as signer and " +
"$createdBy as Created-By attribute in the APK's manifest" "$createdBy as Created-By attribute in the APK's manifest",
) )
// Create the signer config. // Create the signer config.
val signerConfig = ApkSigner.SignerConfig.Builder( val signerConfig =
signer, ApkSigner.SignerConfig.Builder(
privateKeyCertificatePair.privateKey, signer,
listOf(privateKeyCertificatePair.certificate) privateKeyCertificatePair.privateKey,
).build() listOf(privateKeyCertificatePair.certificate),
).build()
// Create the signer. // Create the signer.
return ApkSigner.Builder(listOf(signerConfig)).apply { return ApkSigner.Builder(listOf(signerConfig)).apply {
@@ -227,10 +231,13 @@ object ApkSigner {
) = newApkSignerBuilder( ) = newApkSignerBuilder(
readKeyCertificatePair(keyStore, keyStoreEntryAlias, keyStoreEntryPassword), readKeyCertificatePair(keyStore, keyStoreEntryAlias, keyStoreEntryPassword),
signer, signer,
createdBy createdBy,
) )
fun ApkSigner.Builder.signApk(input: File, output: File) { fun ApkSigner.Builder.signApk(
input: File,
output: File,
) {
logger.info("Signing ${input.name}") logger.info("Signing ${input.name}")
setInputApk(input) setInputApk(input)
@@ -250,7 +257,7 @@ object ApkSigner {
class KeyStoreEntry( class KeyStoreEntry(
val alias: String, val alias: String,
val password: String, val password: String,
val privateKeyCertificatePair: PrivateKeyCertificatePair = newPrivateKeyCertificatePair() val privateKeyCertificatePair: PrivateKeyCertificatePair = newPrivateKeyCertificatePair(),
) )
/** /**
@@ -263,4 +270,4 @@ object ApkSigner {
val privateKey: PrivateKey, val privateKey: PrivateKey,
val certificate: X509Certificate, val certificate: X509Certificate,
) )
} }

View File

@@ -22,7 +22,11 @@ object ApkUtils {
* @param outputFile The apk to write the new entries to. * @param outputFile The apk to write the new entries to.
* @param patchedEntriesSource The result of the patcher to add the patched dex files and resources. * @param patchedEntriesSource The result of the patcher to add the patched dex files and resources.
*/ */
fun copyAligned(apkFile: File, outputFile: File, patchedEntriesSource: PatcherResult) { fun copyAligned(
apkFile: File,
outputFile: File,
patchedEntriesSource: PatcherResult,
) {
logger.info("Aligning ${apkFile.name}") logger.info("Aligning ${apkFile.name}")
outputFile.toPath().deleteIfExists() outputFile.toPath().deleteIfExists()
@@ -30,13 +34,15 @@ object ApkUtils {
ZipFile(outputFile).use { file -> ZipFile(outputFile).use { file ->
patchedEntriesSource.dexFiles.forEach { patchedEntriesSource.dexFiles.forEach {
file.addEntryCompressData( file.addEntryCompressData(
ZipEntry(it.name), it.stream.readBytes() ZipEntry(it.name),
it.stream.readBytes(),
) )
} }
patchedEntriesSource.resourceFile?.let { patchedEntriesSource.resourceFile?.let {
file.copyEntriesFromFileAligned( file.copyEntriesFromFileAligned(
ZipFile(it), ZipFile.apkZipEntryAlignment ZipFile(it),
ZipFile.apkZipEntryAlignment,
) )
} }
@@ -44,7 +50,8 @@ object ApkUtils {
// TODO: Fix copying resources that are not needed anymore. // TODO: Fix copying resources that are not needed anymore.
file.copyEntriesFromFileAligned( file.copyEntriesFromFileAligned(
ZipFile(apkFile), ZipFile.apkZipEntryAlignment ZipFile(apkFile),
ZipFile.apkZipEntryAlignment,
) )
} }
} }
@@ -62,26 +69,27 @@ object ApkUtils {
signingOptions: SigningOptions, signingOptions: SigningOptions,
) { ) {
// Get the keystore from the file or create a new one. // Get the keystore from the file or create a new one.
val keyStore = if (signingOptions.keyStore.exists()) { val keyStore =
ApkSigner.readKeyStore(signingOptions.keyStore.inputStream(), signingOptions.keyStorePassword) if (signingOptions.keyStore.exists()) {
} else { ApkSigner.readKeyStore(signingOptions.keyStore.inputStream(), signingOptions.keyStorePassword)
val entry = ApkSigner.KeyStoreEntry(signingOptions.alias, signingOptions.password) } else {
val entry = ApkSigner.KeyStoreEntry(signingOptions.alias, signingOptions.password)
// Create a new keystore with a new keypair and saves it. // Create a new keystore with a new keypair and saves it.
ApkSigner.newKeyStore(listOf(entry)).also { keyStore -> ApkSigner.newKeyStore(listOf(entry)).also { keyStore ->
keyStore.store( keyStore.store(
signingOptions.keyStore.outputStream(), signingOptions.keyStore.outputStream(),
signingOptions.keyStorePassword?.toCharArray() signingOptions.keyStorePassword?.toCharArray(),
) )
}
} }
}
ApkSigner.newApkSignerBuilder( ApkSigner.newApkSignerBuilder(
keyStore, keyStore,
signingOptions.alias, signingOptions.alias,
signingOptions.password, signingOptions.password,
signingOptions.signer, signingOptions.signer,
signingOptions.signer signingOptions.signer,
).signApk(apk, output) ).signApk(apk, output)
} }
@@ -101,4 +109,4 @@ object ApkUtils {
val password: String = "", val password: String = "",
val signer: String = "ReVanced", val signer: String = "ReVanced",
) )
} }

View File

@@ -2,7 +2,6 @@
package app.revanced.library package app.revanced.library
import app.revanced.library.Options.Patch.Option import app.revanced.library.Options.Patch.Option
import app.revanced.patcher.PatchClass import app.revanced.patcher.PatchClass
import app.revanced.patcher.PatchSet import app.revanced.patcher.PatchSet
@@ -25,31 +24,37 @@ object Options {
* @param prettyPrint Whether to pretty print the JSON. * @param prettyPrint Whether to pretty print the JSON.
* @return The JSON string containing the options. * @return The JSON string containing the options.
*/ */
fun serialize(patches: PatchSet, prettyPrint: Boolean = false): String = patches fun serialize(
.filter { it.options.any() } patches: PatchSet,
.map { patch -> prettyPrint: Boolean = false,
Patch( ): String =
patch.name!!, patches
patch.options.values.map { option -> .filter { it.options.any() }
val optionValue = try { .map { patch ->
option.value Patch(
} catch (e: PatchOptionException) { patch.name!!,
logger.warning("Using default option value for the ${patch.name} patch: ${e.message}") patch.options.values.map { option ->
option.default val optionValue =
} try {
option.value
} catch (e: PatchOptionException) {
logger.warning("Using default option value for the ${patch.name} patch: ${e.message}")
option.default
}
Option(option.key, optionValue) Option(option.key, optionValue)
},
)
}
// See https://github.com/revanced/revanced-patches/pull/2434/commits/60e550550b7641705e81aa72acfc4faaebb225e7.
.distinctBy { it.patchName }
.let {
if (prettyPrint) {
mapper.writerWithDefaultPrettyPrinter().writeValueAsString(it)
} else {
mapper.writeValueAsString(it)
} }
) }
}
// See https://github.com/revanced/revanced-patches/pull/2434/commits/60e550550b7641705e81aa72acfc4faaebb225e7.
.distinctBy { it.patchName }
.let {
if (prettyPrint)
mapper.writerWithDefaultPrettyPrinter().writeValueAsString(it)
else
mapper.writeValueAsString(it)
}
/** /**
* Deserializes the options for the patches in the list. * Deserializes the options for the patches in the list.
@@ -70,9 +75,10 @@ object Options {
filter { it.options.any() }.let { patches -> filter { it.options.any() }.let { patches ->
if (patches.isEmpty()) return if (patches.isEmpty()) return
val jsonPatches = deserialize(json).associate { val jsonPatches =
it.patchName to it.options.associate { option -> option.key to option.value } deserialize(json).associate {
} it.patchName to it.options.associate { option -> option.key to option.value }
}
patches.forEach { patch -> patches.forEach { patch ->
jsonPatches[patch.name]?.let { jsonPatchOptions -> jsonPatches[patch.name]?.let { jsonPatchOptions ->
@@ -104,9 +110,8 @@ object Options {
*/ */
class Patch internal constructor( class Patch internal constructor(
val patchName: String, val patchName: String,
val options: List<Option> val options: List<Option>,
) { ) {
/** /**
* Data class for patch option. * Data class for patch option.
* *
@@ -115,4 +120,4 @@ object Options {
*/ */
class Option internal constructor(val key: String, val value: Any?) class Option internal constructor(val key: String, val value: Any?)
} }
} }

View File

@@ -1,6 +1,14 @@
package app.revanced.library package app.revanced.library
import app.revanced.patcher.PatchSet import app.revanced.patcher.PatchSet
import java.util.*
private typealias PackageName = String
private typealias Version = String
private typealias Count = Int
private typealias VersionMap = SortedMap<Version, Count>
internal typealias PackageNameMap = Map<PackageName, VersionMap>
/** /**
* Utility functions for working with patches. * Utility functions for working with patches.
@@ -14,7 +22,17 @@ object PatchUtils {
* @param packageName The name of the compatible package. * @param packageName The name of the compatible package.
* @return The most common version of. * @return The most common version of.
*/ */
fun getMostCommonCompatibleVersion(patches: PatchSet, packageName: String) = patches @Deprecated(
"Use getMostCommonCompatibleVersions instead.",
ReplaceWith(
"getMostCommonCompatibleVersions(patches, setOf(packageName))" +
".entries.firstOrNull()?.value?.keys?.firstOrNull()",
),
)
fun getMostCommonCompatibleVersion(
patches: PatchSet,
packageName: String,
) = patches
.mapNotNull { .mapNotNull {
// Map all patches to their compatible packages with version constraints. // Map all patches to their compatible packages with version constraints.
it.compatiblePackages?.firstOrNull { compatiblePackage -> it.compatiblePackages?.firstOrNull { compatiblePackage ->
@@ -25,4 +43,35 @@ object PatchUtils {
.groupingBy { it } .groupingBy { it }
.eachCount() .eachCount()
.maxByOrNull { it.value }?.key .maxByOrNull { it.value }?.key
}
/**
* Get the count of versions for each compatible package from a supplied set of [patches] ordered by the most common version.
*
* @param patches The set of patches to check.
* @param packageNames The names of the compatible packages.
* @param countUnusedPatches Whether to count patches that are not used.
* @return A map of package names to a map of versions to their count.
*/
fun getMostCommonCompatibleVersions(
patches: PatchSet,
packageNames: Set<String>,
countUnusedPatches: Boolean = false,
): PackageNameMap {
val wantedPackages = packageNames.toHashSet()
return buildMap {
patches
.filter { it.use || countUnusedPatches }
.flatMap { it.compatiblePackages ?: emptyList() }
.filter { it.name in wantedPackages }
.forEach { compatiblePackage ->
compatiblePackage.versions?.let { versions ->
val versionMap = getOrPut(compatiblePackage.name) { sortedMapOf() }
versions.forEach { version ->
versionMap[version] = versionMap.getOrDefault(version, 0) + 1
}
}
}
}
}
}

View File

@@ -29,15 +29,16 @@ import java.util.logging.Logger
sealed class AdbManager private constructor(deviceSerial: String?) { sealed class AdbManager private constructor(deviceSerial: String?) {
protected val logger: Logger = Logger.getLogger(AdbManager::class.java.name) protected val logger: Logger = Logger.getLogger(AdbManager::class.java.name)
protected val device = with(JadbConnection().devices) { protected val device =
if (isEmpty()) throw DeviceNotFoundException() with(JadbConnection().devices) {
if (isEmpty()) throw DeviceNotFoundException()
deviceSerial?.let { deviceSerial?.let {
firstOrNull { it.serial == deviceSerial } ?: throw DeviceNotFoundException(deviceSerial) firstOrNull { it.serial == deviceSerial } ?: throw DeviceNotFoundException(deviceSerial)
} ?: first().also { } ?: first().also {
logger.warning("No device serial supplied. Using device with serial ${it.serial}") logger.warning("No device serial supplied. Using device with serial ${it.serial}")
} }
}!! }!!
init { init {
logger.fine("Connected to ${device.serial}") logger.fine("Connected to ${device.serial}")
@@ -70,8 +71,10 @@ sealed class AdbManager private constructor(deviceSerial: String?) {
* @return The [AdbManager]. * @return The [AdbManager].
* @throws DeviceNotFoundException If the device can not be found. * @throws DeviceNotFoundException If the device can not be found.
*/ */
fun getAdbManager(deviceSerial: String? = null, root: Boolean = false): AdbManager = fun getAdbManager(
if (root) RootAdbManager(deviceSerial) else UserAdbManager(deviceSerial) deviceSerial: String? = null,
root: Boolean = false,
): AdbManager = if (root) RootAdbManager(deviceSerial) else UserAdbManager(deviceSerial)
} }
/** /**
@@ -123,7 +126,11 @@ sealed class AdbManager private constructor(deviceSerial: String?) {
} }
companion object Utils { companion object Utils {
private fun JadbDevice.run(command: String, with: String) = run(command.applyReplacement(with)) private fun JadbDevice.run(
command: String,
with: String,
) = run(command.applyReplacement(with))
private fun String.applyReplacement(with: String) = replace(PLACEHOLDER, with) private fun String.applyReplacement(with: String) = replace(PLACEHOLDER, with)
} }
} }
@@ -162,13 +169,15 @@ sealed class AdbManager private constructor(deviceSerial: String?) {
class Apk(val file: File, val packageName: String? = null) class Apk(val file: File, val packageName: String? = null)
class DeviceNotFoundException internal constructor(deviceSerial: String? = null) : class DeviceNotFoundException internal constructor(deviceSerial: String? = null) :
Exception(deviceSerial?.let { Exception(
"The device with the ADB device serial \"$deviceSerial\" can not be found" deviceSerial?.let {
} ?: "No ADB device found") "The device with the ADB device serial \"$deviceSerial\" can not be found"
} ?: "No ADB device found",
)
class FailedToFindInstalledPackageException internal constructor(packageName: String) : class FailedToFindInstalledPackageException internal constructor(packageName: String) :
Exception("Failed to find installed package \"$packageName\" because no activity was found") Exception("Failed to find installed package \"$packageName\" because no activity was found")
class PackageNameRequiredException internal constructor() : class PackageNameRequiredException internal constructor() :
Exception("Package name is required") Exception("Package name is required")
} }

View File

@@ -5,8 +5,10 @@ import se.vidstige.jadb.RemoteFile
import se.vidstige.jadb.ShellProcessBuilder import se.vidstige.jadb.ShellProcessBuilder
import java.io.File import java.io.File
internal fun JadbDevice.buildCommand(
internal fun JadbDevice.buildCommand(command: String, su: Boolean = true): ShellProcessBuilder { command: String,
su: Boolean = true,
): ShellProcessBuilder {
if (su) return shellProcessBuilder("su -c \'$command\'") if (su) return shellProcessBuilder("su -c \'$command\'")
val args = command.split(" ") as ArrayList<String> val args = command.split(" ") as ArrayList<String>
@@ -15,14 +17,19 @@ internal fun JadbDevice.buildCommand(command: String, su: Boolean = true): Shell
return shellProcessBuilder(cmd, *args.toTypedArray()) return shellProcessBuilder(cmd, *args.toTypedArray())
} }
internal fun JadbDevice.run(command: String, su: Boolean = true) = internal fun JadbDevice.run(
this.buildCommand(command, su).start() command: String,
su: Boolean = true,
) = this.buildCommand(command, su).start()
internal fun JadbDevice.hasSu() = internal fun JadbDevice.hasSu() = this.run("whoami", true).waitFor() == 0
this.run("whoami", true).waitFor() == 0
internal fun JadbDevice.push(file: File, targetFilePath: String) = internal fun JadbDevice.push(
push(file, RemoteFile(targetFilePath)) file: File,
targetFilePath: String,
) = push(file, RemoteFile(targetFilePath))
internal fun JadbDevice.createFile(targetFile: String, content: String) = internal fun JadbDevice.createFile(
push(content.byteInputStream(), System.currentTimeMillis(), 644, RemoteFile(targetFile)) targetFile: String,
content: String,
) = push(content.byteInputStream(), System.currentTimeMillis(), 644, RemoteFile(targetFile))

View File

@@ -13,30 +13,31 @@ internal object Constants {
internal const val RESTART = "am start -S $PLACEHOLDER" internal const val RESTART = "am start -S $PLACEHOLDER"
internal const val GET_INSTALLED_PATH = "pm path $PLACEHOLDER" internal const val GET_INSTALLED_PATH = "pm path $PLACEHOLDER"
internal const val INSTALL_PATCHED_APK = "base_path=\"$PATCHED_APK_PATH\" && " + internal const val INSTALL_PATCHED_APK =
"base_path=\"$PATCHED_APK_PATH\" && " +
"mv $TMP_PATH ${'$'}base_path && " + "mv $TMP_PATH ${'$'}base_path && " +
"chmod 644 ${'$'}base_path && " + "chmod 644 ${'$'}base_path && " +
"chown system:system ${'$'}base_path && " + "chown system:system ${'$'}base_path && " +
"chcon u:object_r:apk_data_file:s0 ${'$'}base_path" "chcon u:object_r:apk_data_file:s0 ${'$'}base_path"
internal const val UMOUNT = internal const val UMOUNT =
"grep $PLACEHOLDER /proc/mounts | while read -r line; do echo ${'$'}line | cut -d \" \" -f 2 | sed 's/apk.*/apk/' | xargs -r umount -l; done" "grep $PLACEHOLDER /proc/mounts | while read -r line; do echo ${'$'}line | cut -d ' ' -f 2 | sed 's/apk.*/apk/' | xargs -r umount -l; done"
internal const val INSTALL_MOUNT = "mv $TMP_PATH $MOUNT_PATH && chmod +x $MOUNT_PATH" internal const val INSTALL_MOUNT = "mv $TMP_PATH $MOUNT_PATH && chmod +x $MOUNT_PATH"
internal val MOUNT_SCRIPT = internal val MOUNT_SCRIPT =
""" """
#!/system/bin/sh #!/system/bin/sh
MAGISKTMP="${'$'}(magisk --path)" || MAGISKTMP=/sbin MAGISKTMP="$( magisk --path )" || MAGISKTMP=/sbin
MIRROR="${'$'}MAGISKTMP/.magisk/mirror" MIRROR="${'$'}MAGISKTMP/.magisk/mirror"
until [ "${'$'}(getprop sys.boot_completed)" = 1 ]; do sleep 3; done until [ "$( getprop sys.boot_completed )" = 1 ]; do sleep 3; done
until [ -d "/sdcard/Android" ]; do sleep 1; done until [ -d "/sdcard/Android" ]; do sleep 1; done
base_path="$PATCHED_APK_PATH" base_path="$PATCHED_APK_PATH"
stock_path=${'$'}( pm path $PLACEHOLDER | grep base | sed 's/package://g' ) stock_path=$( pm path $PLACEHOLDER | grep base | sed 's/package://g' )
chcon u:object_r:apk_data_file:s0 ${'$'}base_path chcon u:object_r:apk_data_file:s0 ${'$'}base_path
mount -o bind ${'$'}MIRROR${'$'}base_path ${'$'}stock_path mount -o bind ${'$'}MIRROR${'$'}base_path ${'$'}stock_path
""".trimIndent() """.trimIndent()
} }

View File

@@ -10,10 +10,11 @@ object Logger {
/** /**
* Rules for allowed loggers. * Rules for allowed loggers.
*/ */
private val allowedLoggersRules = arrayOf<String.() -> Boolean>( private val allowedLoggersRules =
{ startsWith("app.revanced") }, // ReVanced loggers. arrayOf<String.() -> Boolean>(
{ this == "" } // Logs warnings when compiling resources (Logger in class brut.util.OS). { startsWith("app.revanced") }, // ReVanced loggers.
) { this == "" }, // Logs warnings when compiling resources (Logger in class brut.util.OS).
)
private val rootLogger = java.util.logging.Logger.getLogger("") private val rootLogger = java.util.logging.Logger.getLogger("")
@@ -48,13 +49,14 @@ object Logger {
fun addHandler( fun addHandler(
publishHandler: (log: String, level: Level, loggerName: String?) -> Unit, publishHandler: (log: String, level: Level, loggerName: String?) -> Unit,
flushHandler: () -> Unit, flushHandler: () -> Unit,
closeHandler: () -> Unit closeHandler: () -> Unit,
) = object : Handler() { ) = object : Handler() {
override fun publish(record: LogRecord) = publishHandler( override fun publish(record: LogRecord) =
formatter.format(record), publishHandler(
record.level, formatter.format(record),
record.loggerName record.level,
) record.loggerName,
)
override fun flush() = flushHandler() override fun flush() = flushHandler()
@@ -77,10 +79,11 @@ object Logger {
} }
log.toByteArray().let { log.toByteArray().let {
if (level.intValue() > Level.WARNING.intValue()) if (level.intValue() > Level.WARNING.intValue()) {
System.err.write(it) System.err.write(it)
else } else {
System.out.write(it) System.out.write(it)
}
} }
} }
@@ -91,4 +94,4 @@ object Logger {
addHandler(publishHandler, flushHandler, flushHandler) addHandler(publishHandler, flushHandler, flushHandler)
} }
} }

View File

@@ -9,25 +9,34 @@ internal fun UInt.toLittleEndian() =
internal fun UShort.toLittleEndian() = (this.toUInt() shl 16).toLittleEndian().toUShort() internal fun UShort.toLittleEndian() = (this.toUInt() shl 16).toLittleEndian().toUShort()
internal fun UInt.toBigEndian() = (((this.toInt() and 0xff) shl 24) or ((this.toInt() and 0xff00) shl 8) internal fun UInt.toBigEndian() =
or ((this.toInt() and 0x00ff0000) ushr 8) or (this.toInt() ushr 24)).toUInt() (
((this.toInt() and 0xff) shl 24) or ((this.toInt() and 0xff00) shl 8)
or ((this.toInt() and 0x00ff0000) ushr 8) or (this.toInt() ushr 24)
).toUInt()
internal fun UShort.toBigEndian() = (this.toUInt() shl 16).toBigEndian().toUShort() internal fun UShort.toBigEndian() = (this.toUInt() shl 16).toBigEndian().toUShort()
internal fun ByteBuffer.getUShort() = this.getShort().toUShort() internal fun ByteBuffer.getUShort() = this.getShort().toUShort()
internal fun ByteBuffer.getUInt() = this.getInt().toUInt() internal fun ByteBuffer.getUInt() = this.getInt().toUInt()
internal fun ByteBuffer.putUShort(ushort: UShort) = this.putShort(ushort.toShort()) internal fun ByteBuffer.putUShort(ushort: UShort) = this.putShort(ushort.toShort())
internal fun ByteBuffer.putUInt(uint: UInt) = this.putInt(uint.toInt()) internal fun ByteBuffer.putUInt(uint: UInt) = this.putInt(uint.toInt())
internal fun DataInput.readUShort() = this.readShort().toUShort() internal fun DataInput.readUShort() = this.readShort().toUShort()
internal fun DataInput.readUInt() = this.readInt().toUInt() internal fun DataInput.readUInt() = this.readInt().toUInt()
internal fun DataOutput.writeUShort(ushort: UShort) = this.writeShort(ushort.toInt()) internal fun DataOutput.writeUShort(ushort: UShort) = this.writeShort(ushort.toInt())
internal fun DataOutput.writeUInt(uint: UInt) = this.writeInt(uint.toInt()) internal fun DataOutput.writeUInt(uint: UInt) = this.writeInt(uint.toInt())
internal fun DataInput.readUShortLE() = this.readUShort().toBigEndian() internal fun DataInput.readUShortLE() = this.readUShort().toBigEndian()
internal fun DataInput.readUIntLE() = this.readUInt().toBigEndian() internal fun DataInput.readUIntLE() = this.readUInt().toBigEndian()
internal fun DataOutput.writeUShortLE(ushort: UShort) = this.writeUShort(ushort.toLittleEndian()) internal fun DataOutput.writeUShortLE(ushort: UShort) = this.writeUShort(ushort.toLittleEndian())
internal fun DataOutput.writeUIntLE(uint: UInt) = this.writeUInt(uint.toLittleEndian()) internal fun DataOutput.writeUIntLE(uint: UInt) = this.writeUInt(uint.toLittleEndian())

View File

@@ -14,10 +14,11 @@ class ZipFile(file: File) : Closeable {
private var entries: MutableList<ZipEntry> = mutableListOf() private var entries: MutableList<ZipEntry> = mutableListOf()
// Open file for writing if it doesn't exist (because the intention is to write) or is writable. // Open file for writing if it doesn't exist (because the intention is to write) or is writable.
private val filePointer: RandomAccessFile = RandomAccessFile( private val filePointer: RandomAccessFile =
file, RandomAccessFile(
if (!file.exists() || file.canWrite()) "rw" else "r" file,
) if (!file.exists() || file.canWrite()) "rw" else "r",
)
private var centralDirectoryNeedsRewrite = false private var centralDirectoryNeedsRewrite = false
@@ -28,8 +29,9 @@ class ZipFile(file: File) : Closeable {
if (file.length() > 0) { if (file.length() > 0) {
val endRecord = findEndRecord() val endRecord = findEndRecord()
if (endRecord.diskNumber > 0u || endRecord.totalEntries != endRecord.diskEntries) if (endRecord.diskNumber > 0u || endRecord.totalEntries != endRecord.diskEntries) {
throw IllegalArgumentException("Multi-file archives are not supported") throw IllegalArgumentException("Multi-file archives are not supported")
}
entries = readEntries(endRecord).toMutableList() entries = readEntries(endRecord).toMutableList()
} }
@@ -66,16 +68,17 @@ class ZipFile(file: File) : Closeable {
for (i in 1..numberOfEntries) { for (i in 1..numberOfEntries) {
add( add(
ZipEntry.fromCDE(filePointer).also ZipEntry.fromCDE(filePointer).also
{ {
//for some reason the local extra field can be different from the central one // for some reason the local extra field can be different from the central one
it.readLocalExtra( it.readLocalExtra(
filePointer.channel.map( filePointer.channel.map(
FileChannel.MapMode.READ_ONLY, FileChannel.MapMode.READ_ONLY,
it.localHeaderOffset.toLong() + 28, it.localHeaderOffset.toLong() + 28,
2 2,
),
) )
) },
}) )
} }
} }
} }
@@ -89,20 +92,24 @@ class ZipFile(file: File) : Closeable {
val entriesCount = entries.size.toUShort() val entriesCount = entries.size.toUShort()
val endRecord = ZipEndRecord( val endRecord =
0u, ZipEndRecord(
0u, 0u,
entriesCount, 0u,
entriesCount, entriesCount,
filePointer.channel.position().toUInt() - centralDirectoryStartOffset, entriesCount,
centralDirectoryStartOffset, filePointer.channel.position().toUInt() - centralDirectoryStartOffset,
"" centralDirectoryStartOffset,
) "",
)
filePointer.channel.write(endRecord.toECD()) filePointer.channel.write(endRecord.toECD())
} }
private fun addEntry(entry: ZipEntry, data: ByteBuffer) { private fun addEntry(
entry: ZipEntry,
data: ByteBuffer,
) {
centralDirectoryNeedsRewrite = true centralDirectoryNeedsRewrite = true
entry.localHeaderOffset = filePointer.channel.position().toUInt() entry.localHeaderOffset = filePointer.channel.position().toUInt()
@@ -113,7 +120,10 @@ class ZipFile(file: File) : Closeable {
entries.add(entry) entries.add(entry)
} }
fun addEntryCompressData(entry: ZipEntry, data: ByteArray) { fun addEntryCompressData(
entry: ZipEntry,
data: ByteArray,
) {
val compressor = Deflater(compressionLevel, true) val compressor = Deflater(compressionLevel, true)
compressor.setInput(data) compressor.setInput(data)
compressor.finish() compressor.finish()
@@ -138,7 +148,11 @@ class ZipFile(file: File) : Closeable {
addEntry(entry, compressedBuffer) addEntry(entry, compressedBuffer)
} }
private fun addEntryCopyData(entry: ZipEntry, data: ByteBuffer, alignment: Int? = null) { private fun addEntryCopyData(
entry: ZipEntry,
data: ByteBuffer,
alignment: Int? = null,
) {
alignment?.let { alignment?.let {
// Calculate where data would end up. // Calculate where data would end up.
val dataOffset = filePointer.filePointer + entry.LFHSize val dataOffset = filePointer.filePointer + entry.LFHSize
@@ -160,7 +174,7 @@ class ZipFile(file: File) : Closeable {
return filePointer.channel.map( return filePointer.channel.map(
FileChannel.MapMode.READ_ONLY, FileChannel.MapMode.READ_ONLY,
entry.dataOffset.toLong(), entry.dataOffset.toLong(),
entry.compressedSize.toLong() entry.compressedSize.toLong(),
) )
} }
@@ -170,7 +184,10 @@ class ZipFile(file: File) : Closeable {
* @param file The file to copy entries from. * @param file The file to copy entries from.
* @param entryAlignment A function that returns the alignment for a given entry. * @param entryAlignment A function that returns the alignment for a given entry.
*/ */
fun copyEntriesFromFileAligned(file: ZipFile, entryAlignment: (entry: ZipEntry) -> Int?) { fun copyEntriesFromFileAligned(
file: ZipFile,
entryAlignment: (entry: ZipEntry) -> Int?,
) {
for (entry in file.entries) { for (entry in file.entries) {
if (entries.any { it.fileName == entry.fileName }) continue // Skip duplicates if (entries.any { it.fileName == entry.fileName }) continue // Skip duplicates
@@ -189,9 +206,13 @@ class ZipFile(file: File) : Closeable {
private const val LIBRARY_ALIGNMENT = 4096 private const val LIBRARY_ALIGNMENT = 4096
val apkZipEntryAlignment = { entry: ZipEntry -> val apkZipEntryAlignment = { entry: ZipEntry ->
if (entry.compression.toUInt() != 0u) null if (entry.compression.toUInt() != 0u) {
else if (entry.fileName.endsWith(".so")) LIBRARY_ALIGNMENT null
else DEFAULT_ALIGNMENT } else if (entry.fileName.endsWith(".so")) {
LIBRARY_ALIGNMENT
} else {
DEFAULT_ALIGNMENT
}
} }
} }
} }

View File

@@ -17,7 +17,6 @@ internal class ZipEndRecord(
val centralDirectoryStartOffset: UInt, val centralDirectoryStartOffset: UInt,
val fileComment: String, val fileComment: String,
) { ) {
companion object { companion object {
const val ECD_HEADER_SIZE = 22 const val ECD_HEADER_SIZE = 22
const val ECD_SIGNATURE = 0x06054b50u const val ECD_SIGNATURE = 0x06054b50u
@@ -25,8 +24,9 @@ internal class ZipEndRecord(
fun fromECD(input: DataInput): ZipEndRecord { fun fromECD(input: DataInput): ZipEndRecord {
val signature = input.readUIntLE() val signature = input.readUIntLE()
if (signature != ECD_SIGNATURE) if (signature != ECD_SIGNATURE) {
throw IllegalArgumentException("Input doesn't start with end record signature") throw IllegalArgumentException("Input doesn't start with end record signature")
}
val diskNumber = input.readUShortLE() val diskNumber = input.readUShortLE()
val startingDiskNumber = input.readUShortLE() val startingDiskNumber = input.readUShortLE()
@@ -50,7 +50,7 @@ internal class ZipEndRecord(
totalEntries, totalEntries,
centralDirectorySize, centralDirectorySize,
centralDirectoryStartOffset, centralDirectoryStartOffset,
fileComment fileComment,
) )
} }
} }

View File

@@ -22,7 +22,7 @@ class ZipEntry private constructor(
internal val fileName: String, internal val fileName: String,
internal val extraField: ByteArray, internal val extraField: ByteArray,
internal val fileComment: String, internal val fileComment: String,
internal var localExtraField: ByteArray = ByteArray(0), //separate for alignment internal var localExtraField: ByteArray = ByteArray(0), // separate for alignment
) { ) {
internal val LFHSize: Int internal val LFHSize: Int
get() = LFH_HEADER_SIZE + fileName.toByteArray(Charsets.UTF_8).size + localExtraField.size get() = LFH_HEADER_SIZE + fileName.toByteArray(Charsets.UTF_8).size + localExtraField.size
@@ -31,12 +31,12 @@ class ZipEntry private constructor(
get() = localHeaderOffset + LFHSize.toUInt() get() = localHeaderOffset + LFHSize.toUInt()
constructor(fileName: String) : this( constructor(fileName: String) : this(
0x1403u, //made by unix, version 20 0x1403u, // made by unix, version 20
0u, 0u,
0u, 0u,
0u, 0u,
0x0821u, //seems to be static time google uses, no idea 0x0821u, // seems to be static time google uses, no idea
0x0221u, //same as above 0x0221u, // same as above
0u, 0u,
0u, 0u,
0u, 0u,
@@ -46,21 +46,22 @@ class ZipEntry private constructor(
0u, 0u,
fileName, fileName,
ByteArray(0), ByteArray(0),
"" "",
) )
companion object { companion object {
internal const val CDE_HEADER_SIZE = 46 internal const val CDE_HEADER_SIZE = 46
internal const val CDE_SIGNATURE = 0x02014b50u internal const val CDE_SIGNATURE = 0x02014b50u
internal const val LFH_HEADER_SIZE = 30 internal const val LFH_HEADER_SIZE = 30
internal const val LFH_SIGNATURE = 0x04034b50u internal const val LFH_SIGNATURE = 0x04034b50u
internal fun fromCDE(input: DataInput): ZipEntry { internal fun fromCDE(input: DataInput): ZipEntry {
val signature = input.readUIntLE() val signature = input.readUIntLE()
if (signature != CDE_SIGNATURE) if (signature != CDE_SIGNATURE) {
throw IllegalArgumentException("Input doesn't start with central directory entry signature") throw IllegalArgumentException("Input doesn't start with central directory entry signature")
}
val version = input.readUShortLE() val version = input.readUShortLE()
val versionNeeded = input.readUShortLE() val versionNeeded = input.readUShortLE()
@@ -97,8 +98,11 @@ class ZipEntry private constructor(
fileComment = fileCommentBytes.toString(Charsets.UTF_8) fileComment = fileCommentBytes.toString(Charsets.UTF_8)
} }
flags = (flags and 0b1000u.inv() flags = (
.toUShort()) //disable data descriptor flag as they are not used flags and
0b1000u.inv()
.toUShort()
) // disable data descriptor flag as they are not used
return ZipEntry( return ZipEntry(
version, version,
@@ -121,7 +125,7 @@ class ZipEntry private constructor(
} }
} }
internal fun readLocalExtra(buffer: ByteBuffer) { internal fun readLocalExtra(buffer: ByteBuffer) {
buffer.order(ByteOrder.LITTLE_ENDIAN) buffer.order(ByteOrder.LITTLE_ENDIAN)
localExtraField = ByteArray(buffer.getUShort().toInt()) localExtraField = ByteArray(buffer.getUShort().toInt())
} }
@@ -129,8 +133,9 @@ class ZipEntry private constructor(
internal fun toLFH(): ByteBuffer { internal fun toLFH(): ByteBuffer {
val nameBytes = fileName.toByteArray(Charsets.UTF_8) val nameBytes = fileName.toByteArray(Charsets.UTF_8)
val buffer = ByteBuffer.allocate(LFH_HEADER_SIZE + nameBytes.size + localExtraField.size) val buffer =
.also { it.order(ByteOrder.LITTLE_ENDIAN) } ByteBuffer.allocate(LFH_HEADER_SIZE + nameBytes.size + localExtraField.size)
.also { it.order(ByteOrder.LITTLE_ENDIAN) }
buffer.putUInt(LFH_SIGNATURE) buffer.putUInt(LFH_SIGNATURE)
buffer.putUShort(versionNeeded) buffer.putUShort(versionNeeded)
@@ -184,4 +189,4 @@ class ZipEntry private constructor(
buffer.flip() buffer.flip()
return buffer return buffer
} }
} }

View File

@@ -45,4 +45,4 @@ internal object PatchOptionsTest {
// Do nothing // Do nothing
} }
} }
} }

View File

@@ -8,11 +8,89 @@ import org.junit.jupiter.api.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal object PatchUtilsTest { internal object PatchUtilsTest {
private val patches =
arrayOf(
newPatch("some.package", "a"),
newPatch("some.package", "a", "b", use = false),
newPatch("some.package", "a", "b", "c", use = false),
newPatch("some.other.package", "b", use = false),
newPatch("some.other.package", "b", "c"),
newPatch("some.other.package", "b", "c", "d"),
newPatch("some.other.other.package"),
newPatch("some.other.other.package", "a"),
newPatch("some.other.other.package", "b"),
newPatch("some.other.other.other.package", use = false),
newPatch("some.other.other.other.package", use = false),
).toSet()
@Test
fun `return common versions correctly ordered for each package`() {
assertEqualsVersions(
expected =
mapOf(
"some.package" to sortedMapOf("a" to 3, "b" to 2, "c" to 1),
"some.other.package" to sortedMapOf("b" to 3, "c" to 2, "d" to 1),
"some.other.other.package" to sortedMapOf("a" to 1, "b" to 1),
"some.other.other.other.package" to sortedMapOf(),
),
patches,
compatiblePackageNames =
setOf(
"some.package",
"some.other.package",
"some.other.other.package",
"some.other.other.other.package",
),
countUnusedPatches = true,
)
}
@Test
fun `return common versions correctly ordered for each package without counting unused patches`() {
assertEqualsVersions(
expected =
mapOf(
"some.package" to sortedMapOf("a" to 1),
"some.other.package" to sortedMapOf("b" to 2, "c" to 2, "d" to 1),
"some.other.other.package" to sortedMapOf("a" to 1, "b" to 1),
),
patches,
compatiblePackageNames =
setOf(
"some.package",
"some.other.package",
"some.other.other.package",
"some.other.other.other.package",
),
countUnusedPatches = false,
)
}
@Test
fun `return an empty map because no known package was supplied`() {
assertEqualsVersions(
expected = emptyMap(),
patches,
compatiblePackageNames = setOf("unknown.package"),
)
}
@Test
fun `return empty set of versions because no compatible package is constrained to a version`() {
assertEqualsVersions(
expected = mapOf("some.package" to sortedMapOf()),
patches = setOf(newPatch("some.package")),
compatiblePackageNames = setOf("some.package"),
countUnusedPatches = true,
)
}
@Test @Test
fun `return 'a' because it is the most common version`() { fun `return 'a' because it is the most common version`() {
val patches = arrayOf("a", "a", "c", "d", "a", "b", "c", "d", "a", "b", "c", "d") val patches =
.map { version -> newPatch("some.package", version) } arrayOf("a", "a", "c", "d", "a", "b", "c", "d", "a", "b", "c", "d")
.toSet() .map { version -> newPatch("some.package", version) }
.toSet()
assertEqualsVersion("a", patches, "some.package") assertEqualsVersion("a", patches, "some.package")
} }
@@ -30,20 +108,50 @@ internal object PatchUtilsTest {
} }
@Test @Test
fun `return null because no patch compatible package is constrained to a version`() { fun `return null because no compatible package is constrained to a version`() {
val patches = setOf( val patches =
newPatch("other.package"), setOf(
newPatch("other.package"), newPatch("other.package"),
) newPatch("other.package"),
)
assertEqualsVersion(null, patches, "other.package") assertEqualsVersion(null, patches, "other.package")
} }
private fun assertEqualsVersion( private fun assertEqualsVersions(
expected: String?, patches: PatchSet, compatiblePackageName: String expected: PackageNameMap,
) = assertEquals(expected, PatchUtils.getMostCommonCompatibleVersion(patches, compatiblePackageName)) patches: PatchSet,
compatiblePackageNames: Set<String>,
countUnusedPatches: Boolean = false,
) = assertEquals(
expected,
PatchUtils.getMostCommonCompatibleVersions(patches, compatiblePackageNames, countUnusedPatches),
)
private fun newPatch(packageName: String, vararg versions: String) = object : BytecodePatch() { private fun assertEqualsVersion(
expected: String?,
patches: PatchSet,
compatiblePackageName: String,
) {
// Test both the deprecated and the new method.
assertEquals(
expected,
PatchUtils.getMostCommonCompatibleVersion(patches, compatiblePackageName),
)
assertEquals(
expected,
PatchUtils.getMostCommonCompatibleVersions(patches, setOf(compatiblePackageName))
.entries.firstOrNull()?.value?.keys?.firstOrNull(),
)
}
private fun newPatch(
packageName: String,
vararg versions: String,
use: Boolean = true,
) = object : BytecodePatch() {
init { init {
// Set the compatible packages field to the supplied package name and versions reflectively, // Set the compatible packages field to the supplied package name and versions reflectively,
// because the setter is private but needed for testing. // because the setter is private but needed for testing.
@@ -51,8 +159,16 @@ internal object PatchUtilsTest {
compatiblePackagesField.isAccessible = true compatiblePackagesField.isAccessible = true
compatiblePackagesField.set(this, setOf(CompatiblePackage(packageName, versions.toSet()))) compatiblePackagesField.set(this, setOf(CompatiblePackage(packageName, versions.toSet())))
val useField = Patch::class.java.getDeclaredField("use")
useField.isAccessible = true
useField.set(this, use)
} }
override fun execute(context: BytecodeContext) {} override fun execute(context: BytecodeContext) {}
// Needed to make the patches unique.
override fun equals(other: Any?) = false
} }
} }