perf(e2ee): cache optimization

- shared secret key cache
- remove uncommitted messages from the cache
This commit is contained in:
rhunk 2024-01-17 16:18:20 +01:00
parent 8e36425a36
commit ec0fd2cf08
2 changed files with 31 additions and 27 deletions

View File

@ -3,6 +3,7 @@ package me.rhunk.snapenhance.e2ee
import me.rhunk.snapenhance.RemoteSideContext import me.rhunk.snapenhance.RemoteSideContext
import me.rhunk.snapenhance.bridge.e2ee.E2eeInterface import me.rhunk.snapenhance.bridge.e2ee.E2eeInterface
import me.rhunk.snapenhance.bridge.e2ee.EncryptionResult import me.rhunk.snapenhance.bridge.e2ee.EncryptionResult
import me.rhunk.snapenhance.core.util.EvictingMap
import org.bouncycastle.pqc.crypto.crystals.kyber.* import org.bouncycastle.pqc.crypto.crystals.kyber.*
import java.io.File import java.io.File
import java.security.MessageDigest import java.security.MessageDigest
@ -23,25 +24,33 @@ class E2EEImplementation (
}} }}
private val pairingFolder by lazy { File(context.androidContext.cacheDir, "e2ee-pairing").also { private val pairingFolder by lazy { File(context.androidContext.cacheDir, "e2ee-pairing").also {
if (!it.exists()) it.mkdirs() if (!it.exists()) it.mkdirs()
else {
it.deleteRecursively()
it.mkdirs()
}
} } } }
private val sharedSecretKeyCache = EvictingMap<String, ByteArray?>(100)
fun storeSharedSecretKey(friendId: String, key: ByteArray) { fun storeSharedSecretKey(friendId: String, key: ByteArray) {
File(e2eeFolder, "$friendId.key").writeBytes(key) File(e2eeFolder, "$friendId.key").writeBytes(key)
sharedSecretKeyCache[friendId] = key
} }
fun getSharedSecretKey(friendId: String): ByteArray? { fun getSharedSecretKey(friendId: String): ByteArray? {
return runCatching { return sharedSecretKeyCache.getOrPut(friendId) {
runCatching {
File(e2eeFolder, "$friendId.key").readBytes() File(e2eeFolder, "$friendId.key").readBytes()
}.onFailure { }.onFailure {
context.log.error("Failed to read shared secret key", it) context.log.error("Failed to read shared secret key", it)
}.getOrNull() }.getOrNull()
} }
}
fun deleteSharedSecretKey(friendId: String) { fun deleteSharedSecretKey(friendId: String) {
File(e2eeFolder, "$friendId.key").delete() File(e2eeFolder, "$friendId.key").delete()
} }
override fun createKeyExchange(friendId: String): ByteArray? { override fun createKeyExchange(friendId: String): ByteArray? {
val keyPairGenerator = KyberKeyPairGenerator() val keyPairGenerator = KyberKeyPairGenerator()
keyPairGenerator.init( keyPairGenerator.init(
@ -117,12 +126,7 @@ class E2EEImplementation (
} }
override fun getSecretFingerprint(friendId: String): String? { override fun getSecretFingerprint(friendId: String): String? {
val sharedSecretKey = runCatching { val sharedSecretKey = getSharedSecretKey(friendId) ?: return null
File(e2eeFolder, "$friendId.key").readBytes()
}.onFailure {
context.log.error("Failed to read shared secret key", it)
return null
}.getOrThrow()
return MessageDigest.getInstance("SHA-256") return MessageDigest.getInstance("SHA-256")
.digest(sharedSecretKey) .digest(sharedSecretKey)
@ -132,11 +136,7 @@ class E2EEImplementation (
} }
override fun encryptMessage(friendId: String, message: ByteArray): EncryptionResult? { override fun encryptMessage(friendId: String, message: ByteArray): EncryptionResult? {
val encryptionKey = runCatching { val encryptionKey = getSharedSecretKey(friendId) ?: return null
File(e2eeFolder, "$friendId.key").readBytes()
}.onFailure {
context.log.error("Failed to read shared secret key", it)
}.getOrNull()
return runCatching { return runCatching {
val iv = ByteArray(16).apply { secureRandom.nextBytes(this) } val iv = ByteArray(16).apply { secureRandom.nextBytes(this) }
@ -152,12 +152,7 @@ class E2EEImplementation (
} }
override fun decryptMessage(friendId: String, message: ByteArray, iv: ByteArray): ByteArray? { override fun decryptMessage(friendId: String, message: ByteArray, iv: ByteArray): ByteArray? {
val encryptionKey = runCatching { val encryptionKey = getSharedSecretKey(friendId) ?: return null
File(e2eeFolder, "$friendId.key").readBytes()
}.onFailure {
context.log.error("Failed to read shared secret key", it)
return null
}.getOrNull()
return runCatching { return runCatching {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")

View File

@ -16,6 +16,7 @@ import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import me.rhunk.snapenhance.common.data.ContentType import me.rhunk.snapenhance.common.data.ContentType
import me.rhunk.snapenhance.common.data.MessageState
import me.rhunk.snapenhance.common.data.MessagingRuleType import me.rhunk.snapenhance.common.data.MessagingRuleType
import me.rhunk.snapenhance.common.data.RuleState import me.rhunk.snapenhance.common.data.RuleState
import me.rhunk.snapenhance.common.util.protobuf.ProtoEditor import me.rhunk.snapenhance.common.util.protobuf.ProtoEditor
@ -312,7 +313,7 @@ class EndToEndEncryption : MessagingRuleFeature(
if (messageTypeId == ENCRYPTED_MESSAGE_ID) { if (messageTypeId == ENCRYPTED_MESSAGE_ID) {
runCatching { runCatching {
eachBuffer(2) { eachBuffer(2) {
if (encryptedMessages.contains(clientMessageId)) return@eachBuffer if (decryptedMessageCache.containsKey(clientMessageId)) return@eachBuffer
val participantIdHash = getByteArray(1) ?: return@eachBuffer val participantIdHash = getByteArray(1) ?: return@eachBuffer
val iv = getByteArray(2) ?: return@eachBuffer val iv = getByteArray(2) ?: return@eachBuffer
@ -373,10 +374,15 @@ class EndToEndEncryption : MessagingRuleFeature(
return outputContentType to outputBuffer return outputContentType to outputBuffer
} }
private fun messageHook(conversationId: String, messageId: Long, senderId: String, messageContent: MessageContent) { private fun messageHook(conversationId: String, messageId: Long, senderId: String, messageContent: MessageContent, committed: Boolean) {
val (contentType, buffer) = tryDecryptMessage(senderId, messageId, conversationId, messageContent.contentType ?: ContentType.CHAT, messageContent.content!!) val (contentType, buffer) = tryDecryptMessage(senderId, messageId, conversationId, messageContent.contentType ?: ContentType.CHAT, messageContent.content!!)
messageContent.contentType = contentType messageContent.contentType = contentType
messageContent.content = buffer messageContent.content = buffer
// remove messages currently being sent from the cache
if (!committed) {
decryptedMessageCache.remove(messageId)
encryptedMessages.remove(messageId)
}
} }
override fun asyncInit() { override fun asyncInit() {
@ -520,11 +526,13 @@ class EndToEndEncryption : MessagingRuleFeature(
context.event.subscribe(BuildMessageEvent::class, priority = 0) { event -> context.event.subscribe(BuildMessageEvent::class, priority = 0) { event ->
val message = event.message val message = event.message
val conversationId = message.messageDescriptor!!.conversationId.toString() val conversationId = message.messageDescriptor!!.conversationId.toString()
val isMessageCommitted = message.messageState == MessageState.COMMITTED
messageHook( messageHook(
conversationId = conversationId, conversationId = conversationId,
messageId = message.messageDescriptor!!.messageId!!, messageId = message.messageDescriptor!!.messageId!!,
senderId = message.senderId.toString(), senderId = message.senderId.toString(),
messageContent = message.messageContent!! messageContent = message.messageContent!!,
committed = isMessageCommitted
) )
message.messageContent!!.instanceNonNull() message.messageContent!!.instanceNonNull()
@ -535,7 +543,8 @@ class EndToEndEncryption : MessagingRuleFeature(
conversationId = conversationId, conversationId = conversationId,
messageId = quotedMessage.getObjectField("mMessageId")?.toString()?.toLong() ?: return@also, messageId = quotedMessage.getObjectField("mMessageId")?.toString()?.toLong() ?: return@also,
senderId = SnapUUID(quotedMessage.getObjectField("mSenderId")).toString(), senderId = SnapUUID(quotedMessage.getObjectField("mSenderId")).toString(),
messageContent = MessageContent(quotedMessage) messageContent = MessageContent(quotedMessage),
committed = isMessageCommitted
) )
} }
} }