Commit 171ddab3 authored by topjohnwu's avatar topjohnwu

Reorganize code handling su requests

parent 2aee0b0b
...@@ -10,8 +10,8 @@ data class MagiskPolicy( ...@@ -10,8 +10,8 @@ data class MagiskPolicy(
val uid: Int, val uid: Int,
val packageName: String, val packageName: String,
val appName: String, val appName: String,
val policy: Int = INTERACTIVE, var policy: Int = INTERACTIVE,
val until: Long = -1L, var until: Long = -1L,
val logging: Boolean = true, val logging: Boolean = true,
val notification: Boolean = true, val notification: Boolean = true,
val applicationInfo: ApplicationInfo val applicationInfo: ApplicationInfo
......
...@@ -8,7 +8,6 @@ import android.view.Window ...@@ -8,7 +8,6 @@ import android.view.Window
import com.topjohnwu.magisk.R import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.base.BaseActivity import com.topjohnwu.magisk.base.BaseActivity
import com.topjohnwu.magisk.databinding.ActivityRequestBinding import com.topjohnwu.magisk.databinding.ActivityRequestBinding
import com.topjohnwu.magisk.model.entity.MagiskPolicy
import com.topjohnwu.magisk.model.events.DieEvent import com.topjohnwu.magisk.model.events.DieEvent
import com.topjohnwu.magisk.model.events.ViewEvent import com.topjohnwu.magisk.model.events.ViewEvent
import com.topjohnwu.magisk.utils.SuHandler import com.topjohnwu.magisk.utils.SuHandler
...@@ -22,7 +21,7 @@ open class SuRequestActivity : BaseActivity<SuRequestViewModel, ActivityRequestB ...@@ -22,7 +21,7 @@ open class SuRequestActivity : BaseActivity<SuRequestViewModel, ActivityRequestB
override val viewModel: SuRequestViewModel by viewModel() override val viewModel: SuRequestViewModel by viewModel()
override fun onBackPressed() { override fun onBackPressed() {
viewModel.handler?.handleAction(MagiskPolicy.DENY, -1) viewModel.denyPressed()
} }
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
......
package com.topjohnwu.magisk.ui.surequest package com.topjohnwu.magisk.ui.surequest
import android.annotation.SuppressLint
import android.content.Intent import android.content.Intent
import android.content.SharedPreferences import android.content.SharedPreferences
import android.content.pm.PackageManager import android.content.pm.PackageManager
...@@ -8,14 +7,12 @@ import android.content.res.Resources ...@@ -8,14 +7,12 @@ import android.content.res.Resources
import android.graphics.drawable.Drawable import android.graphics.drawable.Drawable
import android.hardware.fingerprint.FingerprintManager import android.hardware.fingerprint.FingerprintManager
import android.os.CountDownTimer import android.os.CountDownTimer
import android.text.TextUtils
import com.topjohnwu.magisk.BuildConfig import com.topjohnwu.magisk.BuildConfig
import com.topjohnwu.magisk.Config import com.topjohnwu.magisk.Config
import com.topjohnwu.magisk.R import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.base.viewmodel.BaseViewModel import com.topjohnwu.magisk.base.viewmodel.BaseViewModel
import com.topjohnwu.magisk.data.database.PolicyDao import com.topjohnwu.magisk.data.database.PolicyDao
import com.topjohnwu.magisk.databinding.ComparableRvItem import com.topjohnwu.magisk.databinding.ComparableRvItem
import com.topjohnwu.magisk.extensions.addOnPropertyChangedCallback
import com.topjohnwu.magisk.extensions.now import com.topjohnwu.magisk.extensions.now
import com.topjohnwu.magisk.model.entity.MagiskPolicy import com.topjohnwu.magisk.model.entity.MagiskPolicy
import com.topjohnwu.magisk.model.entity.recycler.SpinnerRvItem import com.topjohnwu.magisk.model.entity.recycler.SpinnerRvItem
...@@ -58,48 +55,25 @@ class SuRequestViewModel( ...@@ -58,48 +55,25 @@ class SuRequestViewModel(
setItems(items) setItems(items)
} }
private val cancelTasks = mutableListOf<() -> Unit>()
var handler: ActionHandler? = null private lateinit var timer: CountDownTimer
private var timer: CountDownTimer? = null private lateinit var policy: MagiskPolicy
private var policy: MagiskPolicy? = null private lateinit var connector: SuConnector
set(value) {
field = value
updatePolicy(value)
}
init {
resources.getStringArray(R.array.allow_timeout)
.map { SpinnerRvItem(it) }
.let { items.update(it) }
selectedItemPosition.addOnPropertyChangedCallback {
Timber.e("Changed position to $it")
}
}
private fun updatePolicy(policy: MagiskPolicy?) {
policy ?: return
icon.value = policy.applicationInfo.loadIcon(packageManager)
title.value = policy.appName
packageName.value = policy.packageName
selectedItemPosition.value = timeoutPrefs.getInt(policy.packageName, 0)
}
private fun cancelTimer() { private fun cancelTimer() {
timer?.cancel() timer.cancel()
denyText.value = resources.getString(R.string.deny) denyText.value = resources.getString(R.string.deny)
} }
fun grantPressed() { fun grantPressed() {
handler?.handleAction(MagiskPolicy.ALLOW) handleAction(MagiskPolicy.ALLOW)
timer?.cancel() timer.cancel()
} }
fun denyPressed() { fun denyPressed() {
handler?.handleAction(MagiskPolicy.DENY) handleAction(MagiskPolicy.DENY)
timer?.cancel() timer.cancel()
} }
fun spinnerTouched(): Boolean { fun spinnerTouched(): Boolean {
...@@ -110,75 +84,27 @@ class SuRequestViewModel( ...@@ -110,75 +84,27 @@ class SuRequestViewModel(
fun handleRequest(intent: Intent): Boolean { fun handleRequest(intent: Intent): Boolean {
val socketName = intent.getStringExtra("socket") ?: return false val socketName = intent.getStringExtra("socket") ?: return false
val connector: SuConnector
try { try {
connector = object : SuConnector(socketName) { connector = Connector(socketName)
@Throws(IOException::class) val map = connector.readRequest()
override fun onResponse() { val uid = map["uid"]?.toIntOrNull() ?: return false
out.writeInt(policy?.policy ?: return) policy = uid.toPolicy(packageManager)
} } catch (e: Exception) {
} Timber.e(e)
val bundle = connector.readSocketInput()
val uid = bundle.getString("uid")?.toIntOrNull() ?: return false
policyDB.deleteOutdated().blockingGet() // wrong!
policy = runCatching { policyDB.fetch(uid).blockingGet() }
.getOrDefault(uid.toPolicy(packageManager))
} catch (e: IOException) {
e.printStackTrace()
return false
} catch (e: PackageManager.NameNotFoundException) {
e.printStackTrace()
return false return false
} }
handler = object : ActionHandler() {
override fun handleAction() {
connector.response()
done()
}
@SuppressLint("ApplySharedPref")
override fun handleAction(action: Int) {
val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy?.packageName, pos).commit()
handleAction(action, Config.Value.TIMEOUT_LIST[pos])
}
override fun handleAction(action: Int, time: Int) {
val until = if (time >= 0) {
if (time == 0) {
0
} else {
MILLISECONDS.toSeconds(now) + MINUTES.toSeconds(time.toLong())
}
} else {
policy?.until ?: 0
}
policy = policy?.copy(policy = action, until = until)?.apply {
policyDB.update(this).blockingGet()
}
handleAction()
}
}
// Never allow com.topjohnwu.magisk (could be malware) // Never allow com.topjohnwu.magisk (could be malware)
if (TextUtils.equals(policy?.packageName, BuildConfig.APPLICATION_ID)) if (policy.packageName == BuildConfig.APPLICATION_ID)
return false return false
// If not interactive, response directly
if (policy?.policy != MagiskPolicy.INTERACTIVE) {
handler?.handleAction()
return true
}
when (Config.suAutoReponse) { when (Config.suAutoReponse) {
Config.Value.SU_AUTO_DENY -> { Config.Value.SU_AUTO_DENY -> {
handler?.handleAction(MagiskPolicy.DENY, 0) handleAction(MagiskPolicy.DENY, 0)
return true return true
} }
Config.Value.SU_AUTO_ALLOW -> { Config.Value.SU_AUTO_ALLOW -> {
handler?.handleAction(MagiskPolicy.ALLOW, 0) handleAction(MagiskPolicy.ALLOW, 0)
return true return true
} }
} }
...@@ -187,77 +113,90 @@ class SuRequestViewModel( ...@@ -187,77 +113,90 @@ class SuRequestViewModel(
return true return true
} }
@SuppressLint("ClickableViewAccessibility")
private fun showUI() { private fun showUI() {
resources.getStringArray(R.array.allow_timeout)
.map { SpinnerRvItem(it) }
.let { items.update(it) }
icon.value = policy.applicationInfo.loadIcon(packageManager)
title.value = policy.appName
packageName.value = policy.packageName
selectedItemPosition.value = timeoutPrefs.getInt(policy.packageName, 0)
val millis = SECONDS.toMillis(Config.suDefaultTimeout.toLong()) val millis = SECONDS.toMillis(Config.suDefaultTimeout.toLong())
timer = object : CountDownTimer(millis, 1000) { timer = object : CountDownTimer(millis, 1000) {
override fun onTick(remains: Long) { override fun onTick(remains: Long) {
denyText.value = "%s (%d)" denyText.value = "${resources.getString(R.string.deny)} (${remains / 1000})"
.format(resources.getString(R.string.deny), remains / 1000)
} }
override fun onFinish() { override fun onFinish() {
denyText.value = resources.getString(R.string.deny) denyText.value = resources.getString(R.string.deny)
handler?.handleAction(MagiskPolicy.DENY) handleAction(MagiskPolicy.DENY)
} }
} }
timer?.start() timer.start()
handler?.addCancel(Runnable { cancelTimer() }) cancelTasks.add { cancelTimer() }
val useFP = canUseFingerprint.value
if (useFP) if (canUseFingerprint.value)
try { runCatching {
val helper = SuFingerprint() val helper = SuFingerprint()
helper.authenticate() helper.authenticate()
handler?.addCancel(Runnable { helper.cancel() }) cancelTasks.add { helper.cancel() }
} catch (e: Exception) {
e.printStackTrace()
} }
} }
private inner class SuFingerprint @Throws(Exception::class) private fun handleAction() {
internal constructor() : FingerprintHelper() { connector.response()
cancelTasks.forEach { it() }
override fun onAuthenticationError(errorCode: Int, errString: CharSequence) { DieEvent().publish()
warningText.value = errString
} }
override fun onAuthenticationHelp(helpCode: Int, helpString: CharSequence) { private fun handleAction(action: Int) {
warningText.value = helpString val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy.packageName, pos).apply()
handleAction(action, Config.Value.TIMEOUT_LIST[pos])
} }
override fun onAuthenticationSucceeded(result: FingerprintManager.AuthenticationResult) { private fun handleAction(action: Int, time: Int) {
handler?.handleAction(MagiskPolicy.ALLOW) val until = if (time > 0)
MILLISECONDS.toSeconds(now) + MINUTES.toSeconds(time.toLong())
else
time.toLong()
policy.policy = action
policy.until = until
if (until >= 0)
policyDB.update(policy).blockingAwait()
handleAction()
} }
override fun onAuthenticationFailed() { private inner class Connector @Throws(Exception::class)
warningText.value = resources.getString(R.string.auth_fail) internal constructor(name: String) : SuConnector(name) {
@Throws(IOException::class)
override fun onResponse() {
out.writeInt(policy.policy)
} }
} }
open inner class ActionHandler { private inner class SuFingerprint @Throws(Exception::class)
private val cancelTasks = mutableListOf<Runnable>() internal constructor() : FingerprintHelper() {
internal open fun handleAction() {
done()
}
internal open fun handleAction(action: Int) { override fun onAuthenticationError(errorCode: Int, errString: CharSequence) {
done() warningText.value = errString
} }
internal open fun handleAction(action: Int, time: Int) { override fun onAuthenticationHelp(helpCode: Int, helpString: CharSequence) {
done() warningText.value = helpString
} }
internal fun addCancel(r: Runnable) { override fun onAuthenticationSucceeded(result: FingerprintManager.AuthenticationResult) {
cancelTasks.add(r) handleAction(MagiskPolicy.ALLOW)
} }
internal fun done() { override fun onAuthenticationFailed() {
cancelTasks.forEach { it.run() } warningText.value = resources.getString(R.string.auth_fail)
DieEvent().publish()
} }
} }
......
...@@ -2,11 +2,9 @@ package com.topjohnwu.magisk.utils ...@@ -2,11 +2,9 @@ package com.topjohnwu.magisk.utils
import android.net.LocalSocket import android.net.LocalSocket
import android.net.LocalSocketAddress import android.net.LocalSocketAddress
import android.os.Bundle import androidx.collection.ArrayMap
import android.text.TextUtils
import timber.log.Timber import timber.log.Timber
import java.io.* import java.io.*
import java.nio.charset.Charset
abstract class SuConnector @Throws(IOException::class) abstract class SuConnector @Throws(IOException::class)
protected constructor(name: String) { protected constructor(name: String) {
...@@ -21,24 +19,23 @@ protected constructor(name: String) { ...@@ -21,24 +19,23 @@ protected constructor(name: String) {
input = DataInputStream(BufferedInputStream(socket.inputStream)) input = DataInputStream(BufferedInputStream(socket.inputStream))
} }
@Throws(IOException::class)
private fun readString(): String { private fun readString(): String {
val len = input.readInt() val len = input.readInt()
val buf = ByteArray(len) val buf = ByteArray(len)
input.readFully(buf) input.readFully(buf)
return String(buf, Charset.forName("UTF-8")) return String(buf, Charsets.UTF_8)
} }
@Throws(IOException::class) @Throws(IOException::class)
fun readSocketInput(): Bundle { fun readRequest(): Map<String, String> {
val bundle = Bundle() val ret = ArrayMap<String, String>()
while (true) { while (true) {
val name = readString() val name = readString()
if (TextUtils.equals(name, "eof")) if (name == "eof")
break break
bundle.putString(name, readString()) ret[name] = readString()
} }
return bundle return ret
} }
fun response() { fun response() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment