diff --git a/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt index d9060a79..7348c3c7 100644 --- a/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt +++ b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt @@ -93,13 +93,17 @@ class SyncServerTests { val tcsDataB = CompletableDeferred() channelB.setDataHandler { _, _, o, so, d -> - if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array()) + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b) } channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3))) val tcsDataA = CompletableDeferred() channelA.setDataHandler { _, _, o, so, d -> - if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array()) + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b) } channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6))) @@ -231,7 +235,9 @@ class SyncServerTests { val tcsDataB = CompletableDeferred() channelB.setDataHandler { _, _, o, so, d -> - if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array()) + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b) } channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData)) val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() } diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt index ab49adcb..1c7e1b7e 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt @@ -14,7 +14,7 @@ interface IChannel : AutoCloseable { val remoteVersion: Int? var authorizable: IAuthorizable? fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) - fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null) + fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null) fun setCloseHandler(onClose: ((IChannel) -> Unit)?) } @@ -326,12 +326,4 @@ class ChannelRelayed( completeHandshake(remoteVersion, transport) } } - - fun handleData(data: ByteBuffer) { - val size = data.int - if (size != data.remaining()) throw IllegalStateException("Incomplete packet received") - val opcode = data.get().toUByte() - val subOpcode = data.get().toUByte() - invokeDataHandler(opcode, subOpcode, data) - } } \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt index 61f5b8b0..24d64f87 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt @@ -155,7 +155,7 @@ class SyncSocketSession { val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize) //Logger.i(TAG, "Decrypted message (size = ${plen})") - handleData(_bufferDecrypted, plen) + handleData(_bufferDecrypted, plen, null) } catch (e: Throwable) { Logger.e(TAG, "Exception while receiving data", e) break @@ -374,22 +374,25 @@ class SyncSocketSession { } } - @OptIn(ExperimentalUnsignedTypes::class) - private fun handleData(data: ByteArray, length: Int) { + private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) { + return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel) + } + + private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) { + val length = data.remaining() if (length < HEADER_SIZE) throw Exception("Packet must be at least 6 bytes (header size)") - val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int + val size = data.int if (size != length - 4) throw Exception("Incomplete packet received") - val opcode = data.asUByteArray()[4] - val subOpcode = data.asUByteArray()[5] - val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2) - handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN)) + val opcode = data.get().toUByte() + val subOpcode = data.get().toUByte() + handlePacket(opcode, subOpcode, data, sourceChannel) } - private fun handleRequest(subOpcode: UByte, data: ByteBuffer) { + private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { when (subOpcode) { RequestOpcode.TRANSPORT_RELAYED.value -> { Logger.i(TAG, "Received request for a relayed transport") @@ -440,7 +443,7 @@ class SyncSocketSession { } } - private fun handleResponse(subOpcode: UByte, data: ByteBuffer) { + private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { if (data.remaining() < 8) { Logger.e(TAG, "Response packet too short") return @@ -651,7 +654,7 @@ class SyncSocketSession { return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied) } - private fun handleNotify(subOpcode: UByte, data: ByteBuffer) { + private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { when (subOpcode) { NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data) NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ } @@ -666,7 +669,7 @@ class SyncSocketSession { send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet) } - private fun handleRelay(subOpcode: UByte, data: ByteBuffer) { + private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { when (subOpcode) { RelayOpcode.RELAYED_DATA.value -> { if (data.remaining() < 8) { @@ -680,7 +683,7 @@ class SyncSocketSession { } val decryptedPayload = channel.decrypt(data) try { - channel.handleData(decryptedPayload) + handleData(decryptedPayload, channel) } catch (e: Exception) { Logger.e(TAG, "Exception while handling relayed data", e) channel.sendError(SyncErrorCode.ConnectionClosed) @@ -726,33 +729,36 @@ class SyncSocketSession { } } - private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})") when (opcode) { Opcode.PING.value -> { - send(Opcode.PONG.value) + if (sourceChannel != null) + sourceChannel.send(Opcode.PONG.value) + else + send(Opcode.PONG.value) //Logger.i(TAG, "Received ping, sent pong") return } Opcode.PONG.value -> { - //Logger.i(TAG, "Received pong") + Logger.v(TAG, "Received pong") return } Opcode.REQUEST.value -> { - handleRequest(subOpcode, data) + handleRequest(subOpcode, data, sourceChannel) return } Opcode.RESPONSE.value -> { - handleResponse(subOpcode, data) + handleResponse(subOpcode, data, sourceChannel) return } Opcode.NOTIFY.value -> { - handleNotify(subOpcode, data) + handleNotify(subOpcode, data, sourceChannel) return } Opcode.RELAY.value -> { - handleRelay(subOpcode, data) + handleRelay(subOpcode, data, sourceChannel) return } else -> if (isAuthorized) when (opcode) { @@ -809,12 +815,18 @@ class SyncSocketSession { throw Exception("After sync stream end, the stream must be complete") } - handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }) - } - else -> { - Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})") + handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel) } } + Opcode.DATA.value -> { + if (sourceChannel != null) + sourceChannel.invokeDataHandler(opcode, subOpcode, data) + else + _onData?.invoke(this, opcode, subOpcode, data) + } + else -> { + Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})") + } } } @@ -995,18 +1007,27 @@ class SyncSocketSession { val deferred = CompletableDeferred() _pendingPublishRequests[requestId] = deferred try { - val MAX_PLAINTEXT_SIZE = 65535 - 16 // Adjust for tag size + val MAX_PLAINTEXT_SIZE = 65535 val HANDSHAKE_SIZE = 48 val LENGTH_SIZE = 4 val TAG_SIZE = 16 val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE - val blobSize = HANDSHAKE_SIZE + chunkCount * (LENGTH_SIZE + MAX_PLAINTEXT_SIZE + TAG_SIZE) + + var blobSize = HANDSHAKE_SIZE + var dataOffset = 0 + for (i in 0 until chunkCount) { + val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset) + blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE) + dataOffset += chunkSize + } + val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize) val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN) packet.putInt(requestId) packet.put(keyBytes.size.toByte()) packet.put(keyBytes) packet.put(consumerPublicKeys.size.toByte()) + for (consumer in consumerPublicKeys) { val consumerBytes = Base64.getDecoder().decode(consumer) if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes") @@ -1020,9 +1041,10 @@ class SyncSocketSession { val transportPair = protocol.split() packet.putInt(blobSize) packet.put(handshakeMessage) - var dataOffset = 0 + + dataOffset = 0 for (i in 0 until chunkCount) { - val chunkSize = min(MAX_PLAINTEXT_SIZE, data.size - dataOffset) + val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset) val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize) val ciphertext = ByteArray(chunkSize + TAG_SIZE) val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)