Commit e7f1c031 authored by topjohnwu's avatar topjohnwu

Cleanup code for su request

parent 56602cb9
...@@ -20,8 +20,8 @@ import com.topjohnwu.magisk.extensions.get ...@@ -20,8 +20,8 @@ import com.topjohnwu.magisk.extensions.get
import com.topjohnwu.magisk.extensions.startActivity import com.topjohnwu.magisk.extensions.startActivity
import com.topjohnwu.magisk.extensions.startActivityWithRoot import com.topjohnwu.magisk.extensions.startActivityWithRoot
import com.topjohnwu.magisk.extensions.subscribeK import com.topjohnwu.magisk.extensions.subscribeK
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.magisk.model.entity.toLog import com.topjohnwu.magisk.model.entity.toLog
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.superuser.Shell import com.topjohnwu.superuser.Shell
import timber.log.Timber import timber.log.Timber
...@@ -51,20 +51,8 @@ object SuCallbackHandler : ProviderCallHandler { ...@@ -51,20 +51,8 @@ object SuCallbackHandler : ProviderCallHandler {
} }
when (action) { when (action) {
REQUEST -> { REQUEST -> handleRequest(context, data)
val intent = context.intent<SuRequestActivity>() LOG -> handleLogging(context, data)
.setAction(action)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
LOG -> handleLogs(context, data)
NOTIFY -> handleNotify(context, data) NOTIFY -> handleNotify(context, data)
TEST -> { TEST -> {
val mode = data.getInt("mode", 2) val mode = data.getInt("mode", 2)
...@@ -78,13 +66,26 @@ object SuCallbackHandler : ProviderCallHandler { ...@@ -78,13 +66,26 @@ object SuCallbackHandler : ProviderCallHandler {
private fun Any?.toInt(): Int? { private fun Any?.toInt(): Int? {
return when (this) { return when (this) {
is Int -> this is Number -> this.toInt()
is Long -> this.toInt()
else -> null else -> null
} }
} }
private fun handleLogs(context: Context, data: Bundle) { private fun handleRequest(context: Context, data: Bundle) {
val intent = context.intent<SuRequestActivity>()
.setAction(REQUEST)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
private fun handleLogging(context: Context, data: Bundle) {
val fromUid = data["from.uid"].toInt() ?: return val fromUid = data["from.uid"].toInt() ?: return
if (fromUid == Process.myUid()) if (fromUid == Process.myUid())
return return
......
package com.topjohnwu.magisk.core.su
import android.net.LocalSocket
import android.net.LocalSocketAddress
import androidx.collection.ArrayMap
import timber.log.Timber
import java.io.*
abstract class SuConnector @Throws(IOException::class)
protected constructor(name: String) {
private val socket: LocalSocket = LocalSocket()
protected var out: DataOutputStream
protected var input: DataInputStream
init {
socket.connect(LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT))
out = DataOutputStream(BufferedOutputStream(socket.outputStream))
input = DataInputStream(BufferedInputStream(socket.inputStream))
}
private fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
@Throws(IOException::class)
fun readRequest(): Map<String, String> {
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
fun response() {
runCatching {
onResponse()
out.flush()
}.onFailure { Timber.e(it) }
runCatching {
input.close()
out.close()
socket.close()
}
}
@Throws(IOException::class)
protected abstract fun onResponse()
}
...@@ -2,7 +2,10 @@ package com.topjohnwu.magisk.core.su ...@@ -2,7 +2,10 @@ package com.topjohnwu.magisk.core.su
import android.content.Intent import android.content.Intent
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.CountDownTimer import android.os.CountDownTimer
import androidx.collection.ArrayMap
import com.topjohnwu.magisk.BuildConfig import com.topjohnwu.magisk.BuildConfig
import com.topjohnwu.magisk.core.Config import com.topjohnwu.magisk.core.Config
import com.topjohnwu.magisk.core.Const import com.topjohnwu.magisk.core.Const
...@@ -11,43 +14,36 @@ import com.topjohnwu.magisk.core.model.MagiskPolicy ...@@ -11,43 +14,36 @@ import com.topjohnwu.magisk.core.model.MagiskPolicy
import com.topjohnwu.magisk.core.model.toPolicy import com.topjohnwu.magisk.core.model.toPolicy
import com.topjohnwu.magisk.extensions.now import com.topjohnwu.magisk.extensions.now
import timber.log.Timber import timber.log.Timber
import java.io.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
abstract class SuRequestHandler( abstract class SuRequestHandler(
private val packageManager: PackageManager, private val packageManager: PackageManager,
private val policyDB: PolicyDao private val policyDB: PolicyDao
) { ) {
protected var timer: CountDownTimer = object : CountDownTimer( private val socket: LocalSocket = LocalSocket()
TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) { private lateinit var out: DataOutputStream
override fun onFinish() { private lateinit var input: DataInputStream
respond(MagiskPolicy.DENY, 0)
} protected var timer: CountDownTimer = DefaultCountDown()
override fun onTick(remains: Long) {}
}
set(value) { set(value) {
field.cancel() field.cancel()
field = value field = value
field.start() field.start()
} }
protected lateinit var policy: MagiskPolicy protected lateinit var policy: MagiskPolicy
private set
private val cleanupTasks = mutableListOf<() -> Unit>()
private lateinit var connector: SuConnector
abstract fun onStart() abstract fun onStart()
abstract fun onRespond()
fun start(intent: Intent): Boolean { fun start(intent: Intent): Boolean {
val socketName = intent.getStringExtra("socket") ?: return false val socketName = intent.getStringExtra("socket") ?: return false
try { try {
connector = object : SuConnector(socketName) { socket.connect(LocalSocketAddress(socketName, LocalSocketAddress.Namespace.ABSTRACT))
override fun onResponse() { out = DataOutputStream(BufferedOutputStream(socket.outputStream))
out.writeInt(policy.policy) input = DataInputStream(BufferedInputStream(socket.inputStream))
} val map = readRequest()
}
val map = connector.readRequest()
val uid = map["uid"]?.toIntOrNull() ?: return false val uid = map["uid"]?.toIntOrNull() ?: return false
policy = uid.toPolicy(packageManager) policy = uid.toPolicy(packageManager)
} catch (e: Exception) { } catch (e: Exception) {
...@@ -71,20 +67,10 @@ abstract class SuRequestHandler( ...@@ -71,20 +67,10 @@ abstract class SuRequestHandler(
} }
timer.start() timer.start()
cleanupTasks.add {
timer.cancel()
}
onStart() onStart()
return true return true
} }
private fun respond() {
connector.response()
cleanupTasks.forEach { it() }
onRespond()
}
fun respond(action: Int, time: Int) { fun respond(action: Int, time: Int) {
val until = if (time > 0) val until = if (time > 0)
TimeUnit.MILLISECONDS.toSeconds(now) + TimeUnit.MINUTES.toSeconds(time.toLong()) TimeUnit.MILLISECONDS.toSeconds(now) + TimeUnit.MINUTES.toSeconds(time.toLong())
...@@ -98,6 +84,45 @@ abstract class SuRequestHandler( ...@@ -98,6 +84,45 @@ abstract class SuRequestHandler(
if (until >= 0) if (until >= 0)
policyDB.update(policy).blockingAwait() policyDB.update(policy).blockingAwait()
respond() try {
out.writeInt(policy.policy)
out.flush()
} catch (e: IOException) {
Timber.e(e)
} finally {
runCatching {
input.close()
out.close()
socket.close()
}
}
timer.cancel()
}
@Throws(IOException::class)
private fun readRequest(): Map<String, String> {
fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
private inner class DefaultCountDown
: CountDownTimer(TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) {
override fun onFinish() {
respond(MagiskPolicy.DENY, 0)
}
override fun onTick(remains: Long) {}
} }
} }
...@@ -85,6 +85,9 @@ class SuRequestViewModel( ...@@ -85,6 +85,9 @@ class SuRequestViewModel(
val pos = selectedItemPosition.value val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy.packageName, pos).apply() timeoutPrefs.edit().putInt(policy.packageName, pos).apply()
respond(action, Config.Value.TIMEOUT_LIST[pos]) respond(action, Config.Value.TIMEOUT_LIST[pos])
// Kill activity after response
DieEvent().publish()
} }
fun cancelTimer() { fun cancelTimer() {
...@@ -118,11 +121,6 @@ class SuRequestViewModel( ...@@ -118,11 +121,6 @@ class SuRequestViewModel(
} }
} }
} }
override fun onRespond() {
// Kill activity after response
DieEvent().publish()
}
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment