Commit fb8000b5 authored by topjohnwu's avatar topjohnwu

Handle invalid SafetyNet results

Fix #4253
parent 1b9d8e06
...@@ -25,9 +25,7 @@ ...@@ -25,9 +25,7 @@
# Snet # Snet
-keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; } -keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; }
-keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback -keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback
-keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { -keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { *; }
void onResponse(java.lang.String);
}
# Stub # Stub
-keep class com.topjohnwu.magisk.core.App { <init>(java.lang.Object); } -keep class com.topjohnwu.magisk.core.App { <init>(java.lang.Object); }
......
...@@ -29,7 +29,7 @@ object Const { ...@@ -29,7 +29,7 @@ object Const {
const val MAGISK_LOG = "/cache/magisk.log" const val MAGISK_LOG = "/cache/magisk.log"
// Versions // Versions
const val SNET_EXT_VER = 16 const val SNET_EXT_VER = 17
const val SNET_REVISION = "22.0" const val SNET_REVISION = "22.0"
const val BOOTCTL_REVISION = "22.0" const val BOOTCTL_REVISION = "22.0"
......
...@@ -31,10 +31,9 @@ import java.io.ByteArrayInputStream ...@@ -31,10 +31,9 @@ import java.io.ByteArrayInputStream
import java.io.File import java.io.File
import java.io.IOException import java.io.IOException
import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationHandler
import java.security.GeneralSecurityException import java.lang.reflect.Proxy
import java.security.SecureRandom import java.security.SecureRandom
import java.security.Signature import java.security.Signature
import java.security.cert.X509Certificate
class CheckSafetyNetEvent( class CheckSafetyNetEvent(
private val callback: (SafetyNetResult) -> Unit = {} private val callback: (SafetyNetResult) -> Unit = {}
...@@ -42,19 +41,17 @@ class CheckSafetyNetEvent( ...@@ -42,19 +41,17 @@ class CheckSafetyNetEvent(
private val svc get() = ServiceLocator.networkService private val svc get() = ServiceLocator.networkService
private lateinit var apk: File private lateinit var jar: File
private lateinit var dex: File
private lateinit var nonce: ByteArray private lateinit var nonce: ByteArray
override fun invoke(context: Context) { override fun invoke(context: Context) {
apk = File("${context.filesDir.parent}/snet", "snet.jar") jar = File("${context.filesDir.parent}/snet", "snet.jar")
dex = File(apk.parent, "snet.dex")
scope.launch(Dispatchers.IO) { scope.launch(Dispatchers.IO) {
attest(context) { attest(context) {
// Download and retry // Download and retry
Shell.sh("rm -rf " + apk.parent).exec() Shell.sh("rm -rf " + jar.parent).exec()
apk.parentFile?.mkdir() jar.parentFile?.mkdir()
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
showDialog(context) showDialog(context)
} }
...@@ -65,13 +62,12 @@ class CheckSafetyNetEvent( ...@@ -65,13 +62,12 @@ class CheckSafetyNetEvent(
private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) { private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) {
val helper: SafetyNetHelper val helper: SafetyNetHelper
try { try {
val loader = createClassLoader(apk) val loader = createClassLoader(jar)
// Scan through the dex and find our helper class // Scan through the dex and find our helper class
var clazz: Class<*>? = null var clazz: Class<*>? = null
loop@for (dex in loader.getDexFiles()) { loop@for (dex in loader.getDexFiles()) {
for (name in dex.entries()) { for (name in dex.entries()) {
if (name.startsWith("x.")) {
val cls = loader.loadClass(name) val cls = loader.loadClass(name)
if (InvocationHandler::class.java.isAssignableFrom(cls)) { if (InvocationHandler::class.java.isAssignableFrom(cls)) {
clazz = cls clazz = cls
...@@ -79,11 +75,11 @@ class CheckSafetyNetEvent( ...@@ -79,11 +75,11 @@ class CheckSafetyNetEvent(
} }
} }
} }
} clazz ?: throw Exception("Cannot find SafetyNetHelper implementation")
clazz ?: throw Exception("Cannot find SafetyNetHelper class")
helper = clazz.getMethod("get", Class::class.java, Context::class.java, Any::class.java) helper = Proxy.newProxyInstance(
.invoke(null, SafetyNetHelper::class.java, context, this) as SafetyNetHelper loader, arrayOf(SafetyNetHelper::class.java),
clazz.newInstance() as InvocationHandler) as SafetyNetHelper
if (helper.version != Const.SNET_EXT_VER) if (helper.version != Const.SNET_EXT_VER)
throw Exception("snet extension version mismatch") throw Exception("snet extension version mismatch")
...@@ -95,7 +91,7 @@ class CheckSafetyNetEvent( ...@@ -95,7 +91,7 @@ class CheckSafetyNetEvent(
val random = SecureRandom() val random = SecureRandom()
nonce = ByteArray(24) nonce = ByteArray(24)
random.nextBytes(nonce) random.nextBytes(nonce)
helper.attest(nonce) helper.attest(context, nonce, this)
} }
// All of these fields are whitelisted // All of these fields are whitelisted
...@@ -114,7 +110,7 @@ class CheckSafetyNetEvent( ...@@ -114,7 +110,7 @@ class CheckSafetyNetEvent(
} }
} }
try { try {
svc.fetchSafetynet().byteStream().writeTo(apk) svc.fetchSafetynet().byteStream().writeTo(jar)
attest(context, abort) attest(context, abort)
} catch (e: IOException) { } catch (e: IOException) {
abort(e) abort(e)
...@@ -147,7 +143,7 @@ class CheckSafetyNetEvent( ...@@ -147,7 +143,7 @@ class CheckSafetyNetEvent(
Base64.decode(this, Base64.URL_SAFE) Base64.decode(this, Base64.URL_SAFE)
} }
private fun String.parseJws(): SafetyNetResponse? { private fun String.parseJws(): SafetyNetResponse {
val jws = split('.') val jws = split('.')
val secondDot = lastIndexOf('.') val secondDot = lastIndexOf('.')
val rawHeader = String(jws[0].decode()) val rawHeader = String(jws[0].decode())
...@@ -156,7 +152,8 @@ class CheckSafetyNetEvent( ...@@ -156,7 +152,8 @@ class CheckSafetyNetEvent(
val signedBytes = substring(0, secondDot).toByteArray() val signedBytes = substring(0, secondDot).toByteArray()
val moshi = Moshi.Builder().build() val moshi = Moshi.Builder().build()
val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader) ?: return null val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader)
?: error("Invalid JWS header")
val alg = when (header.algorithm) { val alg = when (header.algorithm) {
"RS256" -> "SHA256withRSA" "RS256" -> "SHA256withRSA"
...@@ -165,41 +162,30 @@ class CheckSafetyNetEvent( ...@@ -165,41 +162,30 @@ class CheckSafetyNetEvent(
signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER) signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER)
"SHA256withECDSA" "SHA256withECDSA"
} }
else -> return null else -> error("Unsupported algorithm: ${header.algorithm}")
} }
// Verify signature // Verify signature
val certB64 = header.certificates?.first() ?: return null val certB64 = header.certificates?.first() ?: error("Cannot find certificate in JWS")
val certDer = certB64.decode() val bis = ByteArrayInputStream(certB64.decode())
val bis = ByteArrayInputStream(certDer) val cert = CryptoUtils.readCertificate(bis)
val cert: X509Certificate
try {
cert = CryptoUtils.readCertificate(bis)
val verifier = Signature.getInstance(alg) val verifier = Signature.getInstance(alg)
verifier.initVerify(cert.publicKey) verifier.initVerify(cert.publicKey)
verifier.update(signedBytes) verifier.update(signedBytes)
if (!verifier.verify(signature)) if (!verifier.verify(signature))
return null error("Signature mismatch")
} catch (e: GeneralSecurityException) {
Timber.e(e)
return null
}
// Verify hostname // Verify hostname
val hostNameVerifier = JsseDefaultHostnameAuthorizer(setOf()) val hostnameVerifier = JsseDefaultHostnameAuthorizer(setOf())
try { if (!hostnameVerifier.verify("attest.android.com", cert))
if (!hostNameVerifier.verify("attest.android.com", cert)) error("Hostname mismatch")
return null
} catch (e: IOException) {
Timber.e(e)
return null
}
val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload) ?: return null val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload)
?: error("Invalid SafetyNet response")
// Verify results // Verify results
if (!response.nonce.decode().contentEquals(nonce)) if (!response.nonce.decode().contentEquals(nonce))
return null error("nonce mismatch")
return response return response
} }
...@@ -207,7 +193,10 @@ class CheckSafetyNetEvent( ...@@ -207,7 +193,10 @@ class CheckSafetyNetEvent(
override fun onResponse(response: String?) { override fun onResponse(response: String?) {
if (response != null) { if (response != null) {
scope.launch(Dispatchers.Default) { scope.launch(Dispatchers.Default) {
val res = response.parseJws() val res = runCatching { response.parseJws() }.getOrElse {
Timber.e(it)
INVALID_RESPONSE
}
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
callback(SafetyNetResult(res)) callback(SafetyNetResult(res))
} }
...@@ -231,3 +220,6 @@ data class SafetyNetResponse( ...@@ -231,3 +220,6 @@ data class SafetyNetResponse(
val basicIntegrity: Boolean, val basicIntegrity: Boolean,
val evaluationType: String = "" val evaluationType: String = ""
) )
// Special instance to indicate invalid SafetyNet response
val INVALID_RESPONSE = SafetyNetResponse("", ctsProfileMatch = false, basicIntegrity = false)
package com.topjohnwu.magisk.ui.safetynet package com.topjohnwu.magisk.ui.safetynet
import android.content.Context
interface SafetyNetHelper { interface SafetyNetHelper {
val version: Int val version: Int
fun attest(nonce: ByteArray) fun attest(context: Context, nonce: ByteArray, callback: Callback)
interface Callback { interface Callback {
fun onResponse(response: String?) fun onResponse(response: String?)
......
...@@ -6,7 +6,7 @@ import com.topjohnwu.magisk.R ...@@ -6,7 +6,7 @@ import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.arch.BaseViewModel import com.topjohnwu.magisk.arch.BaseViewModel
import com.topjohnwu.magisk.utils.set import com.topjohnwu.magisk.utils.set
data class SafetyNetResult( class SafetyNetResult(
val response: SafetyNetResponse? = null, val response: SafetyNetResponse? = null,
val dismiss: Boolean = false val dismiss: Boolean = false
) )
...@@ -42,37 +42,43 @@ class SafetynetViewModel : BaseViewModel() { ...@@ -42,37 +42,43 @@ class SafetynetViewModel : BaseViewModel() {
init { init {
cachedResult?.also { cachedResult?.also {
handleResponse(SafetyNetResult(it)) handleResult(SafetyNetResult(it))
} ?: attest() } ?: attest()
} }
private fun attest() { private fun attest() {
isChecking = true isChecking = true
CheckSafetyNetEvent { CheckSafetyNetEvent(::handleResult).publish()
handleResponse(it)
}.publish()
} }
fun reset() = attest() fun reset() = attest()
private fun handleResponse(response: SafetyNetResult) { private fun handleResult(result: SafetyNetResult) {
isChecking = false isChecking = false
if (response.dismiss) { if (result.dismiss) {
back() back()
return return
} }
response.response?.apply { result.response?.apply {
val result = ctsProfileMatch && basicIntegrity
cachedResult = this cachedResult = this
if (this === INVALID_RESPONSE) {
isSuccess = false
ctsState = false
basicIntegrityState = false
evalType = "N/A"
safetyNetTitle = R.string.safetynet_res_invalid
} else {
val success = ctsProfileMatch && basicIntegrity
isSuccess = success
ctsState = ctsProfileMatch ctsState = ctsProfileMatch
basicIntegrityState = basicIntegrity basicIntegrityState = basicIntegrity
evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC" evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC"
isSuccess = result
safetyNetTitle = safetyNetTitle =
if (result) R.string.safetynet_attest_success if (success) R.string.safetynet_attest_success
else R.string.safetynet_attest_failure else R.string.safetynet_attest_failure
}
} ?: { } ?: {
isSuccess = false isSuccess = false
ctsState = false ctsState = false
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment