diff --git a/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt index 7348c3c7..7607a2c9 100644 --- a/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt +++ b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt @@ -3,19 +3,21 @@ package com.futo.platformplayer import com.futo.platformplayer.noise.protocol.Noise import com.futo.platformplayer.sync.internal.* import kotlinx.coroutines.* +import kotlinx.coroutines.selects.select import org.junit.Assert.* import org.junit.Test import java.net.Socket import java.nio.ByteBuffer import kotlin.random.Random import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds class SyncServerTests { //private val relayHost = "relay.grayjay.app" //private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw=" private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw=" - private val relayHost = "192.168.1.175" + private val relayHost = "192.168.1.138" private val relayPort = 9000 /** Creates a client connected to the live relay server. */ @@ -23,7 +25,8 @@ class SyncServerTests { onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null, onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null, onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null, - isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null + isHandshakeAllowed: ((LinkType, SyncSocketSession, String, String?, UInt) -> Boolean)? = null, + onException: ((Throwable) -> Unit)? = null ): SyncSocketSession = withContext(Dispatchers.IO) { val p = Noise.createDH("25519") p.generateKeyPair() @@ -43,10 +46,14 @@ class SyncServerTests { }, onData = onData ?: { _, _, _, _ -> }, onNewChannel = onNewChannel ?: { _, _ -> }, - isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true } + isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _, _, _ -> true } ) socketSession.authorizable = AlwaysAuthorized() - socketSession.startAsInitiator(relayKey) + try { + socketSession.startAsInitiator(relayKey) + } catch (e: Throwable) { + onException?.invoke(e) + } withTimeout(5000.milliseconds) { tcs.await() } return@withContext socketSession } @@ -259,6 +266,71 @@ class SyncServerTests { clientA.stop() clientB.stop() } + + @Test + fun relayedTransport_WithValidAppId_Success() = runBlocking { + // Arrange: Set up clients + val allowedAppId = 1234u + val tcsB = CompletableDeferred() + + // Client B requires appId 1234 + val clientB = createClient( + onNewChannel = { _, c -> tcsB.complete(c) }, + isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId } + ) + + val clientA = createClient() + + // Act: Start relayed channel with valid appId + val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey, appId = allowedAppId) } + val channelB = withTimeout(5.seconds) { tcsB.await() } + withTimeout(5.seconds) { channelTask.await() } + + // Assert: Channel is established + assertNotNull("Channel should be created on target with valid appId", channelB) + + // Clean up + clientA.stop() + clientB.stop() + } + + @Test + fun relayedTransport_WithInvalidAppId_Fails() = runBlocking { + // Arrange: Set up clients + val allowedAppId = 1234u + val invalidAppId = 5678u + val tcsB = CompletableDeferred() + + // Client B requires appId 1234 + val clientB = createClient( + onNewChannel = { _, c -> tcsB.complete(c) }, + isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId }, + onException = { } + ) + + val clientA = createClient() + + // Act & Assert: Attempt with invalid appId should fail + try { + withTimeout(5.seconds) { + clientA.startRelayedChannel(clientB.localPublicKey, appId = invalidAppId) + } + fail("Starting relayed channel with invalid appId should fail") + } catch (e: Throwable) { + // Expected: The channel creation should time out or fail + } + + // Ensure no channel was created on client B + val completedTask = select { + tcsB.onAwait { "channel" } + async { delay(1.seconds); "timeout" }.onAwait { "timeout" } + } + assertEquals("No channel should be created with invalid appId", "timeout", completedTask) + + // Clean up + clientA.stop() + clientB.stop() + } } class AlwaysAuthorized : IAuthorizable { diff --git a/app/src/androidTest/java/com/futo/platformplayer/SyncTests.kt b/app/src/androidTest/java/com/futo/platformplayer/SyncTests.kt new file mode 100644 index 00000000..1b9f19cd --- /dev/null +++ b/app/src/androidTest/java/com/futo/platformplayer/SyncTests.kt @@ -0,0 +1,512 @@ +package com.futo.platformplayer + +import com.futo.platformplayer.noise.protocol.DHState +import com.futo.platformplayer.noise.protocol.Noise +import com.futo.platformplayer.sync.internal.* +import kotlinx.coroutines.* +import org.junit.Assert.* +import org.junit.Test +import java.io.PipedInputStream +import java.io.PipedOutputStream +import java.nio.ByteBuffer +import kotlin.random.Random +import java.io.InputStream +import java.io.OutputStream +import kotlin.time.Duration.Companion.seconds + +data class PipeStreams( + val initiatorInput: LittleEndianDataInputStream, + val initiatorOutput: LittleEndianDataOutputStream, + val responderInput: LittleEndianDataInputStream, + val responderOutput: LittleEndianDataOutputStream +) + +typealias OnHandshakeComplete = (SyncSocketSession) -> Unit +typealias IsHandshakeAllowed = (LinkType, SyncSocketSession, String, String?, UInt) -> Boolean +typealias OnClose = (SyncSocketSession) -> Unit +typealias OnData = (SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit + +class SyncSocketTests { + private fun createPipeStreams(): PipeStreams { + val initiatorOutput = PipedOutputStream() + val responderOutput = PipedOutputStream() + val responderInput = PipedInputStream(initiatorOutput) + val initiatorInput = PipedInputStream(responderOutput) + return PipeStreams( + LittleEndianDataInputStream(initiatorInput), LittleEndianDataOutputStream(initiatorOutput), + LittleEndianDataInputStream(responderInput), LittleEndianDataOutputStream(responderOutput) + ) + } + + fun generateKeyPair(): DHState { + val p = Noise.createDH("25519") + p.generateKeyPair() + return p + } + + private fun createSessions( + initiatorInput: LittleEndianDataInputStream, + initiatorOutput: LittleEndianDataOutputStream, + responderInput: LittleEndianDataInputStream, + responderOutput: LittleEndianDataOutputStream, + initiatorKeyPair: DHState, + responderKeyPair: DHState, + onInitiatorHandshakeComplete: OnHandshakeComplete, + onResponderHandshakeComplete: OnHandshakeComplete, + onInitiatorClose: OnClose? = null, + onResponderClose: OnClose? = null, + onClose: OnClose? = null, + isHandshakeAllowed: IsHandshakeAllowed? = null, + onDataA: OnData? = null, + onDataB: OnData? = null + ): Pair { + val initiatorSession = SyncSocketSession( + "", initiatorKeyPair, initiatorInput, initiatorOutput, + onClose = { + onClose?.invoke(it) + onInitiatorClose?.invoke(it) + }, + onHandshakeComplete = onInitiatorHandshakeComplete, + onData = onDataA, + isHandshakeAllowed = isHandshakeAllowed + ) + + val responderSession = SyncSocketSession( + "", responderKeyPair, responderInput, responderOutput, + onClose = { + onClose?.invoke(it) + onResponderClose?.invoke(it) + }, + onHandshakeComplete = onResponderHandshakeComplete, + onData = onDataB, + isHandshakeAllowed = isHandshakeAllowed + ) + + return Pair(initiatorSession, responderSession) + } + + @Test + fun handshake_WithValidPairingCode_Succeeds(): Unit = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val validPairingCode = "secret" + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = validPairingCode) + responderSession.startAsResponder() + + withTimeout(5.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + } + + @Test + fun handshake_WithInvalidPairingCode_Fails() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val validPairingCode = "secret" + val invalidPairingCode = "wrong" + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val initiatorClosed = CompletableDeferred() + val responderClosed = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onInitiatorClose = { + initiatorClosed.complete(true) + }, + onResponderClose = { + responderClosed.complete(true) + }, + isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = invalidPairingCode) + responderSession.startAsResponder() + + withTimeout(100.seconds) { + initiatorClosed.await() + responderClosed.await() + } + + assertFalse(handshakeInitiatorCompleted.isCompleted) + assertFalse(handshakeResponderCompleted.isCompleted) + } + + @Test + fun handshake_WithoutPairingCodeWhenRequired_Fails() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val validPairingCode = "secret" + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val initiatorClosed = CompletableDeferred() + val responderClosed = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onInitiatorClose = { + initiatorClosed.complete(true) + }, + onResponderClose = { + responderClosed.complete(true) + }, + isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) // No pairing code + responderSession.startAsResponder() + + withTimeout(5.seconds) { + initiatorClosed.await() + responderClosed.await() + } + + assertFalse(handshakeInitiatorCompleted.isCompleted) + assertFalse(handshakeResponderCompleted.isCompleted) + } + + @Test + fun handshake_WithPairingCodeWhenNotRequired_Succeeds(): Unit = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val pairingCode = "unnecessary" + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + isHandshakeAllowed = { _, _, _, _, _ -> true } // Always allow + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = pairingCode) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + } + + @Test + fun sendAndReceive_SmallDataPacket_Succeeds() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val tcsDataReceived = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onDataB = { _, opcode, subOpcode, data -> + if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) { + val b = ByteArray(data.remaining()) + data.get(b) + tcsDataReceived.complete(b) + } + } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + // Ensure both sessions are authorized + initiatorSession.authorizable = Authorized() + responderSession.authorizable = Authorized() + + val smallData = byteArrayOf(1, 2, 3) + initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(smallData)) + + val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() } + assertArrayEquals(smallData, receivedData) + } + + @Test + fun sendAndReceive_ExactlyMaximumPacketSize_Succeeds() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val tcsDataReceived = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onDataB = { _, opcode, subOpcode, data -> + if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) { + val b = ByteArray(data.remaining()) + data.get(b) + tcsDataReceived.complete(b) + } + } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + // Ensure both sessions are authorized + initiatorSession.authorizable = Authorized() + responderSession.authorizable = Authorized() + + val maxData = ByteArray(SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE).apply { Random.nextBytes(this) } + initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxData)) + + val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() } + assertArrayEquals(maxData, receivedData) + } + + @Test + fun stream_LargeData_Succeeds() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val tcsDataReceived = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onDataB = { _, opcode, subOpcode, data -> + if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) { + val b = ByteArray(data.remaining()) + data.get(b) + tcsDataReceived.complete(b) + } + } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + // Ensure both sessions are authorized + initiatorSession.authorizable = Authorized() + responderSession.authorizable = Authorized() + + val largeData = ByteArray(2 * (SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE)).apply { Random.nextBytes(this) } + initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData)) + + val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() } + assertArrayEquals(largeData, receivedData) + } + + @Test + fun authorizedSession_CanSendData() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val tcsDataReceived = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onDataB = { _, opcode, subOpcode, data -> + if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) { + val b = ByteArray(data.remaining()) + data.get(b) + tcsDataReceived.complete(b) + } + } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + // Authorize both sessions + initiatorSession.authorizable = Authorized() + responderSession.authorizable = Authorized() + + val data = byteArrayOf(1, 2, 3) + initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data)) + + val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() } + assertArrayEquals(data, receivedData) + } + + @Test + fun unauthorizedSession_CannotSendData() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val tcsDataReceived = CompletableDeferred() + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onDataB = { _, _, _, _ -> } + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey) + responderSession.startAsResponder() + + withTimeout(10.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + // Authorize initiator but not responder + initiatorSession.authorizable = Authorized() + responderSession.authorizable = Unauthorized() + + val data = byteArrayOf(1, 2, 3) + initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data)) + + delay(1.seconds) + assertFalse(tcsDataReceived.isCompleted) + } + + @Test + fun directHandshake_WithValidAppId_Succeeds() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val allowedAppId = 1234u + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + + val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt -> + linkType == LinkType.Direct && appId == allowedAppId + } + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + isHandshakeAllowed = responderIsHandshakeAllowed + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = allowedAppId) + responderSession.startAsResponder() + + withTimeout(5.seconds) { + handshakeInitiatorCompleted.await() + handshakeResponderCompleted.await() + } + + assertNotNull(initiatorSession.remotePublicKey) + assertNotNull(responderSession.remotePublicKey) + } + + @Test + fun directHandshake_WithInvalidAppId_Fails() = runBlocking { + val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams() + val initiatorKeyPair = generateKeyPair() + val responderKeyPair = generateKeyPair() + val allowedAppId = 1234u + val invalidAppId = 5678u + + val handshakeInitiatorCompleted = CompletableDeferred() + val handshakeResponderCompleted = CompletableDeferred() + val initiatorClosed = CompletableDeferred() + val responderClosed = CompletableDeferred() + + val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt -> + linkType == LinkType.Direct && appId == allowedAppId + } + + val (initiatorSession, responderSession) = createSessions( + initiatorInput, initiatorOutput, responderInput, responderOutput, + initiatorKeyPair, responderKeyPair, + { handshakeInitiatorCompleted.complete(true) }, + { handshakeResponderCompleted.complete(true) }, + onInitiatorClose = { + initiatorClosed.complete(true) + }, + onResponderClose = { + responderClosed.complete(true) + }, + isHandshakeAllowed = responderIsHandshakeAllowed + ) + + initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = invalidAppId) + responderSession.startAsResponder() + + withTimeout(5.seconds) { + initiatorClosed.await() + responderClosed.await() + } + + assertFalse(handshakeInitiatorCompleted.isCompleted) + assertFalse(handshakeResponderCompleted.isCompleted) + } +} + +class Authorized : IAuthorizable { + override val isAuthorized: Boolean = true +} + +class Unauthorized : IAuthorizable { + override val isAuthorized: Boolean = false +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/Utility.kt b/app/src/main/java/com/futo/platformplayer/Utility.kt index 7ddefc79..c1c7d3e8 100644 --- a/app/src/main/java/com/futo/platformplayer/Utility.kt +++ b/app/src/main/java/com/futo/platformplayer/Utility.kt @@ -69,7 +69,14 @@ fun warnIfMainThread(context: String) { } fun ensureNotMainThread() { - if (Looper.myLooper() == Looper.getMainLooper()) { + val isMainLooper = try { + Looper.myLooper() == Looper.getMainLooper() + } catch (e: Throwable) { + //Ignore, for unit tests where its not mocked + false + } + + if (isMainLooper) { Logger.e("Utility", "Throwing exception because a function that should not be called on main thread, is called on main thread") throw IllegalStateException("Cannot run on main thread") } diff --git a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt index 5f580d55..7e5d02e9 100644 --- a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt +++ b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt @@ -226,7 +226,7 @@ class StateSync { keyPair!!, LittleEndianDataInputStream(socket.getInputStream()), LittleEndianDataOutputStream(socket.getOutputStream()), - isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) }, + isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) }, onNewChannel = { _, c -> val remotePublicKey = c.remotePublicKey if (remotePublicKey == null) { @@ -297,7 +297,7 @@ class StateSync { if (connectionInfo.allowRemoteRelayed && Settings.instance.synchronization.connectThroughRelay) { try { Log.v(TAG, "Attempting relayed connection with '$targetKey'.") - runBlocking { relaySession.startRelayedChannel(targetKey, null) } + runBlocking { relaySession.startRelayedChannel(targetKey, APP_ID, null) } } catch (e: Throwable) { Log.e(TAG, "Failed to start relayed channel with $targetKey.", e) } @@ -318,7 +318,7 @@ class StateSync { override val isAuthorized: Boolean get() = true } - _relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, null) + _relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null) Log.i(TAG, "Started relay session.") } catch (e: Throwable) { @@ -731,8 +731,8 @@ class StateSync { ) } - private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?): Boolean { - Log.v(TAG, "Check if handshake allowed from '$publicKey'.") + private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?, appId: UInt): Boolean { + Log.v(TAG, "Check if handshake allowed from '$publicKey' (app id: $appId).") if (publicKey == RELAY_PUBLIC_KEY) return true @@ -744,7 +744,7 @@ class StateSync { } } - Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode'.") + Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode' (app id: $appId).") if (_pairingCode == null || pairingCode.isNullOrEmpty()) return false @@ -766,7 +766,7 @@ class StateSync { if (channelSocket != null) session?.removeChannel(channelSocket!!) }, - isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) }, + isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) }, onHandshakeComplete = { s -> val remotePublicKey = s.remotePublicKey if (remotePublicKey == null) { @@ -930,7 +930,7 @@ class StateSync { _remotePendingStatusUpdate[deviceInfo.publicKey] = onStatusUpdate } } - relaySession.startRelayedChannel(deviceInfo.publicKey, deviceInfo.pairingCode) + relaySession.startRelayedChannel(deviceInfo.publicKey, APP_ID, deviceInfo.pairingCode) } } else { throw e @@ -950,7 +950,7 @@ class StateSync { } } - session.startAsInitiator(publicKey, pairingCode) + session.startAsInitiator(publicKey, APP_ID, pairingCode) return session } @@ -1008,6 +1008,7 @@ class StateSync { val version = 1 val RELAY_SERVER = "relay.grayjay.app" val RELAY_PUBLIC_KEY = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw=" + val APP_ID = 0x534A5247u //GRayJaySync (GRJS) private const val TAG = "StateSync" const val PORT = 12315 diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt index bfcee6fd..84c1445c 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt @@ -246,7 +246,7 @@ class ChannelRelayed( } } - fun sendRequestTransport(requestId: Int, publicKey: String, pairingCode: String? = null) { + fun sendRequestTransport(requestId: Int, publicKey: String, appId: UInt, pairingCode: String? = null) { throwIfDisposed() synchronized(sendLock) { @@ -270,10 +270,11 @@ class ChannelRelayed( 0 to ByteArray(0) } - val packetSize = 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten + val packetSize = 4 + 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten val packet = ByteArray(packetSize) ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply { putInt(requestId) + putInt(appId.toInt()) put(publicKeyBytes) putInt(pairingMessageLength) if (pairingMessageLength > 0) put(pairingMessage) diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt index 2b3a7e10..ad928698 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt @@ -38,12 +38,13 @@ class SyncSocketSession { private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)? private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? - private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? + private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)? private var _cipherStatePair: CipherStatePair? = null private var _remotePublicKey: String? = null val remotePublicKey: String? get() = _remotePublicKey private var _started: Boolean = false private val _localKeyPair: DHState + private var _thread: Thread? = null private var _localPublicKey: String val localPublicKey: String get() = _localPublicKey private val _onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? @@ -87,7 +88,7 @@ class SyncSocketSession { onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null, onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null, onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null, - isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? = null + isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)? = null ) { _inputStream = inputStream _outputStream = outputStream @@ -105,31 +106,35 @@ class SyncSocketSession { _localPublicKey = Base64.getEncoder().encodeToString(localPublicKey) } - fun startAsInitiator(remotePublicKey: String, pairingCode: String? = null) { + fun startAsInitiator(remotePublicKey: String, appId: UInt = 0u, pairingCode: String? = null) { _started = true - try { - handshakeAsInitiator(remotePublicKey, pairingCode) - _onHandshakeComplete?.invoke(this) - receiveLoop() - } catch (e: Throwable) { - Logger.e(TAG, "Failed to run as initiator", e) - } finally { - stop() - } + _thread = Thread { + try { + handshakeAsInitiator(remotePublicKey, appId, pairingCode) + _onHandshakeComplete?.invoke(this) + receiveLoop() + } catch (e: Throwable) { + Logger.e(TAG, "Failed to run as initiator", e) + } finally { + stop() + } + }.apply { start() } } fun startAsResponder() { _started = true - try { - if (handshakeAsResponder()) { - _onHandshakeComplete?.invoke(this) - receiveLoop() + _thread = Thread { + try { + if (handshakeAsResponder()) { + _onHandshakeComplete?.invoke(this) + receiveLoop() + } + } catch (e: Throwable) { + Logger.e(TAG, "Failed to run as responder", e) + } finally { + stop() } - } catch (e: Throwable) { - Logger.e(TAG, "Failed to run as responder", e) - } finally { - stop() - } + }.apply { start() } } private fun receiveLoop() { @@ -187,12 +192,13 @@ class SyncSocketSession { _onClose?.invoke(this) _inputStream.close() _outputStream.close() + _thread = null _cipherStatePair?.sender?.destroy() _cipherStatePair?.receiver?.destroy() Logger.i(TAG, "Session closed") } - private fun handshakeAsInitiator(remotePublicKey: String, pairingCode: String?) { + private fun handshakeAsInitiator(remotePublicKey: String, appId: UInt, pairingCode: String?) { performVersionCheck() val initiator = HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR) @@ -218,7 +224,8 @@ class SyncSocketSession { val mainBuffer = ByteArray(512) val mainLength = initiator.writeMessage(mainBuffer, 0, null, 0, 0) - val messageData = ByteBuffer.allocate(4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN) + val messageData = ByteBuffer.allocate(4 + 4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN) + messageData.putInt(appId.toInt()) messageData.putInt(pairingMessageLength) if (pairingMessageLength > 0) messageData.put(pairingMessage) messageData.put(mainBuffer, 0, mainLength) @@ -250,9 +257,10 @@ class SyncSocketSession { _inputStream.readFully(message) val messageBuffer = ByteBuffer.wrap(message).order(ByteOrder.LITTLE_ENDIAN) + val appId = messageBuffer.int.toUInt() val pairingMessageLength = messageBuffer.int val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { messageBuffer.get(it) } else byteArrayOf() - val mainLength = messageSize - 4 - pairingMessageLength + val mainLength = messageSize - 4 - 4 - pairingMessageLength val mainMessage = ByteArray(mainLength).also { messageBuffer.get(it) } var pairingCode: String? = null @@ -267,6 +275,15 @@ class SyncSocketSession { val plaintext = ByteArray(512) responder.readMessage(mainMessage, 0, mainLength, plaintext, 0) + val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength) + responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0) + val remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes) + + val isAllowedToConnect = remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, remotePublicKey, pairingCode, appId) ?: true) + if (!isAllowedToConnect) { + stop() + return false + } val responseBuffer = ByteArray(512) val responseLength = responder.writeMessage(responseBuffer, 0, null, 0, 0) @@ -274,13 +291,8 @@ class SyncSocketSession { _outputStream.write(responseBuffer, 0, responseLength) _cipherStatePair = responder.split() - val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength) - responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0) - _remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes) - - return (_remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, _remotePublicKey!!, pairingCode) ?: true)).also { - if (!it) stop() - } + _remotePublicKey = remotePublicKey + return true } private fun performVersionCheck() { @@ -400,13 +412,14 @@ class SyncSocketSession { val remoteVersion = data.int val connectionId = data.long val requestId = data.int + val appId = data.int.toUInt() val publicKeyBytes = ByteArray(32).also { data.get(it) } val pairingMessageLength = data.int - if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128)") + if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128) (app id: $appId)") val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { data.get(it) } else ByteArray(0) val channelMessageLength = data.int if (data.remaining() != channelMessageLength) { - Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()}") + Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()} (app id: $appId)") return } val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) } @@ -420,7 +433,7 @@ class SyncSocketSession { val length = pairingProtocol.readMessage(pairingMessage, 0, pairingMessageLength, plaintext, 0) String(plaintext, 0, length, Charsets.UTF_8) } else null - val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode) ?: true) + val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode, appId) ?: true) if (!isAllowed) { val rp = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN) rp.putInt(2) // Status code for not allowed @@ -876,14 +889,14 @@ class SyncSocketSession { return deferred.await() } - suspend fun startRelayedChannel(publicKey: String, pairingCode: String? = null): ChannelRelayed? { + suspend fun startRelayedChannel(publicKey: String, appId: UInt = 0u, pairingCode: String? = null): ChannelRelayed? { val requestId = generateRequestId() val deferred = CompletableDeferred() val channel = ChannelRelayed(this, _localKeyPair, publicKey, true) _onNewChannel?.invoke(this, channel) _pendingChannels[requestId] = channel to deferred try { - channel.sendRequestTransport(requestId, publicKey, pairingCode) + channel.sendRequestTransport(requestId, publicKey, appId, pairingCode) } catch (e: Exception) { _pendingChannels.remove(requestId)?.let { it.first.close(); it.second.completeExceptionally(e) } throw e diff --git a/app/src/test/java/com/futo/platformplayer/NoiseProtocolTests.kt b/app/src/test/java/com/futo/platformplayer/NoiseProtocolTests.kt index 33b640f9..189fc767 100644 --- a/app/src/test/java/com/futo/platformplayer/NoiseProtocolTests.kt +++ b/app/src/test/java/com/futo/platformplayer/NoiseProtocolTests.kt @@ -9,6 +9,7 @@ import com.futo.platformplayer.noise.protocol.HandshakeState import com.futo.platformplayer.noise.protocol.Noise import com.futo.platformplayer.states.StateSync import com.futo.platformplayer.sync.internal.IAuthorizable +import com.futo.platformplayer.sync.internal.Opcode import com.futo.platformplayer.sync.internal.SyncSocketSession import com.futo.platformplayer.sync.internal.SyncStream import junit.framework.TestCase.assertEquals @@ -586,16 +587,16 @@ class NoiseProtocolTest { handshakeLatch.await(10, TimeUnit.SECONDS) // Simulate initiator sending a PING and responder replying with PONG - initiatorSession.send(SyncSocketSession.Opcode.PING.value) - responderSession.send(SyncSocketSession.Opcode.PONG.value) + initiatorSession.send(Opcode.PING.value) + responderSession.send(Opcode.PONG.value) // Test data transfer - responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesExactlyOnePacket) - initiatorSession.send(SyncSocketSession.Opcode.DATA.value, 1u, randomBytes) + responderSession.send(Opcode.DATA.value, 0u, randomBytesExactlyOnePacket) + initiatorSession.send(Opcode.DATA.value, 1u, randomBytes) // Send large data to test stream handling val start = System.currentTimeMillis() - responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesBig) + responderSession.send(Opcode.DATA.value, 0u, randomBytesBig) println("Sent 10MB in ${System.currentTimeMillis() - start}ms") // Wait for a brief period to simulate delay and allow communication