fix(core): remote media resolver

Use of coroutines to cancel network requests if necessary

Signed-off-by: rhunk <101876869+rhunk@users.noreply.github.com>
This commit is contained in:
rhunk 2025-01-03 00:22:45 +01:00
parent 8e8220a55e
commit 3d50054d38
6 changed files with 56 additions and 51 deletions

View File

@ -1,7 +1,7 @@
package me.rhunk.snapenhance.common.util.snap package me.rhunk.snapenhance.common.util.snap
import me.rhunk.snapenhance.common.Constants import me.rhunk.snapenhance.common.Constants
import me.rhunk.snapenhance.common.logger.AbstractLogger import me.rhunk.snapenhance.common.util.ktx.await
import okhttp3.Headers import okhttp3.Headers
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
@ -41,21 +41,21 @@ object RemoteMediaResolver {
.addHeader("User-Agent", Constants.USER_AGENT) .addHeader("User-Agent", Constants.USER_AGENT)
.build() .build()
/** suspend inline fun downloadMedia(url: String, decryptionCallback: (InputStream) -> InputStream = { it }, result: (InputStream, Long) -> Unit) {
* Download bolt media with memory allocation okHttpClient.newCall(Request.Builder().url(url).build()).await().use { response ->
*/
fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }): ByteArray? {
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response ->
if (!response.isSuccessful) { if (!response.isSuccessful) {
AbstractLogger.directDebug("Unexpected code $response") throw Throwable("invalid response ${response.code}")
return null
} }
return decryptionCallback(response.body.byteStream()).readBytes() result(decryptionCallback(response.body.byteStream()), response.body.contentLength())
} }
} }
inline fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }, resultCallback: (stream: InputStream, length: Long) -> Unit) { suspend inline fun downloadBoltMedia(
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response -> protoKey: ByteArray,
decryptionCallback: (InputStream) -> InputStream = { it },
resultCallback: (stream: InputStream, length: Long) -> Unit
) {
okHttpClient.newCall(newResolveRequest(protoKey)).await().use { response ->
if (!response.isSuccessful) { if (!response.isSuccessful) {
throw Throwable("invalid response ${response.code}") throw Throwable("invalid response ${response.code}")
} }

View File

@ -45,6 +45,7 @@ import me.rhunk.snapenhance.common.ui.createComposeAlertDialog
import me.rhunk.snapenhance.common.ui.rememberAsyncMutableState import me.rhunk.snapenhance.common.ui.rememberAsyncMutableState
import me.rhunk.snapenhance.common.util.ktx.copyToClipboard import me.rhunk.snapenhance.common.util.ktx.copyToClipboard
import me.rhunk.snapenhance.common.util.snap.BitmojiSelfie import me.rhunk.snapenhance.common.util.snap.BitmojiSelfie
import me.rhunk.snapenhance.common.util.snap.RemoteMediaResolver
import me.rhunk.snapenhance.core.action.AbstractAction import me.rhunk.snapenhance.core.action.AbstractAction
import me.rhunk.snapenhance.core.features.impl.experiments.AddFriendSourceSpoof import me.rhunk.snapenhance.core.features.impl.experiments.AddFriendSourceSpoof
import me.rhunk.snapenhance.core.features.impl.experiments.BetterLocation import me.rhunk.snapenhance.core.features.impl.experiments.BetterLocation
@ -405,13 +406,13 @@ class BulkMessagingAction : AbstractAction() {
if (bitmojiBitmap != null || friendInfo.bitmojiAvatarId == null || friendInfo.bitmojiSelfieId == null) return@withContext if (bitmojiBitmap != null || friendInfo.bitmojiAvatarId == null || friendInfo.bitmojiSelfieId == null) return@withContext
val bitmojiUrl = BitmojiSelfie.getBitmojiSelfie(friendInfo.bitmojiSelfieId, friendInfo.bitmojiAvatarId, BitmojiSelfie.BitmojiSelfieType.NEW_THREE_D) ?: return@withContext val bitmojiUrl = BitmojiSelfie.getBitmojiSelfie(friendInfo.bitmojiSelfieId, friendInfo.bitmojiAvatarId, BitmojiSelfie.BitmojiSelfieType.NEW_THREE_D) ?: return@withContext
runCatching { runCatching {
URL(bitmojiUrl).openStream().use { input -> RemoteMediaResolver.downloadMedia(bitmojiUrl) { inputStream, length ->
bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext] = BitmapFactory.decodeStream(input) bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext] = BitmapFactory.decodeStream(inputStream).also {
bitmojiBitmap = it
}
} }
bitmojiBitmap = bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext]
}.onFailure {
context.log.error("Failed to load bitmoji", it)
} }
} }
} }

View File

