feat: multiple media chat export

- optimize message exporter download
- optimize zip download/extract
This commit is contained in:
rhunk
2023-09-16 11:56:41 +02:00
parent 5a47e04093
commit 9cb9bd7a26
10 changed files with 289 additions and 323 deletions

View File

@ -23,17 +23,14 @@ import me.rhunk.snapenhance.core.download.data.DownloadMetadata
import me.rhunk.snapenhance.core.download.data.DownloadRequest import me.rhunk.snapenhance.core.download.data.DownloadRequest
import me.rhunk.snapenhance.core.download.data.DownloadStage import me.rhunk.snapenhance.core.download.data.DownloadStage
import me.rhunk.snapenhance.core.download.data.InputMedia import me.rhunk.snapenhance.core.download.data.InputMedia
import me.rhunk.snapenhance.core.download.data.MediaEncryptionKeyPair import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType
import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver
import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper
import java.io.File import java.io.File
import java.io.InputStream import java.io.InputStream
import java.net.HttpURLConnection import java.net.HttpURLConnection
import java.net.URL import java.net.URL
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
import javax.crypto.Cipher
import javax.crypto.CipherInputStream
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import javax.xml.parsers.DocumentBuilderFactory import javax.xml.parsers.DocumentBuilderFactory
import javax.xml.transform.TransformerFactory import javax.xml.transform.TransformerFactory
import javax.xml.transform.dom.DOMSource import javax.xml.transform.dom.DOMSource
@ -110,14 +107,6 @@ class DownloadProcessor (
return files return files
} }
private fun decryptInputStream(inputStream: InputStream, encryption: MediaEncryptionKeyPair): InputStream {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val key = Base64.UrlSafe.decode(encryption.key)
val iv = Base64.UrlSafe.decode(encryption.iv)
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(key, "AES"), IvParameterSpec(iv))
return CipherInputStream(inputStream, cipher)
}
@SuppressLint("UnspecifiedRegisterReceiverFlag") @SuppressLint("UnspecifiedRegisterReceiverFlag")
private suspend fun saveMediaToGallery(inputFile: File, downloadObject: DownloadObject) { private suspend fun saveMediaToGallery(inputFile: File, downloadObject: DownloadObject) {
if (coroutineContext.job.isCancelled) return if (coroutineContext.job.isCancelled) return
@ -202,24 +191,16 @@ class DownloadProcessor (
downloadRequest.inputMedias.forEach { inputMedia -> downloadRequest.inputMedias.forEach { inputMedia ->
fun handleInputStream(inputStream: InputStream) { fun handleInputStream(inputStream: InputStream) {
createMediaTempFile().apply { createMediaTempFile().apply {
if (inputMedia.encryption != null) { (inputMedia.encryption?.decryptInputStream(inputStream) ?: inputStream).copyTo(outputStream())
decryptInputStream(inputStream,
inputMedia.encryption!!
).use { decryptedInputStream ->
decryptedInputStream.copyTo(outputStream())
}
} else {
inputStream.copyTo(outputStream())
}
}.also { downloadedMedias[inputMedia] = it } }.also { downloadedMedias[inputMedia] = it }
} }
launch { launch {
when (inputMedia.type) { when (inputMedia.type) {
DownloadMediaType.PROTO_MEDIA -> { DownloadMediaType.PROTO_MEDIA -> {
RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content))?.let { inputStream -> RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content), decryptionCallback = { it }, resultCallback = {
handleInputStream(inputStream) handleInputStream(it)
} })
} }
DownloadMediaType.DIRECT_MEDIA -> { DownloadMediaType.DIRECT_MEDIA -> {
val decoded = Base64.UrlSafe.decode(inputMedia.content) val decoded = Base64.UrlSafe.decode(inputMedia.content)
@ -359,20 +340,26 @@ class DownloadProcessor (
var shouldMergeOverlay = downloadRequest.shouldMergeOverlay var shouldMergeOverlay = downloadRequest.shouldMergeOverlay
//if there is a zip file, extract it and replace the downloaded media with the extracted ones //if there is a zip file, extract it and replace the downloaded media with the extracted ones
downloadedMedias.values.find { it.fileType == FileType.ZIP }?.let { entry -> downloadedMedias.values.find { it.fileType == FileType.ZIP }?.let { zipFile ->
val extractedMedias = extractZip(entry.file.inputStream()).map { val oldDownloadedMedias = downloadedMedias.toMap()
InputMedia( downloadedMedias.clear()
type = DownloadMediaType.LOCAL_MEDIA,
content = it.absolutePath MediaDownloaderHelper.getSplitElements(zipFile.file.inputStream()) { type, inputStream ->
) to DownloadedFile(it, FileType.fromFile(it)) createMediaTempFile().apply {
inputStream.copyTo(outputStream())
}.also {
downloadedMedias[InputMedia(
type = DownloadMediaType.LOCAL_MEDIA,
content = it.absolutePath,
isOverlay = type == SplitMediaAssetType.OVERLAY
)] = DownloadedFile(it, FileType.fromFile(it))
}
} }
downloadedMedias.values.removeIf { oldDownloadedMedias.forEach { (_, value) ->
it.file.delete() value.file.delete()
true
} }
downloadedMedias.putAll(extractedMedias)
shouldMergeOverlay = true shouldMergeOverlay = true
} }

View File

@ -122,11 +122,16 @@
} }
.media_container { .chat_media {
max-width: 300px; max-width: 300px;
max-height: 500px; max-height: 500px;
} }
.overlay_media {
position: absolute;
pointer-events: none;
}
.red_snap_svg { .red_snap_svg {
color: var(--sigSnapWithoutSound); color: var(--sigSnapWithoutSound);
} }
@ -140,7 +145,7 @@
<div style="display: none;"> <div style="display: none;">
<svg class="red_snap_svg" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg class="red_snap_svg" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect x="2.75" y="2.75" width="10.5" height="10.5" rx="1.808" stroke="currentColor" stroke-width="1.5"></rect> <rect x="2.75" y="2.75" width="10.5" height="10.5" rx="1.808" stroke="currentColor" stroke-width="1.5"></rect>
</svg> </svg>
</div> </div>
<script> <script>
@ -152,13 +157,24 @@
} }
function makeConversationSummary() { function makeConversationSummary() {
const conversationTitle = conversationData.conversationName != null ? const conversationTitle = conversationData.conversationName != null ?
conversationData.conversationName : conversationData.conversationName : "DM with " + Object.values(participants).map(user => user.username).join(", ")
"DM with " + Object.values(participants).map(user => user.username).join(", ")
document.querySelector(".conversation_summary .title").textContent = conversationTitle document.querySelector(".conversation_summary .title").textContent = conversationTitle
} }
function decodeMedia(element) {
const decodedData = new Uint8Array(
inflate(
base64decode(
element.innerHTML.substring(5, element.innerHTML.length - 4)
)
)
)
return URL.createObjectURL(new Blob([decodedData]))
}
function makeConversationMessageContainer() { function makeConversationMessageContainer() {
const messageTemplate = document.querySelector("#message_template") const messageTemplate = document.querySelector("#message_template")
Object.values(conversationData.messages).forEach(message => { Object.values(conversationData.messages).forEach(message => {
@ -185,63 +201,88 @@
return headerElement return headerElement
})(document.createElement("div"))) })(document.createElement("div")))
messageObject.appendChild(((elem) =>{ messageObject.appendChild(((messageContainer) =>{
elem.classList.add("content") messageContainer.classList.add("content")
elem.innerHTML = message.serializedContent messageContainer.innerHTML = message.serializedContent
if (!message.serializedContent) { if (!message.serializedContent) {
elem.innerHTML = "" messageContainer.innerHTML = ""
let messageData = "" let messageData = ""
switch(message.type) { switch(message.type) {
case "SNAP": case "SNAP":
elem.appendChild(document.querySelector('.red_snap_svg').cloneNode(true)) messageContainer.appendChild(document.querySelector('.red_snap_svg').cloneNode(true))
messageData = "Snap" messageData = "Snap"
break break
default: default:
messageData = message.type messageData = message.type
} }
elem.innerHTML += messageData messageContainer.innerHTML += messageData
} }
if (message.mediaReferences && message.mediaReferences.length > 0) { if (message.attachments && message.attachments.length > 0) {
//only get the first reference let observers = []
const reference = Object.values(message.mediaReferences)[0]
let fetched = false
var observer = new IntersectionObserver(function(entries) {
if(!fetched && entries[0].isIntersecting === true) {
fetched = true
const mediaDiv = document.querySelector('.media-ORIGINAL_' + reference.content.replace(/(=)/g, "")) message.attachments.forEach((attachment, index) => {
if (!mediaDiv) return const mediaKey = attachment.key.replace(/(=)/g, "")
const content = mediaDiv.innerHTML.substring(5, mediaDiv.innerHTML.length - 4) observers.push(() => {
const decodedData = new Uint8Array(inflate(base64decode(content))) const originalMedia = document.querySelector('.media-ORIGINAL_' + mediaKey)
if (!originalMedia) {
return
}
const originalMediaUrl = decodeMedia(originalMedia)
const mediaContainer = document.createElement("div")
messageContainer.appendChild(mediaContainer)
const blob = new Blob([decodedData])
const url = URL.createObjectURL(blob)
const imageTag = document.createElement("img") const imageTag = document.createElement("img")
imageTag.classList.add("media_container") imageTag.src = originalMediaUrl
imageTag.src = url imageTag.classList.add("chat_media")
mediaContainer.appendChild(imageTag)
imageTag.onerror = () => { imageTag.onerror = () => {
elem.removeChild(imageTag) mediaContainer.removeChild(imageTag)
const mediaTag = document.createElement(message.type === "NOTE" ? "audio" : "video") const mediaTag = document.createElement(message.type === "NOTE" ? "audio" : "video")
mediaTag.classList.add("media_container") mediaTag.classList.add("chat_media")
mediaTag.src = url mediaTag.src = originalMediaUrl
mediaTag.preload = "metadata" mediaTag.preload = "metadata"
mediaTag.controls = true mediaTag.controls = true
elem.appendChild(mediaTag) mediaContainer.appendChild(mediaTag)
} }
elem.innerHTML = ""
elem.appendChild(imageTag) const overlay = document.querySelector('.media-OVERLAY_' + mediaKey)
if (!overlay) {
return
}
const overlayImage = document.createElement("img")
overlayImage.src = decodeMedia(overlay)
overlayImage.classList.add("chat_media")
overlayImage.classList.add("overlay_media")
mediaContainer.appendChild(overlayImage)
})
})
let fetched = false
new IntersectionObserver(entries => {
if(!fetched && entries[0].isIntersecting === true) {
fetched = true
messageContainer.innerHTML = ""
observers.forEach(c => {
try {
c()
} catch (e) {
console.log(e)
}
})
} }
}, { threshold: [1] }); }).observe(messageContainer)
observer.observe(elem)
} }
return elem return messageContainer
})(document.createElement("div"))) })(document.createElement("div")))
document.querySelector('.conversation_message_container').appendChild(messageObject) document.querySelector('.conversation_message_container').appendChild(messageObject)

View File

@ -3,6 +3,11 @@
package me.rhunk.snapenhance.core.download.data package me.rhunk.snapenhance.core.download.data
import me.rhunk.snapenhance.data.wrapper.impl.media.EncryptionWrapper import me.rhunk.snapenhance.data.wrapper.impl.media.EncryptionWrapper
import java.io.InputStream
import javax.crypto.Cipher
import javax.crypto.CipherInputStream
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import kotlin.io.encoding.Base64 import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi import kotlin.io.encoding.ExperimentalEncodingApi
@ -10,12 +15,16 @@ import kotlin.io.encoding.ExperimentalEncodingApi
data class MediaEncryptionKeyPair( data class MediaEncryptionKeyPair(
val key: String, val key: String,
val iv: String val iv: String
) ) {
fun decryptInputStream(inputStream: InputStream): InputStream {
fun Pair<ByteArray, ByteArray>.toKeyPair(): MediaEncryptionKeyPair { val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
return MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.first), Base64.UrlSafe.encode(this.second)) cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(Base64.UrlSafe.decode(key), "AES"), IvParameterSpec(Base64.UrlSafe.decode(iv)))
return CipherInputStream(inputStream, cipher)
}
} }
fun EncryptionWrapper.toKeyPair(): MediaEncryptionKeyPair { fun Pair<ByteArray, ByteArray>.toKeyPair()
return MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.keySpec), Base64.UrlSafe.encode(this.ivKeyParameterSpec)) = MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.first), Base64.UrlSafe.encode(this.second))
}
fun EncryptionWrapper.toKeyPair()
= MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.keySpec), Base64.UrlSafe.encode(this.ivKeyParameterSpec))

