feat: better proto utils

- new proto editor
This commit is contained in:
rhunk
2023-08-28 02:06:58 +02:00
parent 1d925136ff
commit 9b7ff40302
11 changed files with 266 additions and 131 deletions

View File

@ -14,18 +14,18 @@ class MessageSender(
companion object {
val redSnapProto: (Boolean) -> ByteArray = {hasAudio ->
ProtoWriter().apply {
write(11, 5) {
write(1) {
write(1) {
writeConstant(2, 0)
writeConstant(12, 0)
writeConstant(15, 0)
from(11, 5) {
from(1) {
from(1) {
addVarInt(2, 0)
addVarInt(12, 0)
addVarInt(15, 0)
}
writeConstant(6, 0)
addVarInt(6, 0)
}
write(2) {
writeConstant(5, if (hasAudio) 1 else 0)
writeBuffer(6, byteArrayOf())
from(2) {
addVarInt(5, if (hasAudio) 1 else 0)
addBuffer(6, byteArrayOf())
}
}
}.toByteArray()
@ -33,15 +33,15 @@ class MessageSender(
val audioNoteProto: (Long) -> ByteArray = { duration ->
ProtoWriter().apply {
write(6, 1) {
write(1) {
writeConstant(2, 4)
write(5) {
writeConstant(1, 0)
writeConstant(2, 0)
from(6, 1) {
from(1) {
addVarInt(2, 4)
from(5) {
addVarInt(1, 0)
addVarInt(2, 0)
}
writeConstant(7, 0)
writeConstant(13, duration)
addVarInt(7, 0)
addVarInt(13, duration)
}
}
}.toByteArray()
@ -153,8 +153,8 @@ class MessageSender(
fun sendChatMessage(conversations: List<SnapUUID>, message: String, onError: (Any) -> Unit = {}, onSuccess: () -> Unit = {}) {
internalSendMessage(conversations, createLocalMessageContentTemplate(ContentType.CHAT, ProtoWriter().apply {
write(2) {
writeString(1, message)
from(2) {
addString(1, message)
}
}.toByteArray(), savePolicy = "LIFETIME"), CallbackBuilder(sendMessageCallback)
.override("onSuccess", callback = { onSuccess() })

View File

@ -466,7 +466,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
val messageReader = ProtoReader(messageContent)
val urlProto: ByteArray = if (isArroyoMessage) {
var finalProto: ByteArray? = null
messageReader.readPath(4)?.each(5) {
messageReader.eachBuffer(4, 5) {
finalProto = getByteArray(1, 3)
}
finalProto!!

View File

@ -64,9 +64,9 @@ class ProfilePictureDownloader : Feature("ProfilePictureDownloader", loadParams
}
}
ProtoReader(content).readPath(1, 1, 2) {
friendUsername = getString(2) ?: return@readPath
readPath(4) {
ProtoReader(content).followPath(1, 1, 2) {
friendUsername = getString(2) ?: return@followPath
followPath(4) {
backgroundUrl = getString(2)
avatarUrl = getString(100)
}

View File

@ -28,7 +28,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara
//prevent story replies
val messageProtoReader = ProtoReader(localMessageContent.content)
if (messageProtoReader.exists(7)) return@subscribe
if (messageProtoReader.contains(7)) return@subscribe
event.canceled = true
@ -38,7 +38,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara
dialog.dismiss()
val overrideType = typeNames.keys.toTypedArray()[which]
if (overrideType != "ORIGINAL" && messageProtoReader.readPath(3)?.getCount(3) != 1) {
if (overrideType != "ORIGINAL" && messageProtoReader.followPath(3)?.getCount(3) != 1) {
context.runOnUiThread {
ViewAppearanceHelper.newAlertDialogBuilder(context.mainActivity!!)
.setMessage(context.translation["gallery_media_send_override.multiple_media_toast"])
@ -57,7 +57,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara
"NOTE" -> {
localMessageContent.contentType = ContentType.NOTE
val mediaDuration =
messageProtoReader.getLong(3, 3, 5, 1, 1, 15) ?: 0
messageProtoReader.getVarInt(3, 3, 5, 1, 1, 15) ?: 0
localMessageContent.content =
MessageSender.audioNoteProto(mediaDuration)
}

View File

@ -20,12 +20,12 @@ class UnlimitedSnapViewTime :
if (message.messageContent.contentType != ContentType.SNAP) return@hookConstructor
with(message.messageContent) {
val mediaAttributes = ProtoReader(this.content).readPath(11, 5, 2) ?: return@hookConstructor
if (mediaAttributes.exists(6)) return@hookConstructor
val mediaAttributes = ProtoReader(this.content).followPath(11, 5, 2) ?: return@hookConstructor
if (mediaAttributes.contains(6)) return@hookConstructor
this.content = ProtoEditor(this.content).apply {
edit(11, 5, 2) {
mediaAttributes.getInt(5)?.let { writeConstant(5, it) }
writeBuffer(6, byteArrayOf())
remove(8)
addBuffer(6, byteArrayOf())
}
}.toByteArray()
}

View File

@ -1,39 +1,64 @@
package me.rhunk.snapenhance.util.protobuf
typealias WireCallback = EditorContext.() -> Unit
class EditorContext(
private val wires: MutableMap<Int, MutableList<Wire>>
) {
fun clear() {
wires.clear()
}
fun addWire(wire: Wire) {
wires.getOrPut(wire.id) { mutableListOf() }.add(wire)
}
fun addVarInt(id: Int, value: Int) = addVarInt(id, value.toLong())
fun addVarInt(id: Int, value: Long) = addWire(Wire(id, WireType.VARINT, value))
fun addBuffer(id: Int, value: ByteArray) = addWire(Wire(id, WireType.LENGTH_DELIMITED, value))
fun addString(id: Int, value: String) = addBuffer(id, value.toByteArray())
fun addFixed64(id: Int, value: Long) = addWire(Wire(id, WireType.FIXED64, value))
fun addFixed32(id: Int, value: Int) = addWire(Wire(id, WireType.FIXED32, value))
fun firstOrNull(id: Int) = wires[id]?.firstOrNull()
fun getOrNull(id: Int) = wires[id]
fun get(id: Int) = wires[id]!!
fun remove(id: Int) = wires.remove(id)
fun remove(id: Int, index: Int) = wires[id]?.removeAt(index)
}
class ProtoEditor(
private var buffer: ByteArray
) {
fun edit(vararg path: Int, callback: ProtoWriter.() -> Unit) {
val writer = ProtoWriter()
callback(writer)
buffer = writeAtPath(path, 0, ProtoReader(buffer), writer.toByteArray())
fun edit(vararg path: Int, callback: WireCallback) {
buffer = writeAtPath(path, 0, ProtoReader(buffer), callback)
}
private fun writeAtPath(path: IntArray, currentIndex: Int, rootReader: ProtoReader, bufferToWrite: ByteArray): ByteArray {
if (currentIndex == path.size) {
return bufferToWrite
}
val id = path[currentIndex]
private fun writeAtPath(path: IntArray, currentIndex: Int, rootReader: ProtoReader, wireToWriteCallback: WireCallback): ByteArray {
val id = path.getOrNull(currentIndex)
val output = ProtoWriter()
val wires = mutableListOf<Pair<Int, ByteArray>>()
val wires = mutableMapOf<Int, MutableList<Wire>>()
rootReader.list { tag, value ->
if (tag == id) {
val childReader = rootReader.readPath(id)
rootReader.forEach { wireId, value ->
wires.putIfAbsent(wireId, mutableListOf())
if (id != null && wireId == id) {
val childReader = rootReader.followPath(id)
if (childReader == null) {
wires.add(Pair(tag, value))
return@list
wires.getOrPut(wireId) { mutableListOf() }.add(value)
return@forEach
}
wires.add(Pair(tag, writeAtPath(path, currentIndex + 1, childReader, bufferToWrite)))
return@list
wires[wireId]!!.add(Wire(wireId, WireType.LENGTH_DELIMITED, writeAtPath(path, currentIndex + 1, childReader, wireToWriteCallback)))
return@forEach
}
wires.add(Pair(tag, value))
wires[wireId]!!.add(value)
}
wires.forEach { (tag, value) ->
output.writeBuffer(tag, value)
if (currentIndex == path.size) {
wireToWriteCallback(EditorContext(wires))
}
wires.values.flatten().forEach(output::addWire)
return output.toByteArray()
}

View File

@ -1,6 +1,6 @@
package me.rhunk.snapenhance.util.protobuf
data class Wire(val type: Int, val value: Any)
data class Wire(val id: Int, val type: WireType, val value: Any)
class ProtoReader(private val buffer: ByteArray) {
private var offset: Int = 0
@ -32,19 +32,46 @@ class ProtoReader(private val buffer: ByteArray) {
while (offset < buffer.size) {
val tag = readVarInt().toInt()
val id = tag ushr 3
val type = tag and 0x7
val type = WireType.fromValue(tag and 0x7) ?: break
try {
val value = when (type) {
0 -> readVarInt().toString().toByteArray()
2 -> {
val length = readVarInt().toInt()
val value = buffer.copyOfRange(offset, offset + length)
offset += length
value
WireType.VARINT -> readVarInt()
WireType.FIXED64 -> {
val bytes = ByteArray(8)
for (i in 0..7) {
bytes[i] = readByte()
}
bytes
}
else -> break
WireType.LENGTH_DELIMITED -> {
val length = readVarInt().toInt()
val bytes = ByteArray(length)
for (i in 0 until length) {
bytes[i] = readByte()
}
bytes
}
WireType.START_GROUP -> {
val bytes = mutableListOf<Byte>()
while (true) {
val b = readByte()
if (b.toInt() == WireType.END_GROUP.value) {
break
}
bytes.add(b)
}
bytes.toByteArray()
}
WireType.FIXED32 -> {
val bytes = ByteArray(4)
for (i in 0..3) {
bytes[i] = readByte()
}
bytes
}
WireType.END_GROUP -> continue
}
values.getOrPut(id) { mutableListOf() }.add(Wire(type, value))
values.getOrPut(id) { mutableListOf() }.add(Wire(id, type, value))
} catch (t: Throwable) {
values.clear()
break
@ -52,13 +79,19 @@ class ProtoReader(private val buffer: ByteArray) {
}
}
fun readPath(vararg ids: Int, reader: (ProtoReader.() -> Unit)? = null): ProtoReader? {
fun followPath(vararg ids: Int, excludeLast: Boolean = false, reader: (ProtoReader.() -> Unit)? = null): ProtoReader? {
var thisReader = this
ids.forEach { id ->
if (!thisReader.exists(id)) {
ids.let {
if (excludeLast) {
it.sliceArray(0 until it.size - 1)
} else {
it
}
}.forEach { id ->
if (!thisReader.contains(id)) {
return null
}
thisReader = ProtoReader(thisReader.get(id) as ByteArray)
thisReader = ProtoReader(thisReader.getByteArray(id) ?: return null)
}
if (reader != null) {
thisReader.reader()
@ -66,65 +99,77 @@ class ProtoReader(private val buffer: ByteArray) {
return thisReader
}
fun pathExists(vararg ids: Int): Boolean {
fun containsPath(vararg ids: Int): Boolean {
var thisReader = this
ids.forEach { id ->
if (!thisReader.exists(id)) {
if (!thisReader.contains(id)) {
return false
}
thisReader = ProtoReader(thisReader.get(id) as ByteArray)
thisReader = ProtoReader(thisReader.getByteArray(id) ?: return false)
}
return true
}
fun getByteArray(id: Int) = values[id]?.first()?.value as ByteArray?
fun getByteArray(vararg ids: Int): ByteArray? {
if (ids.isEmpty() || ids.size < 2) {
return null
}
val lastId = ids.last()
var value: ByteArray? = null
readPath(*(ids.copyOfRange(0, ids.size - 1))) {
value = getByteArray(lastId)
}
return value
}
fun getString(id: Int) = getByteArray(id)?.toString(Charsets.UTF_8)
fun getString(vararg ids: Int) = getByteArray(*ids)?.toString(Charsets.UTF_8)
fun getInt(id: Int) = getString(id)?.toInt()
fun getInt(vararg ids: Int) = getString(*ids)?.toInt()
fun getLong(id: Int) = getString(id)?.toLong()
fun getLong(vararg ids: Int) = getString(*ids)?.toLong()
fun exists(id: Int) = values.containsKey(id)
fun get(id: Int) = values[id]!!.first().value
fun isValid() = values.isNotEmpty()
fun getCount(id: Int) = values[id]!!.size
fun each(id: Int, reader: ProtoReader.(index: Int) -> Unit) {
values[id]!!.forEachIndexed { index, _ ->
ProtoReader(values[id]!![index].value as ByteArray).reader(index)
}
}
fun list(reader: (id: Int, data: ByteArray) -> Unit) {
fun forEach(reader: (Int, Wire) -> Unit) {
values.forEach { (id, wires) ->
wires.forEachIndexed { index, _ ->
reader(id, wires[index].value as ByteArray)
wires.forEach { wire ->
reader(id, wire)
}
}
}
fun eachExists(id: Int, reader: ProtoReader.(index: Int) -> Unit) {
if (!exists(id)) {
return
fun forEach(vararg id: Int, reader: ProtoReader.() -> Unit) {
followPath(*id)?.eachBuffer { _, buffer ->
ProtoReader(buffer).reader()
}
each(id, reader)
}
fun eachBuffer(vararg ids: Int, reader: ProtoReader.() -> Unit) {
followPath(*ids, excludeLast = true)?.eachBuffer { id, buffer ->
if (id == ids.last()) {
ProtoReader(buffer).reader()
}
}
}
fun eachBuffer(reader: (Int, ByteArray) -> Unit) {
values.forEach { (id, wires) ->
wires.forEach { wire ->
if (wire.type == WireType.LENGTH_DELIMITED) {
reader(id, wire.value as ByteArray)
}
}
}
}
fun contains(id: Int) = values.containsKey(id)
fun getWire(id: Int) = values[id]?.firstOrNull()
fun getRawValue(id: Int) = getWire(id)?.value
fun getByteArray(id: Int) = getRawValue(id) as? ByteArray
fun getByteArray(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getByteArray(ids.last())
fun getString(id: Int) = getByteArray(id)?.toString(Charsets.UTF_8)
fun getString(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getString(ids.last())
fun getVarInt(id: Int) = getRawValue(id) as? Long
fun getVarInt(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getVarInt(ids.last())
fun getCount(id: Int) = values[id]?.size ?: 0
fun getFixed64(id: Int): Long {
val bytes = getByteArray(id) ?: return 0L
var value = 0L
for (i in 0..7) {
value = value or ((bytes[i].toLong() and 0xFF) shl (i * 8))
}
return value
}
fun getFixed32(id: Int): Int {
val bytes = getByteArray(id) ?: return 0
var value = 0
for (i in 0..3) {
value = value or ((bytes[i].toInt() and 0xFF) shl (i * 8))
}
return value
}
}

View File

@ -23,43 +23,94 @@ class ProtoWriter {
stream.write(v.toInt())
}
fun writeBuffer(id: Int, value: ByteArray) {
writeVarInt(id shl 3 or 2)
fun addBuffer(id: Int, value: ByteArray) {
writeVarInt(id shl 3 or WireType.LENGTH_DELIMITED.value)
writeVarInt(value.size)
stream.write(value)
}
fun writeConstant(id: Int, value: Int) {
writeVarInt(id shl 3)
writeVarInt(value)
}
fun addVarInt(id: Int, value: Int) = addVarInt(id, value.toLong())
fun writeConstant(id: Int, value: Long) {
fun addVarInt(id: Int, value: Long) {
writeVarInt(id shl 3)
writeVarLong(value)
}
fun writeString(id: Int, value: String) = writeBuffer(id, value.toByteArray())
fun addString(id: Int, value: String) = addBuffer(id, value.toByteArray())
fun write(id: Int, writer: ProtoWriter.() -> Unit) {
val writerStream = ProtoWriter()
writer(writerStream)
writeBuffer(id, writerStream.stream.toByteArray())
fun addFixed32(id: Int, value: Int) {
writeVarInt(id shl 3 or WireType.FIXED32.value)
val bytes = ByteArray(4)
for (i in 0..3) {
bytes[i] = (value shr (i * 8)).toByte()
}
stream.write(bytes)
}
fun write(vararg ids: Int, writer: ProtoWriter.() -> Unit) {
fun addFixed64(id: Int, value: Long) {
writeVarInt(id shl 3 or WireType.FIXED64.value)
val bytes = ByteArray(8)
for (i in 0..7) {
bytes[i] = (value shr (i * 8)).toByte()
}
stream.write(bytes)
}
fun from(id: Int, writer: ProtoWriter.() -> Unit) {
val writerStream = ProtoWriter()
writer(writerStream)
addBuffer(id, writerStream.stream.toByteArray())
}
fun from(vararg ids: Int, writer: ProtoWriter.() -> Unit) {
val writerStream = ProtoWriter()
writer(writerStream)
var stream = writerStream.stream.toByteArray()
ids.reversed().forEach { id ->
with(ProtoWriter()) {
writeBuffer(id, stream)
addBuffer(id, stream)
stream = this.stream.toByteArray()
}
}
stream.let(this.stream::write)
}
fun addWire(wire: Wire) {
writeVarInt(wire.id shl 3 or wire.type.value)
when (wire.type) {
WireType.VARINT -> writeVarLong(wire.value as Long)
WireType.FIXED64, WireType.FIXED32 -> {
when (wire.value) {
is Int -> {
val bytes = ByteArray(4)
for (i in 0..3) {
bytes[i] = (wire.value shr (i * 8)).toByte()
}
stream.write(bytes)
}
is Long -> {
val bytes = ByteArray(8)
for (i in 0..7) {
bytes[i] = (wire.value shr (i * 8)).toByte()
}
stream.write(bytes)
}
is ByteArray -> stream.write(wire.value)
}
}
WireType.LENGTH_DELIMITED -> {
val value = wire.value as ByteArray
writeVarInt(value.size)
stream.write(value)
}
WireType.START_GROUP -> {
val value = wire.value as ByteArray
stream.write(value)
}
WireType.END_GROUP -> return
}
}
fun toByteArray(): ByteArray {
return stream.toByteArray()
}

View File

@ -0,0 +1,14 @@
package me.rhunk.snapenhance.util.protobuf;
enum class WireType(val value: Int) {
VARINT(0),
FIXED64(1),
LENGTH_DELIMITED(2),
START_GROUP(3),
END_GROUP(4),
FIXED32(5);
companion object {
fun fromValue(value: Int) = values().firstOrNull { it.value == value }
}
}

View File

@ -13,12 +13,12 @@ import javax.crypto.spec.SecretKeySpec
object EncryptionHelper {
fun getEncryptionKeys(contentType: ContentType, messageProto: ProtoReader, isArroyo: Boolean): Pair<ByteArray, ByteArray>? {
val messageMediaInfo = MediaDownloaderHelper.getMessageMediaInfo(messageProto, contentType, isArroyo) ?: return null
val encryptionProtoIndex = if (messageMediaInfo.exists(Constants.ENCRYPTION_PROTO_INDEX_V2)) {
val encryptionProtoIndex = if (messageMediaInfo.contains(Constants.ENCRYPTION_PROTO_INDEX_V2)) {
Constants.ENCRYPTION_PROTO_INDEX_V2
} else {
Constants.ENCRYPTION_PROTO_INDEX
}
val encryptionProto = messageMediaInfo.readPath(encryptionProtoIndex) ?: return null
val encryptionProto = messageMediaInfo.followPath(encryptionProtoIndex) ?: return null
var key: ByteArray = encryptionProto.getByteArray(1)!!
var iv: ByteArray = encryptionProto.getByteArray(2)!!

View File

@ -19,12 +19,12 @@ import java.util.zip.ZipInputStream
object MediaDownloaderHelper {
fun getMessageMediaInfo(protoReader: ProtoReader, contentType: ContentType, isArroyo: Boolean): ProtoReader? {
val messageContainerPath = if (isArroyo) protoReader.readPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader
val messageContainerPath = if (isArroyo) protoReader.followPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader
val mediaContainerPath = if (contentType == ContentType.NOTE) intArrayOf(6, 1, 1) else intArrayOf(5, 1, 1)
return when (contentType) {
ContentType.NOTE -> messageContainerPath.readPath(*mediaContainerPath)
ContentType.SNAP -> messageContainerPath.readPath(*(intArrayOf(11) + mediaContainerPath))
ContentType.NOTE -> messageContainerPath.followPath(*mediaContainerPath)
ContentType.SNAP -> messageContainerPath.followPath(*(intArrayOf(11) + mediaContainerPath))
ContentType.EXTERNAL_MEDIA -> {
val externalMediaTypes = arrayOf(
intArrayOf(3, 3), //normal external media
@ -32,7 +32,7 @@ object MediaDownloaderHelper {
intArrayOf(7, 3) //original story reply
)
externalMediaTypes.forEach { path ->
messageContainerPath.readPath(*(path + mediaContainerPath))?.also { return it }
messageContainerPath.followPath(*(path + mediaContainerPath))?.also { return it }
}
null
}