@ -31,7 +31,7 @@ data class DecodedAttachment(
} }
@OptIn(ExperimentalEncodingApi::class) @OptIn(ExperimentalEncodingApi::class)
inline fun openStream(callback: (mediaStream: InputStream?, length: Long) -> Unit) { suspend inline fun openStream(callback: (mediaStream: InputStream?, length: Long) -> Unit) {
boltKey?.let { mediaUrlKey -> boltKey?.let { mediaUrlKey ->
RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(mediaUrlKey), decryptionCallback = { RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(mediaUrlKey), decryptionCallback = {
attachmentInfo?.encryption?.decryptInputStream(it) ?: it attachmentInfo?.encryption?.decryptInputStream(it) ?: it
@ -39,11 +39,10 @@ data class DecodedAttachment(
callback(inputStream, length) callback(inputStream, length)
}) })
} ?: directUrl?.let { rawMediaUrl -> } ?: directUrl?.let { rawMediaUrl ->
val connection = URL(rawMediaUrl).openConnection() RemoteMediaResolver.downloadMedia(rawMediaUrl, decryptionCallback = {
connection.getInputStream().let {
attachmentInfo?.encryption?.decryptInputStream(it) ?: it attachmentInfo?.encryption?.decryptInputStream(it) ?: it
}.use { }) { inputStream, length ->
callback(it, connection.contentLengthLong) callback(inputStream, length)
} }
} ?: callback(null, 0) } ?: callback(null, 0)
} }

View File

@ -310,7 +310,7 @@ class Notifications : Feature("Notifications") {
}.send() }.send()
} }
private fun onMessageReceived(data: NotificationData, notificationType: String, message: Message) { private suspend fun onMessageReceived(data: NotificationData, notificationType: String, message: Message) {
val conversationId = message.messageDescriptor?.conversationId.toString() val conversationId = message.messageDescriptor?.conversationId.toString()
val orderKey = message.orderKey ?: return val orderKey = message.orderKey ?: return
val senderUsername by lazy { val senderUsername by lazy {
@ -487,16 +487,18 @@ class Notifications : Feature("Notifications") {
suspendCoroutine { continuation -> suspendCoroutine { continuation ->
conversationManager.fetchMessageByServerId(conversationId, serverMessageId.toLong(), onSuccess = { conversationManager.fetchMessageByServerId(conversationId, serverMessageId.toLong(), onSuccess = {
if (it.senderId.toString() == context.database.myUserId) { if (it.senderId.toString() == context.database.myUserId) {
param.invokeOriginal()
continuation.resumeWith(Result.success(Unit)) continuation.resumeWith(Result.success(Unit))
param.invokeOriginal()
return@fetchMessageByServerId return@fetchMessageByServerId
} }
onMessageReceived(notificationData, notificationType, it) context.coroutineScope.launch(coroutineDispatcher) {
continuation.resumeWith(Result.success(Unit)) continuation.resumeWith(Result.success(Unit))
onMessageReceived(notificationData, notificationType, it)
}
}, onError = { }, onError = {
context.log.error("Failed to fetch message id ${serverMessageId}: $it") context.log.error("Failed to fetch message id ${serverMessageId}: $it")
param.invokeOriginal()
continuation.resumeWith(Result.success(Unit)) continuation.resumeWith(Result.success(Unit))
param.invokeOriginal()
}) })
} }
} }

View File

@ -318,7 +318,7 @@ class FriendTracker : Feature("Friend Tracker") {
// allow events when a notification is received // allow events when a notification is received
hookConstructor(HookStage.AFTER) { param -> hookConstructor(HookStage.AFTER) { param ->
methods.first { it.name == "appStateChanged" }.let { method -> methods.first { it.name == "appStateChanged" }.let { method ->
method.invoke(param.thisObject(), method.parameterTypes[0].enumConstants.first { it.toString() == "ACTIVE" }) method.invoke(param.thisObject(), method.parameterTypes[0].enumConstants!!.first { it.toString() == "ACTIVE" })
} }
} }
} }

View File

@ -3,6 +3,7 @@ package me.rhunk.snapenhance.core.messaging
import android.util.Base64InputStream import android.util.Base64InputStream
import android.util.Base64OutputStream import android.util.Base64OutputStream
import com.google.gson.stream.JsonWriter import com.google.gson.stream.JsonWriter
import kotlinx.coroutines.runBlocking
import me.rhunk.snapenhance.common.BuildConfig import me.rhunk.snapenhance.common.BuildConfig
import me.rhunk.snapenhance.common.data.ContentType import me.rhunk.snapenhance.common.data.ContentType
import me.rhunk.snapenhance.common.database.impl.FriendFeedEntry import me.rhunk.snapenhance.common.database.impl.FriendFeedEntry
@ -132,6 +133,7 @@ class ConversationExporter(
for (i in 0..5) { for (i in 0..5) {
printLog("downloading ${attachment.boltKey ?: attachment.directUrl}... (attempt ${i + 1}/5)") printLog("downloading ${attachment.boltKey ?: attachment.directUrl}... (attempt ${i + 1}/5)")
runCatching { runCatching {
runBlocking {
attachment.openStream { downloadedInputStream, _ -> attachment.openStream { downloadedInputStream, _ ->
MediaDownloaderHelper.getSplitElements(downloadedInputStream!!) { type, splitInputStream -> MediaDownloaderHelper.getSplitElements(downloadedInputStream!!) { type, splitInputStream ->
val mediaKey = "${type}_${attachment.mediaUniqueId}" val mediaKey = "${type}_${attachment.mediaUniqueId}"
@ -161,6 +163,7 @@ class ConversationExporter(
downloadedMediaIdCache.add(attachment.mediaUniqueId!!) downloadedMediaIdCache.add(attachment.mediaUniqueId!!)
} }
} }
}
return@decode return@decode
}.onFailure { }.onFailure {
downloadedMediaIdCache.remove(attachment.mediaUniqueId!!) downloadedMediaIdCache.remove(attachment.mediaUniqueId!!)