View File

@ -4,7 +4,6 @@ import me.rhunk.snapenhance.Constants
import me.rhunk.snapenhance.core.Logger import me.rhunk.snapenhance.core.Logger
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
import java.io.ByteArrayInputStream
import java.io.InputStream import java.io.InputStream
import java.util.Base64 import java.util.Base64
@ -36,18 +35,34 @@ object RemoteMediaResolver {
} }
.build() .build()
fun downloadBoltMedia(protoKey: ByteArray): InputStream? { private fun newResolveRequest(protoKey: ByteArray) = Request.Builder()
val request = Request.Builder() .url(BOLT_HTTP_RESOLVER_URL + "/resolve?co=" + Base64.getUrlEncoder().encodeToString(protoKey))
.url(BOLT_HTTP_RESOLVER_URL + "/resolve?co=" + Base64.getUrlEncoder().encodeToString(protoKey)) .addHeader("User-Agent", Constants.USER_AGENT)
.addHeader("User-Agent", Constants.USER_AGENT) .build()
.build()
okHttpClient.newCall(request).execute().use { response -> /**
* Download bolt media with memory allocation
*/
fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }): ByteArray? {
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response ->
if (!response.isSuccessful) { if (!response.isSuccessful) {
Logger.directDebug("Unexpected code $response") Logger.directDebug("Unexpected code $response")
return null return null
} }
return ByteArrayInputStream(response.body.bytes()) return decryptionCallback(response.body.byteStream()).readBytes()
}
}
fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }, resultCallback: (InputStream) -> Unit) {
okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response ->
if (!response.isSuccessful) {
throw Throwable("invalid response ${response.code}")
}
resultCallback(
decryptionCallback(
response.body.byteStream()
)
)
} }
} }
} }

