fix(core): remote media resolver

Use of coroutines to cancel network requests if necessary

Signed-off-by: rhunk <101876869+rhunk@users.noreply.github.com>
This commit is contained in:
rhunk 2025-01-03 00:22:45 +01:00
parent 8e8220a55e
commit 3d50054d38
6 changed files with 56 additions and 51 deletions

View File

@ -1,7 +1,7 @@
package me.rhunk.snapenhance.common.util.snap
import me.rhunk.snapenhance.common.Constants
import me.rhunk.snapenhance.common.logger.AbstractLogger
import me.rhunk.snapenhance.common.util.ktx.await
import okhttp3.Headers
import okhttp3.OkHttpClient
import okhttp3.Request
@ -41,21 +41,21 @@ object RemoteMediaResolver {
.addHeader("User-Agent", Constants.USER_AGENT)
.build()
/**
* Download bolt media with memory allocation
*/
fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }): ByteArray? {
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response ->
suspend inline fun downloadMedia(url: String, decryptionCallback: (InputStream) -> InputStream = { it }, result: (InputStream, Long) -> Unit) {
okHttpClient.newCall(Request.Builder().url(url).build()).await().use { response ->
if (!response.isSuccessful) {
AbstractLogger.directDebug("Unexpected code $response")
return null
throw Throwable("invalid response ${response.code}")
}
return decryptionCallback(response.body.byteStream()).readBytes()
result(decryptionCallback(response.body.byteStream()), response.body.contentLength())
}
}
inline fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }, resultCallback: (stream: InputStream, length: Long) -> Unit) {
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response ->
suspend inline fun downloadBoltMedia(
protoKey: ByteArray,
decryptionCallback: (InputStream) -> InputStream = { it },
resultCallback: (stream: InputStream, length: Long) -> Unit
) {
okHttpClient.newCall(newResolveRequest(protoKey)).await().use { response ->
if (!response.isSuccessful) {
throw Throwable("invalid response ${response.code}")
}

View File

@ -45,6 +45,7 @@ import me.rhunk.snapenhance.common.ui.createComposeAlertDialog
import me.rhunk.snapenhance.common.ui.rememberAsyncMutableState
import me.rhunk.snapenhance.common.util.ktx.copyToClipboard
import me.rhunk.snapenhance.common.util.snap.BitmojiSelfie
import me.rhunk.snapenhance.common.util.snap.RemoteMediaResolver
import me.rhunk.snapenhance.core.action.AbstractAction
import me.rhunk.snapenhance.core.features.impl.experiments.AddFriendSourceSpoof
import me.rhunk.snapenhance.core.features.impl.experiments.BetterLocation
@ -405,13 +406,13 @@ class BulkMessagingAction : AbstractAction() {
if (bitmojiBitmap != null || friendInfo.bitmojiAvatarId == null || friendInfo.bitmojiSelfieId == null) return@withContext
val bitmojiUrl = BitmojiSelfie.getBitmojiSelfie(friendInfo.bitmojiSelfieId, friendInfo.bitmojiAvatarId, BitmojiSelfie.BitmojiSelfieType.NEW_THREE_D) ?: return@withContext
runCatching {
URL(bitmojiUrl).openStream().use { input ->
bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext] = BitmapFactory.decodeStream(input)
RemoteMediaResolver.downloadMedia(bitmojiUrl) { inputStream, length ->
bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext] = BitmapFactory.decodeStream(inputStream).also {
bitmojiBitmap = it
}
}
bitmojiBitmap = bitmojiCache[friendInfo.bitmojiAvatarId ?: return@withContext]
}.onFailure {
context.log.error("Failed to load bitmoji", it)
}
}
}

View File

@ -31,7 +31,7 @@ data class DecodedAttachment(
}
@OptIn(ExperimentalEncodingApi::class)
inline fun openStream(callback: (mediaStream: InputStream?, length: Long) -> Unit) {
suspend inline fun openStream(callback: (mediaStream: InputStream?, length: Long) -> Unit) {
boltKey?.let { mediaUrlKey ->
RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(mediaUrlKey), decryptionCallback = {
attachmentInfo?.encryption?.decryptInputStream(it) ?: it
@ -39,11 +39,10 @@ data class DecodedAttachment(
callback(inputStream, length)
})
} ?: directUrl?.let { rawMediaUrl ->
val connection = URL(rawMediaUrl).openConnection()
connection.getInputStream().let {
RemoteMediaResolver.downloadMedia(rawMediaUrl, decryptionCallback = {
attachmentInfo?.encryption?.decryptInputStream(it) ?: it
}.use {
callback(it, connection.contentLengthLong)
}) { inputStream, length ->
callback(inputStream, length)
}
} ?: callback(null, 0)
}

View File

@ -310,7 +310,7 @@ class Notifications : Feature("Notifications") {
}.send()
}
private fun onMessageReceived(data: NotificationData, notificationType: String, message: Message) {
private suspend fun onMessageReceived(data: NotificationData, notificationType: String, message: Message) {
val conversationId = message.messageDescriptor?.conversationId.toString()
val orderKey = message.orderKey ?: return
val senderUsername by lazy {
@ -487,16 +487,18 @@ class Notifications : Feature("Notifications") {
suspendCoroutine { continuation ->
conversationManager.fetchMessageByServerId(conversationId, serverMessageId.toLong(), onSuccess = {
if (it.senderId.toString() == context.database.myUserId) {
param.invokeOriginal()
continuation.resumeWith(Result.success(Unit))
param.invokeOriginal()
return@fetchMessageByServerId
}
onMessageReceived(notificationData, notificationType, it)
continuation.resumeWith(Result.success(Unit))
context.coroutineScope.launch(coroutineDispatcher) {
continuation.resumeWith(Result.success(Unit))
onMessageReceived(notificationData, notificationType, it)
}
}, onError = {
context.log.error("Failed to fetch message id ${serverMessageId}: $it")
param.invokeOriginal()
continuation.resumeWith(Result.success(Unit))
param.invokeOriginal()
})
}
}

View File

@ -318,7 +318,7 @@ class FriendTracker : Feature("Friend Tracker") {
// allow events when a notification is received
hookConstructor(HookStage.AFTER) { param ->
methods.first { it.name == "appStateChanged" }.let { method ->
method.invoke(param.thisObject(), method.parameterTypes[0].enumConstants.first { it.toString() == "ACTIVE" })
method.invoke(param.thisObject(), method.parameterTypes[0].enumConstants!!.first { it.toString() == "ACTIVE" })
}
}
}

View File

@ -3,6 +3,7 @@ package me.rhunk.snapenhance.core.messaging
import android.util.Base64InputStream
import android.util.Base64OutputStream
import com.google.gson.stream.JsonWriter
import kotlinx.coroutines.runBlocking
import me.rhunk.snapenhance.common.BuildConfig
import me.rhunk.snapenhance.common.data.ContentType
import me.rhunk.snapenhance.common.database.impl.FriendFeedEntry
@ -132,33 +133,35 @@ class ConversationExporter(
for (i in 0..5) {
printLog("downloading ${attachment.boltKey ?: attachment.directUrl}... (attempt ${i + 1}/5)")
runCatching {
attachment.openStream { downloadedInputStream, _ ->
MediaDownloaderHelper.getSplitElements(downloadedInputStream!!) { type, splitInputStream ->
val mediaKey = "${type}_${attachment.mediaUniqueId}"
val bufferedInputStream = BufferedInputStream(splitInputStream)
val fileType = MediaDownloaderHelper.getFileType(bufferedInputStream)
val mediaFile = cacheFolder.resolve("$mediaKey.${fileType.fileExtension}")
runBlocking {
attachment.openStream { downloadedInputStream, _ ->
MediaDownloaderHelper.getSplitElements(downloadedInputStream!!) { type, splitInputStream ->
val mediaKey = "${type}_${attachment.mediaUniqueId}"
val bufferedInputStream = BufferedInputStream(splitInputStream)
val fileType = MediaDownloaderHelper.getFileType(bufferedInputStream)
val mediaFile = cacheFolder.resolve("$mediaKey.${fileType.fileExtension}")
mediaFile.outputStream().use { fos ->
bufferedInputStream.copyTo(fos)
}
mediaFile.outputStream().use { fos ->
bufferedInputStream.copyTo(fos)
}
writeThreadExecutor.execute {
outputFileStream.write("<div class=\"media-$mediaKey\"><!-- ".toByteArray())
mediaFile.inputStream().use {
val deflateInputStream = DeflaterInputStream(it, Deflater(Deflater.BEST_SPEED, true))
(newBase64InputStream.newInstance(
deflateInputStream,
android.util.Base64.DEFAULT or android.util.Base64.NO_WRAP,
true
) as InputStream).copyTo(outputFileStream)
outputFileStream.write(" --></div>\n".toByteArray())
outputFileStream.flush()
writeThreadExecutor.execute {
outputFileStream.write("<div class=\"media-$mediaKey\"><!-- ".toByteArray())
mediaFile.inputStream().use {
val deflateInputStream = DeflaterInputStream(it, Deflater(Deflater.BEST_SPEED, true))
(newBase64InputStream.newInstance(
deflateInputStream,
android.util.Base64.DEFAULT or android.util.Base64.NO_WRAP,
true
) as InputStream).copyTo(outputFileStream)
outputFileStream.write(" --></div>\n".toByteArray())
outputFileStream.flush()
}
}
}
}
writeThreadExecutor.execute {
downloadedMediaIdCache.add(attachment.mediaUniqueId!!)
writeThreadExecutor.execute {
downloadedMediaIdCache.add(attachment.mediaUniqueId!!)
}
}
}
return@decode