View File

@ -3,6 +3,7 @@ package me.rhunk.snapenhance.core.util.export
import android.os.Environment import android.os.Environment
import android.util.Base64InputStream import android.util.Base64InputStream
import com.google.gson.JsonArray import com.google.gson.JsonArray
import com.google.gson.JsonNull
import com.google.gson.JsonObject import com.google.gson.JsonObject
import de.robv.android.xposed.XposedHelpers import de.robv.android.xposed.XposedHelpers
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@ -11,20 +12,21 @@ import me.rhunk.snapenhance.ModContext
import me.rhunk.snapenhance.core.BuildConfig import me.rhunk.snapenhance.core.BuildConfig
import me.rhunk.snapenhance.core.database.objects.FriendFeedEntry import me.rhunk.snapenhance.core.database.objects.FriendFeedEntry
import me.rhunk.snapenhance.core.database.objects.FriendInfo import me.rhunk.snapenhance.core.database.objects.FriendInfo
import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.core.util.snap.EncryptionHelper
import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper
import me.rhunk.snapenhance.data.ContentType import me.rhunk.snapenhance.data.ContentType
import me.rhunk.snapenhance.data.FileType import me.rhunk.snapenhance.data.FileType
import me.rhunk.snapenhance.data.MediaReferenceType
import me.rhunk.snapenhance.data.wrapper.impl.Message import me.rhunk.snapenhance.data.wrapper.impl.Message
import me.rhunk.snapenhance.data.wrapper.impl.SnapUUID import me.rhunk.snapenhance.data.wrapper.impl.SnapUUID
import me.rhunk.snapenhance.features.impl.downloader.decoder.AttachmentType
import me.rhunk.snapenhance.features.impl.downloader.decoder.MessageDecoder
import java.io.BufferedInputStream
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Base64
import java.util.Collections import java.util.Collections
import java.util.Date import java.util.Date
import java.util.Locale import java.util.Locale
@ -33,6 +35,7 @@ import java.util.concurrent.TimeUnit
import java.util.zip.Deflater import java.util.zip.Deflater
import java.util.zip.DeflaterInputStream import java.util.zip.DeflaterInputStream
import java.util.zip.ZipFile import java.util.zip.ZipFile
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi import kotlin.io.encoding.ExperimentalEncodingApi
@ -44,6 +47,7 @@ enum class ExportFormat(
HTML("html"); HTML("html");
} }
@OptIn(ExperimentalEncodingApi::class)
class MessageExporter( class MessageExporter(
private val context: ModContext, private val context: ModContext,
private val outputFile: File, private val outputFile: File,
@ -94,7 +98,6 @@ class MessageExporter(
writer.flush() writer.flush()
} }
@OptIn(ExperimentalEncodingApi::class)
suspend fun exportHtml(output: OutputStream) { suspend fun exportHtml(output: OutputStream) {
val downloadMediaCacheFolder = File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS), "SnapEnhance/cache").also { it.mkdirs() } val downloadMediaCacheFolder = File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS), "SnapEnhance/cache").also { it.mkdirs() }
val mediaFiles = Collections.synchronizedMap(mutableMapOf<String, Pair<FileType, File>>()) val mediaFiles = Collections.synchronizedMap(mutableMapOf<String, Pair<FileType, File>>())
@ -115,34 +118,30 @@ class MessageExporter(
mediaToDownload?.contains(it.messageContent.contentType) ?: false mediaToDownload?.contains(it.messageContent.contentType) ?: false
}.forEach { message -> }.forEach { message ->
threadPool.execute { threadPool.execute {
val remoteMediaReferences by lazy { MessageDecoder.decode(message.messageContent).forEach decode@{ attachment ->
val serializedMessageContent = context.gson.toJsonTree(message.messageContent.instanceNonNull()).asJsonObject val protoMediaReference = Base64.UrlSafe.decode(attachment.mediaKey ?: return@decode)
serializedMessageContent["mRemoteMediaReferences"]
.asJsonArray.map { it.asJsonObject["mMediaReferences"].asJsonArray }
.flatten()
}
remoteMediaReferences.firstOrNull().takeIf { it != null }?.let { media ->
val protoMediaReference = media.asJsonObject["mContentObject"].asJsonArray.map { it.asByte }.toByteArray()
runCatching { runCatching {
val downloadedMedia = MediaDownloaderHelper.downloadMediaFromReference(protoMediaReference) { RemoteMediaResolver.downloadBoltMedia(protoMediaReference, decryptionCallback = {
EncryptionHelper.decryptInputStream(it, message.messageContent.contentType!!, ProtoReader(message.messageContent.content), isArroyo = false) (attachment.attachmentInfo?.encryption?.decryptInputStream(it) ?: it)
} }) {
it.use { inputStream ->
MediaDownloaderHelper.getSplitElements(inputStream) { type, splitInputStream ->
val fileName = "${type}_${Base64.UrlSafe.encode(protoMediaReference).replace("=", "")}"
val bufferedInputStream = BufferedInputStream(splitInputStream)
val fileType = MediaDownloaderHelper.getFileType(bufferedInputStream)
val mediaFile = File(downloadMediaCacheFolder, "$fileName.${fileType.fileExtension}")
downloadedMedia.forEach { (type, mediaData) -> FileOutputStream(mediaFile).use { fos ->
val fileType = FileType.fromByteArray(mediaData) bufferedInputStream.copyTo(fos)
val fileName = "${type}_${kotlin.io.encoding.Base64.UrlSafe.encode(protoMediaReference).replace("=", "")}" }
val mediaFile = File(downloadMediaCacheFolder, "$fileName.${fileType.fileExtension}") mediaFiles[fileName] = fileType to mediaFile
}
FileOutputStream(mediaFile).use { fos ->
mediaData.inputStream().copyTo(fos)
} }
mediaFiles[fileName] = fileType to mediaFile
updateProgress("downloaded")
} }
updateProgress("downloaded")
}.onFailure { }.onFailure {
printLog("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}") printLog("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}")
context.log.error("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}", it) context.log.error("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}", it)
@ -208,7 +207,7 @@ class MessageExporter(
//export avenir next font //export avenir next font
apkFile.getEntry("res/font/avenir_next_medium.ttf").let { entry -> apkFile.getEntry("res/font/avenir_next_medium.ttf").let { entry ->
val encodedFontData = kotlin.io.encoding.Base64.Default.encode(apkFile.getInputStream(entry).readBytes()) val encodedFontData = Base64.Default.encode(apkFile.getInputStream(entry).readBytes())
output.write(""" output.write("""
<style> <style>
@font-face { @font-face {
@ -284,41 +283,25 @@ class MessageExporter(
addProperty("createdTimestamp", message.messageMetadata.createdAt) addProperty("createdTimestamp", message.messageMetadata.createdAt)
addProperty("readTimestamp", message.messageMetadata.readAt) addProperty("readTimestamp", message.messageMetadata.readAt)
addProperty("serializedContent", serializeMessageContent(message)) addProperty("serializedContent", serializeMessageContent(message))
addProperty("rawContent", Base64.getUrlEncoder().encodeToString(message.messageContent.content)) addProperty("rawContent", Base64.UrlSafe.encode(message.messageContent.content))
val messageContentType = message.messageContent.contentType ?: ContentType.CHAT add("attachments", JsonArray().apply {
MessageDecoder.decode(message.messageContent)
EncryptionHelper.getEncryptionKeys(messageContentType, ProtoReader(message.messageContent.content), isArroyo = false)?.let { encryptionKeyPair -> .forEach attachments@{ attachments ->
add("encryption", JsonObject().apply encryption@{ if (attachments.type == AttachmentType.STICKER) //TODO: implement stickers
addProperty("key", Base64.getEncoder().encodeToString(encryptionKeyPair.first)) return@attachments
addProperty("iv", Base64.getEncoder().encodeToString(encryptionKeyPair.second))
})
}
val remoteMediaReferences by lazy {
val serializedMessageContent = context.gson.toJsonTree(message.messageContent.instanceNonNull()).asJsonObject
serializedMessageContent["mRemoteMediaReferences"]
.asJsonArray.map { it.asJsonObject["mMediaReferences"].asJsonArray }
.flatten()
}
add("mediaReferences", JsonArray().apply mediaReferences@ {
if (messageContentType != ContentType.EXTERNAL_MEDIA &&
messageContentType != ContentType.STICKER &&
messageContentType != ContentType.SNAP &&
messageContentType != ContentType.NOTE)
return@mediaReferences
remoteMediaReferences.forEach { media ->
val protoMediaReference = media.asJsonObject["mContentObject"].asJsonArray.map { it.asByte }.toByteArray()
val mediaType = MediaReferenceType.valueOf(media.asJsonObject["mMediaType"].asString)
add(JsonObject().apply { add(JsonObject().apply {
addProperty("mediaType", mediaType.toString()) addProperty("key", attachments.mediaKey?.replace("=", ""))
addProperty("content", Base64.getUrlEncoder().encodeToString(protoMediaReference)) addProperty("type", attachments.type.toString())
add("encryption", attachments.attachmentInfo?.encryption?.let { encryption ->
JsonObject().apply {
addProperty("key", encryption.key)
addProperty("iv", encryption.iv)
}
} ?: JsonNull.INSTANCE)
}) })
} }
}) })
}) })
} }
}) })

View File

@ -1,73 +0,0 @@
package me.rhunk.snapenhance.core.util.snap
import me.rhunk.snapenhance.Constants
import me.rhunk.snapenhance.core.download.data.MediaEncryptionKeyPair
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.data.ContentType
import java.io.InputStream
import javax.crypto.Cipher
import javax.crypto.CipherInputStream
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi
@OptIn(ExperimentalEncodingApi::class)
object EncryptionHelper {
fun getEncryptionKeys(contentType: ContentType, messageProto: ProtoReader, isArroyo: Boolean): Pair<ByteArray, ByteArray>? {
val mediaEncryptionInfo = MediaDownloaderHelper.getMessageMediaEncryptionInfo(
messageProto,
contentType,
isArroyo
) ?: return null
val encryptionProtoIndex = if (mediaEncryptionInfo.contains(Constants.ENCRYPTION_PROTO_INDEX_V2)) {
Constants.ENCRYPTION_PROTO_INDEX_V2
} else {
Constants.ENCRYPTION_PROTO_INDEX
}
val encryptionProto = mediaEncryptionInfo.followPath(encryptionProtoIndex) ?: return null
var key: ByteArray = encryptionProto.getByteArray(1)!!
var iv: ByteArray = encryptionProto.getByteArray(2)!!
if (encryptionProtoIndex == Constants.ENCRYPTION_PROTO_INDEX_V2) {
key = Base64.UrlSafe.decode(key)
iv = Base64.UrlSafe.decode(iv)
}
return Pair(key, iv)
}
fun decryptInputStream(
inputStream: InputStream,
contentType: ContentType,
messageProto: ProtoReader,
isArroyo: Boolean
): InputStream {
val encryptionKeys = getEncryptionKeys(contentType, messageProto, isArroyo) ?: throw Exception("Failed to get encryption keys")
Cipher.getInstance("AES/CBC/PKCS5Padding").apply {
init(Cipher.DECRYPT_MODE, SecretKeySpec(encryptionKeys.first, "AES"), IvParameterSpec(encryptionKeys.second))
}.let { cipher ->
return CipherInputStream(inputStream, cipher)
}
}
fun decryptInputStream(
inputStream: InputStream,
mediaEncryptionKeyPair: MediaEncryptionKeyPair?
): InputStream {
if (mediaEncryptionKeyPair == null) {
return inputStream
}
Cipher.getInstance("AES/CBC/PKCS5Padding").apply {
init(Cipher.DECRYPT_MODE,
SecretKeySpec(Base64.UrlSafe.decode(mediaEncryptionKeyPair.key), "AES"),
IvParameterSpec(Base64.UrlSafe.decode(mediaEncryptionKeyPair.iv))
)
}.let { cipher ->
return CipherInputStream(inputStream, cipher)
}
}
}

View File

@ -1,70 +1,45 @@
package me.rhunk.snapenhance.core.util.snap package me.rhunk.snapenhance.core.util.snap
import me.rhunk.snapenhance.Constants
import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType
import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.data.ContentType
import me.rhunk.snapenhance.data.FileType import me.rhunk.snapenhance.data.FileType
import java.io.ByteArrayInputStream import java.io.BufferedInputStream
import java.io.FileNotFoundException
import java.io.InputStream import java.io.InputStream
import java.util.zip.ZipEntry
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
object MediaDownloaderHelper { object MediaDownloaderHelper {
fun getMessageMediaEncryptionInfo(protoReader: ProtoReader, contentType: ContentType, isArroyo: Boolean): ProtoReader? { fun getFileType(bufferedInputStream: BufferedInputStream): FileType {
val messageContainerPath = if (isArroyo) protoReader.followPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader val buffer = ByteArray(16)
val mediaContainerPath = if (contentType == ContentType.NOTE) intArrayOf(6, 1, 1) else intArrayOf(5, 1, 1) bufferedInputStream.mark(16)
bufferedInputStream.read(buffer)
return when (contentType) { bufferedInputStream.reset()
ContentType.NOTE -> messageContainerPath.followPath(*mediaContainerPath) return FileType.fromByteArray(buffer)
ContentType.SNAP -> messageContainerPath.followPath(*(intArrayOf(11) + mediaContainerPath))
ContentType.EXTERNAL_MEDIA -> {
val externalMediaTypes = arrayOf(
intArrayOf(3, 3, *mediaContainerPath), //normal external media
intArrayOf(7, 15, 1, 1), //attached audio note
intArrayOf(7, 12, 3, *mediaContainerPath), //attached story reply
intArrayOf(7, 3, *mediaContainerPath), //original story reply
)
externalMediaTypes.forEach { path ->
messageContainerPath.followPath(*path)?.also { return it }
}
null
}
else -> null
}
} }
fun downloadMediaFromReference(
mediaReference: ByteArray,
decryptionCallback: (InputStream) -> InputStream,
): Map<SplitMediaAssetType, ByteArray> {
val inputStream = RemoteMediaResolver.downloadBoltMedia(mediaReference) ?: throw FileNotFoundException("Unable to get media key. Check the logs for more info")
val content = decryptionCallback(inputStream).readBytes()
val fileType = FileType.fromByteArray(content)
val isZipFile = fileType == FileType.ZIP
//videos with overlay are packed in a zip file fun getSplitElements(
//there are 2 files in the zip file, the video (webm) and the overlay (png) inputStream: InputStream,
if (isZipFile) { callback: (SplitMediaAssetType, InputStream) -> Unit
var videoData: ByteArray? = null ) {
var overlayData: ByteArray? = null val bufferedInputStream = BufferedInputStream(inputStream)
val zipInputStream = ZipInputStream(ByteArrayInputStream(content)) val fileType = getFileType(bufferedInputStream)
while (zipInputStream.nextEntry != null) {
val zipEntryData: ByteArray = zipInputStream.readBytes() if (fileType != FileType.ZIP) {
val entryFileType = FileType.fromByteArray(zipEntryData) callback(SplitMediaAssetType.ORIGINAL, bufferedInputStream)
if (entryFileType.isVideo) { return
videoData = zipEntryData
} else if (entryFileType.isImage) {
overlayData = zipEntryData
}
}
videoData ?: throw FileNotFoundException("Unable to find video file in zip file")
overlayData ?: throw FileNotFoundException("Unable to find overlay file in zip file")
return mapOf(SplitMediaAssetType.ORIGINAL to videoData, SplitMediaAssetType.OVERLAY to overlayData)
} }
return mapOf(SplitMediaAssetType.ORIGINAL to content) val zipInputStream = ZipInputStream(bufferedInputStream)
var entry: ZipEntry? = zipInputStream.nextEntry
while (entry != null) {
if (entry.name.startsWith("overlay")) {
callback(SplitMediaAssetType.OVERLAY, zipInputStream)
} else if (entry.name.startsWith("media")) {
callback(SplitMediaAssetType.ORIGINAL, zipInputStream)
}
entry = zipInputStream.nextEntry
}
} }
} }

View File

@ -28,7 +28,6 @@ import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver
import me.rhunk.snapenhance.core.util.ktx.getObjectField import me.rhunk.snapenhance.core.util.ktx.getObjectField
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.core.util.snap.BitmojiSelfie import me.rhunk.snapenhance.core.util.snap.BitmojiSelfie
import me.rhunk.snapenhance.core.util.snap.EncryptionHelper
import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper
import me.rhunk.snapenhance.core.util.snap.PreviewUtils import me.rhunk.snapenhance.core.util.snap.PreviewUtils
import me.rhunk.snapenhance.data.FileType import me.rhunk.snapenhance.data.FileType
@ -47,6 +46,7 @@ import me.rhunk.snapenhance.hook.HookAdapter
import me.rhunk.snapenhance.hook.HookStage import me.rhunk.snapenhance.hook.HookStage
import me.rhunk.snapenhance.hook.Hooker import me.rhunk.snapenhance.hook.Hooker
import me.rhunk.snapenhance.ui.ViewAppearanceHelper import me.rhunk.snapenhance.ui.ViewAppearanceHelper
import java.io.ByteArrayInputStream
import java.nio.file.Paths import java.nio.file.Paths
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Locale import java.util.Locale
@ -526,42 +526,24 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
val friendInfo: FriendInfo = context.database.getFriendInfo(message.senderId!!) ?: throw Exception("Friend not found in database") val friendInfo: FriendInfo = context.database.getFriendInfo(message.senderId!!) ?: throw Exception("Friend not found in database")
val authorName = friendInfo.usernameForSorting!! val authorName = friendInfo.usernameForSorting!!
var messageContent = message.messageContent!! val decodedAttachments = if (messageLogger.isMessageRemoved(message.clientConversationId!!, message.serverMessageId.toLong())) {
var customMediaReferences = mutableListOf<String>()
if (messageLogger.isMessageRemoved(message.clientConversationId!!, message.serverMessageId.toLong())) {
val messageObject = messageLogger.getMessageObject(message.clientConversationId!!, message.serverMessageId.toLong()) ?: throw Exception("Message not found in database") val messageObject = messageLogger.getMessageObject(message.clientConversationId!!, message.serverMessageId.toLong()) ?: throw Exception("Message not found in database")
val messageContentObject = messageObject.getAsJsonObject("mMessageContent") MessageDecoder.decode(messageObject.getAsJsonObject("mMessageContent"))
} else {
messageContent = messageContentObject MessageDecoder.decode(
.getAsJsonArray("mContent") protoReader = ProtoReader(message.messageContent!!)
.map { it.asByte } )
.toByteArray()
customMediaReferences = messageContentObject
.getAsJsonArray("mRemoteMediaReferences")
.map { it.asJsonObject.getAsJsonArray("mMediaReferences") }
.flatten().map { reference ->
Base64.UrlSafe.encode(
reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray()
)
}
.toMutableList()
} }
val messageReader = ProtoReader(messageContent)
val decodedAttachments = MessageDecoder.decode(
protoReader = messageReader,
customMediaReferences = customMediaReferences.takeIf { it.isNotEmpty() }
)
if (decodedAttachments.isEmpty()) { if (decodedAttachments.isEmpty()) {
context.shortToast(translations["no_attachments_toast"]) context.shortToast(translations["no_attachments_toast"])
return return
} }
if (!isPreview) { if (!isPreview) {
if (decodedAttachments.size == 1) { if (decodedAttachments.size == 1 ||
context.mainActivity == null // we can't show alert dialogs when it downloads from a notification, so it downloads the first one
) {
downloadMessageAttachments(friendInfo, message, authorName, downloadMessageAttachments(friendInfo, message, authorName,
listOf(decodedAttachments.first()) listOf(decodedAttachments.first())
) )
@ -600,11 +582,15 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
val firstAttachment = decodedAttachments.first() val firstAttachment = decodedAttachments.first()
val previewCoroutine = async { val previewCoroutine = async {
val downloadedMediaList = MediaDownloaderHelper.downloadMediaFromReference(Base64.decode(firstAttachment.mediaKey!!)) { val downloadedMedia = RemoteMediaResolver.downloadBoltMedia(Base64.decode(firstAttachment.mediaKey!!), decryptionCallback = {
EncryptionHelper.decryptInputStream( firstAttachment.attachmentInfo?.encryption?.decryptInputStream(it) ?: it
it, }) ?: return@async null
decodedAttachments.first().attachmentInfo?.encryption
) val downloadedMediaList = mutableMapOf<SplitMediaAssetType, ByteArray>()
MediaDownloaderHelper.getSplitElements(ByteArrayInputStream(downloadedMedia)) {
type, inputStream ->
downloadedMediaList[type] = inputStream.readBytes()
} }
val originalMedia = downloadedMediaList[SplitMediaAssetType.ORIGINAL] ?: return@async null val originalMedia = downloadedMediaList[SplitMediaAssetType.ORIGINAL] ?: return@async null

View File

@ -1,7 +1,11 @@
package me.rhunk.snapenhance.features.impl.downloader.decoder package me.rhunk.snapenhance.features.impl.downloader.decoder
import com.google.gson.GsonBuilder
import com.google.gson.JsonElement
import com.google.gson.JsonObject
import me.rhunk.snapenhance.core.download.data.toKeyPair import me.rhunk.snapenhance.core.download.data.toKeyPair
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.data.wrapper.impl.MessageContent
import kotlin.io.encoding.Base64 import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi import kotlin.io.encoding.ExperimentalEncodingApi
@ -13,6 +17,8 @@ data class DecodedAttachment(
@OptIn(ExperimentalEncodingApi::class) @OptIn(ExperimentalEncodingApi::class)
object MessageDecoder { object MessageDecoder {
private val gson = GsonBuilder().create()
private fun decodeAttachment(protoReader: ProtoReader): AttachmentInfo? { private fun decodeAttachment(protoReader: ProtoReader): AttachmentInfo? {
val mediaInfo = protoReader.followPath(1, 1) ?: return null val mediaInfo = protoReader.followPath(1, 1) ?: return null
@ -39,6 +45,43 @@ object MessageDecoder {
) )
} }
@OptIn(ExperimentalEncodingApi::class)
fun getEncodedMediaReferences(messageContent: JsonElement): List<String> {
return getMediaReferences(messageContent).map { reference ->
Base64.UrlSafe.encode(
reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray()
)
}
.toList()
}
fun getMediaReferences(messageContent: JsonElement): List<JsonElement> {
return messageContent.asJsonObject.getAsJsonArray("mRemoteMediaReferences")
.asSequence()
.map { it.asJsonObject.getAsJsonArray("mMediaReferences") }
.flatten()
.sortedBy {
it.asJsonObject["mMediaListId"].asLong
}.toList()
}
fun decode(messageContent: MessageContent): List<DecodedAttachment> {
return decode(
ProtoReader(messageContent.content),
customMediaReferences = getEncodedMediaReferences(gson.toJsonTree(messageContent.instanceNonNull()))
)
}
fun decode(messageContent: JsonObject): List<DecodedAttachment> {
return decode(
ProtoReader(messageContent.getAsJsonArray("mContent")
.map { it.asByte }
.toByteArray()),
customMediaReferences = getEncodedMediaReferences(messageContent)
)
}
fun decode( fun decode(
protoReader: ProtoReader, protoReader: ProtoReader,
customMediaReferences: List<String>? = null // when customReferences is null it means that the message is from arroyo database customMediaReferences: List<String>? = null // when customReferences is null it means that the message is from arroyo database
@ -138,7 +181,6 @@ object MessageDecoder {
} }
} }
return decodedAttachment return decodedAttachment
} }
} }

View File

@ -16,9 +16,9 @@ import me.rhunk.snapenhance.core.Logger
import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType
import me.rhunk.snapenhance.core.eventbus.events.impl.SnapWidgetBroadcastReceiveEvent import me.rhunk.snapenhance.core.eventbus.events.impl.SnapWidgetBroadcastReceiveEvent
import me.rhunk.snapenhance.core.util.CallbackBuilder import me.rhunk.snapenhance.core.util.CallbackBuilder
import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver
import me.rhunk.snapenhance.core.util.ktx.setObjectField import me.rhunk.snapenhance.core.util.ktx.setObjectField
import me.rhunk.snapenhance.core.util.protobuf.ProtoReader import me.rhunk.snapenhance.core.util.protobuf.ProtoReader
import me.rhunk.snapenhance.core.util.snap.EncryptionHelper
import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper
import me.rhunk.snapenhance.core.util.snap.PreviewUtils import me.rhunk.snapenhance.core.util.snap.PreviewUtils
import me.rhunk.snapenhance.core.util.snap.SnapWidgetBroadcastReceiverHelper import me.rhunk.snapenhance.core.util.snap.SnapWidgetBroadcastReceiverHelper
@ -34,7 +34,6 @@ import me.rhunk.snapenhance.features.impl.downloader.decoder.MessageDecoder
import me.rhunk.snapenhance.hook.HookStage import me.rhunk.snapenhance.hook.HookStage
import me.rhunk.snapenhance.hook.Hooker import me.rhunk.snapenhance.hook.Hooker
import me.rhunk.snapenhance.hook.hook import me.rhunk.snapenhance.hook.hook
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi import kotlin.io.encoding.ExperimentalEncodingApi
class Notifications : Feature("Notifications", loadParams = FeatureLoadParams.INIT_SYNC) { class Notifications : Feature("Notifications", loadParams = FeatureLoadParams.INIT_SYNC) {
@ -246,29 +245,31 @@ class Notifications : Feature("Notifications", loadParams = FeatureLoadParams.IN
appendNotifications() appendNotifications()
} }
ContentType.SNAP, ContentType.EXTERNAL_MEDIA -> { ContentType.SNAP, ContentType.EXTERNAL_MEDIA -> {
val serializedMessageContent = context.gson.toJsonTree(snapMessage.messageContent.instanceNonNull()).asJsonObject val mediaReferences = MessageDecoder.getMediaReferences(
val mediaReferences = serializedMessageContent messageContent = context.gson.toJsonTree(snapMessage.messageContent.instanceNonNull())
.getAsJsonArray("mRemoteMediaReferences") )
.map { it.asJsonObject.getAsJsonArray("mMediaReferences") }
.flatten()
val mediaReferenceUrls = mediaReferences.map { reference -> val mediaReferenceKeys = mediaReferences.map { reference ->
reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray() reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray()
} }
MessageDecoder.decode( MessageDecoder.decode(snapMessage.messageContent).firstOrNull()?.also { media ->
ProtoReader(contentData), val mediaType = MediaReferenceType.valueOf(mediaReferences.first().asJsonObject["mMediaType"].asString)
customMediaReferences = mediaReferenceUrls.map { Base64.UrlSafe.encode(it) }
).forEachIndexed { index, media ->
val mediaType = MediaReferenceType.valueOf(mediaReferences[index].asJsonObject["mMediaType"].asString)
runCatching { runCatching {
val downloadedMediaList = MediaDownloaderHelper.downloadMediaFromReference(mediaReferenceUrls[index]) { inputStream -> val downloadedMedia = RemoteMediaResolver.downloadBoltMedia(mediaReferenceKeys.first(), decryptionCallback = {
media.attachmentInfo?.encryption?.let { EncryptionHelper.decryptInputStream(inputStream, it) } ?: inputStream media.attachmentInfo?.encryption?.decryptInputStream(it) ?: it
}) ?: throw Throwable("Unable to download media")
val downloadedMedias = mutableMapOf<SplitMediaAssetType, ByteArray>()
MediaDownloaderHelper.getSplitElements(downloadedMedia.inputStream()) { type, inputStream ->
downloadedMedias[type] = inputStream.readBytes()
} }
var bitmapPreview = PreviewUtils.createPreview(downloadedMediaList[SplitMediaAssetType.ORIGINAL]!!, mediaType.name.contains("VIDEO"))!! var bitmapPreview = PreviewUtils.createPreview(downloadedMedias[SplitMediaAssetType.ORIGINAL]!!, mediaType.name.contains("VIDEO"))!!
downloadedMediaList[SplitMediaAssetType.OVERLAY]?.let { downloadedMedias[SplitMediaAssetType.OVERLAY]?.let {
bitmapPreview = PreviewUtils.mergeBitmapOverlay(bitmapPreview, BitmapFactory.decodeByteArray(it, 0, it.size)) bitmapPreview = PreviewUtils.mergeBitmapOverlay(bitmapPreview, BitmapFactory.decodeByteArray(it, 0, it.size))
} }