mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2025-04-29 22:24:29 +02:00
Implemented app id and updated unit tests.
This commit is contained in:
parent
daa91986ef
commit
c4623c80ff
@ -3,19 +3,21 @@ package com.futo.platformplayer
|
|||||||
import com.futo.platformplayer.noise.protocol.Noise
|
import com.futo.platformplayer.noise.protocol.Noise
|
||||||
import com.futo.platformplayer.sync.internal.*
|
import com.futo.platformplayer.sync.internal.*
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
|
import kotlinx.coroutines.selects.select
|
||||||
import org.junit.Assert.*
|
import org.junit.Assert.*
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import java.net.Socket
|
import java.net.Socket
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.time.Duration.Companion.milliseconds
|
import kotlin.time.Duration.Companion.milliseconds
|
||||||
|
import kotlin.time.Duration.Companion.seconds
|
||||||
|
|
||||||
class SyncServerTests {
|
class SyncServerTests {
|
||||||
|
|
||||||
//private val relayHost = "relay.grayjay.app"
|
//private val relayHost = "relay.grayjay.app"
|
||||||
//private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
//private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
||||||
private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw="
|
private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw="
|
||||||
private val relayHost = "192.168.1.175"
|
private val relayHost = "192.168.1.138"
|
||||||
private val relayPort = 9000
|
private val relayPort = 9000
|
||||||
|
|
||||||
/** Creates a client connected to the live relay server. */
|
/** Creates a client connected to the live relay server. */
|
||||||
@ -23,7 +25,8 @@ class SyncServerTests {
|
|||||||
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
|
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
|
||||||
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
|
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
|
||||||
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
|
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
|
||||||
isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null
|
isHandshakeAllowed: ((LinkType, SyncSocketSession, String, String?, UInt) -> Boolean)? = null,
|
||||||
|
onException: ((Throwable) -> Unit)? = null
|
||||||
): SyncSocketSession = withContext(Dispatchers.IO) {
|
): SyncSocketSession = withContext(Dispatchers.IO) {
|
||||||
val p = Noise.createDH("25519")
|
val p = Noise.createDH("25519")
|
||||||
p.generateKeyPair()
|
p.generateKeyPair()
|
||||||
@ -43,10 +46,14 @@ class SyncServerTests {
|
|||||||
},
|
},
|
||||||
onData = onData ?: { _, _, _, _ -> },
|
onData = onData ?: { _, _, _, _ -> },
|
||||||
onNewChannel = onNewChannel ?: { _, _ -> },
|
onNewChannel = onNewChannel ?: { _, _ -> },
|
||||||
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true }
|
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _, _, _ -> true }
|
||||||
)
|
)
|
||||||
socketSession.authorizable = AlwaysAuthorized()
|
socketSession.authorizable = AlwaysAuthorized()
|
||||||
socketSession.startAsInitiator(relayKey)
|
try {
|
||||||
|
socketSession.startAsInitiator(relayKey)
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
onException?.invoke(e)
|
||||||
|
}
|
||||||
withTimeout(5000.milliseconds) { tcs.await() }
|
withTimeout(5000.milliseconds) { tcs.await() }
|
||||||
return@withContext socketSession
|
return@withContext socketSession
|
||||||
}
|
}
|
||||||
@ -259,6 +266,71 @@ class SyncServerTests {
|
|||||||
clientA.stop()
|
clientA.stop()
|
||||||
clientB.stop()
|
clientB.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun relayedTransport_WithValidAppId_Success() = runBlocking {
|
||||||
|
// Arrange: Set up clients
|
||||||
|
val allowedAppId = 1234u
|
||||||
|
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||||
|
|
||||||
|
// Client B requires appId 1234
|
||||||
|
val clientB = createClient(
|
||||||
|
onNewChannel = { _, c -> tcsB.complete(c) },
|
||||||
|
isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId }
|
||||||
|
)
|
||||||
|
|
||||||
|
val clientA = createClient()
|
||||||
|
|
||||||
|
// Act: Start relayed channel with valid appId
|
||||||
|
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey, appId = allowedAppId) }
|
||||||
|
val channelB = withTimeout(5.seconds) { tcsB.await() }
|
||||||
|
withTimeout(5.seconds) { channelTask.await() }
|
||||||
|
|
||||||
|
// Assert: Channel is established
|
||||||
|
assertNotNull("Channel should be created on target with valid appId", channelB)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
clientA.stop()
|
||||||
|
clientB.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun relayedTransport_WithInvalidAppId_Fails() = runBlocking {
|
||||||
|
// Arrange: Set up clients
|
||||||
|
val allowedAppId = 1234u
|
||||||
|
val invalidAppId = 5678u
|
||||||
|
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||||
|
|
||||||
|
// Client B requires appId 1234
|
||||||
|
val clientB = createClient(
|
||||||
|
onNewChannel = { _, c -> tcsB.complete(c) },
|
||||||
|
isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId },
|
||||||
|
onException = { }
|
||||||
|
)
|
||||||
|
|
||||||
|
val clientA = createClient()
|
||||||
|
|
||||||
|
// Act & Assert: Attempt with invalid appId should fail
|
||||||
|
try {
|
||||||
|
withTimeout(5.seconds) {
|
||||||
|
clientA.startRelayedChannel(clientB.localPublicKey, appId = invalidAppId)
|
||||||
|
}
|
||||||
|
fail("Starting relayed channel with invalid appId should fail")
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
// Expected: The channel creation should time out or fail
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no channel was created on client B
|
||||||
|
val completedTask = select {
|
||||||
|
tcsB.onAwait { "channel" }
|
||||||
|
async { delay(1.seconds); "timeout" }.onAwait { "timeout" }
|
||||||
|
}
|
||||||
|
assertEquals("No channel should be created with invalid appId", "timeout", completedTask)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
clientA.stop()
|
||||||
|
clientB.stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class AlwaysAuthorized : IAuthorizable {
|
class AlwaysAuthorized : IAuthorizable {
|
||||||
|
512
app/src/androidTest/java/com/futo/platformplayer/SyncTests.kt
Normal file
512
app/src/androidTest/java/com/futo/platformplayer/SyncTests.kt
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
package com.futo.platformplayer
|
||||||
|
|
||||||
|
import com.futo.platformplayer.noise.protocol.DHState
|
||||||
|
import com.futo.platformplayer.noise.protocol.Noise
|
||||||
|
import com.futo.platformplayer.sync.internal.*
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import org.junit.Assert.*
|
||||||
|
import org.junit.Test
|
||||||
|
import java.io.PipedInputStream
|
||||||
|
import java.io.PipedOutputStream
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import kotlin.random.Random
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.io.OutputStream
|
||||||
|
import kotlin.time.Duration.Companion.seconds
|
||||||
|
|
||||||
|
data class PipeStreams(
|
||||||
|
val initiatorInput: LittleEndianDataInputStream,
|
||||||
|
val initiatorOutput: LittleEndianDataOutputStream,
|
||||||
|
val responderInput: LittleEndianDataInputStream,
|
||||||
|
val responderOutput: LittleEndianDataOutputStream
|
||||||
|
)
|
||||||
|
|
||||||
|
typealias OnHandshakeComplete = (SyncSocketSession) -> Unit
|
||||||
|
typealias IsHandshakeAllowed = (LinkType, SyncSocketSession, String, String?, UInt) -> Boolean
|
||||||
|
typealias OnClose = (SyncSocketSession) -> Unit
|
||||||
|
typealias OnData = (SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit
|
||||||
|
|
||||||
|
class SyncSocketTests {
|
||||||
|
private fun createPipeStreams(): PipeStreams {
|
||||||
|
val initiatorOutput = PipedOutputStream()
|
||||||
|
val responderOutput = PipedOutputStream()
|
||||||
|
val responderInput = PipedInputStream(initiatorOutput)
|
||||||
|
val initiatorInput = PipedInputStream(responderOutput)
|
||||||
|
return PipeStreams(
|
||||||
|
LittleEndianDataInputStream(initiatorInput), LittleEndianDataOutputStream(initiatorOutput),
|
||||||
|
LittleEndianDataInputStream(responderInput), LittleEndianDataOutputStream(responderOutput)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun generateKeyPair(): DHState {
|
||||||
|
val p = Noise.createDH("25519")
|
||||||
|
p.generateKeyPair()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createSessions(
|
||||||
|
initiatorInput: LittleEndianDataInputStream,
|
||||||
|
initiatorOutput: LittleEndianDataOutputStream,
|
||||||
|
responderInput: LittleEndianDataInputStream,
|
||||||
|
responderOutput: LittleEndianDataOutputStream,
|
||||||
|
initiatorKeyPair: DHState,
|
||||||
|
responderKeyPair: DHState,
|
||||||
|
onInitiatorHandshakeComplete: OnHandshakeComplete,
|
||||||
|
onResponderHandshakeComplete: OnHandshakeComplete,
|
||||||
|
onInitiatorClose: OnClose? = null,
|
||||||
|
onResponderClose: OnClose? = null,
|
||||||
|
onClose: OnClose? = null,
|
||||||
|
isHandshakeAllowed: IsHandshakeAllowed? = null,
|
||||||
|
onDataA: OnData? = null,
|
||||||
|
onDataB: OnData? = null
|
||||||
|
): Pair<SyncSocketSession, SyncSocketSession> {
|
||||||
|
val initiatorSession = SyncSocketSession(
|
||||||
|
"", initiatorKeyPair, initiatorInput, initiatorOutput,
|
||||||
|
onClose = {
|
||||||
|
onClose?.invoke(it)
|
||||||
|
onInitiatorClose?.invoke(it)
|
||||||
|
},
|
||||||
|
onHandshakeComplete = onInitiatorHandshakeComplete,
|
||||||
|
onData = onDataA,
|
||||||
|
isHandshakeAllowed = isHandshakeAllowed
|
||||||
|
)
|
||||||
|
|
||||||
|
val responderSession = SyncSocketSession(
|
||||||
|
"", responderKeyPair, responderInput, responderOutput,
|
||||||
|
onClose = {
|
||||||
|
onClose?.invoke(it)
|
||||||
|
onResponderClose?.invoke(it)
|
||||||
|
},
|
||||||
|
onHandshakeComplete = onResponderHandshakeComplete,
|
||||||
|
onData = onDataB,
|
||||||
|
isHandshakeAllowed = isHandshakeAllowed
|
||||||
|
)
|
||||||
|
|
||||||
|
return Pair(initiatorSession, responderSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun handshake_WithValidPairingCode_Succeeds(): Unit = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val validPairingCode = "secret"
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = validPairingCode)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(5.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun handshake_WithInvalidPairingCode_Fails() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val validPairingCode = "secret"
|
||||||
|
val invalidPairingCode = "wrong"
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val initiatorClosed = CompletableDeferred<Boolean>()
|
||||||
|
val responderClosed = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onInitiatorClose = {
|
||||||
|
initiatorClosed.complete(true)
|
||||||
|
},
|
||||||
|
onResponderClose = {
|
||||||
|
responderClosed.complete(true)
|
||||||
|
},
|
||||||
|
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = invalidPairingCode)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(100.seconds) {
|
||||||
|
initiatorClosed.await()
|
||||||
|
responderClosed.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFalse(handshakeInitiatorCompleted.isCompleted)
|
||||||
|
assertFalse(handshakeResponderCompleted.isCompleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun handshake_WithoutPairingCodeWhenRequired_Fails() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val validPairingCode = "secret"
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val initiatorClosed = CompletableDeferred<Boolean>()
|
||||||
|
val responderClosed = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onInitiatorClose = {
|
||||||
|
initiatorClosed.complete(true)
|
||||||
|
},
|
||||||
|
onResponderClose = {
|
||||||
|
responderClosed.complete(true)
|
||||||
|
},
|
||||||
|
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey) // No pairing code
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(5.seconds) {
|
||||||
|
initiatorClosed.await()
|
||||||
|
responderClosed.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFalse(handshakeInitiatorCompleted.isCompleted)
|
||||||
|
assertFalse(handshakeResponderCompleted.isCompleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun handshake_WithPairingCodeWhenNotRequired_Succeeds(): Unit = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val pairingCode = "unnecessary"
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
isHandshakeAllowed = { _, _, _, _, _ -> true } // Always allow
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = pairingCode)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun sendAndReceive_SmallDataPacket_Succeeds() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val tcsDataReceived = CompletableDeferred<ByteArray>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onDataB = { _, opcode, subOpcode, data ->
|
||||||
|
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
|
||||||
|
val b = ByteArray(data.remaining())
|
||||||
|
data.get(b)
|
||||||
|
tcsDataReceived.complete(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure both sessions are authorized
|
||||||
|
initiatorSession.authorizable = Authorized()
|
||||||
|
responderSession.authorizable = Authorized()
|
||||||
|
|
||||||
|
val smallData = byteArrayOf(1, 2, 3)
|
||||||
|
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(smallData))
|
||||||
|
|
||||||
|
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
|
||||||
|
assertArrayEquals(smallData, receivedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun sendAndReceive_ExactlyMaximumPacketSize_Succeeds() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val tcsDataReceived = CompletableDeferred<ByteArray>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onDataB = { _, opcode, subOpcode, data ->
|
||||||
|
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
|
||||||
|
val b = ByteArray(data.remaining())
|
||||||
|
data.get(b)
|
||||||
|
tcsDataReceived.complete(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure both sessions are authorized
|
||||||
|
initiatorSession.authorizable = Authorized()
|
||||||
|
responderSession.authorizable = Authorized()
|
||||||
|
|
||||||
|
val maxData = ByteArray(SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE).apply { Random.nextBytes(this) }
|
||||||
|
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxData))
|
||||||
|
|
||||||
|
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
|
||||||
|
assertArrayEquals(maxData, receivedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun stream_LargeData_Succeeds() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val tcsDataReceived = CompletableDeferred<ByteArray>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onDataB = { _, opcode, subOpcode, data ->
|
||||||
|
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
|
||||||
|
val b = ByteArray(data.remaining())
|
||||||
|
data.get(b)
|
||||||
|
tcsDataReceived.complete(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure both sessions are authorized
|
||||||
|
initiatorSession.authorizable = Authorized()
|
||||||
|
responderSession.authorizable = Authorized()
|
||||||
|
|
||||||
|
val largeData = ByteArray(2 * (SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE)).apply { Random.nextBytes(this) }
|
||||||
|
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
||||||
|
|
||||||
|
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
|
||||||
|
assertArrayEquals(largeData, receivedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun authorizedSession_CanSendData() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val tcsDataReceived = CompletableDeferred<ByteArray>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onDataB = { _, opcode, subOpcode, data ->
|
||||||
|
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
|
||||||
|
val b = ByteArray(data.remaining())
|
||||||
|
data.get(b)
|
||||||
|
tcsDataReceived.complete(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize both sessions
|
||||||
|
initiatorSession.authorizable = Authorized()
|
||||||
|
responderSession.authorizable = Authorized()
|
||||||
|
|
||||||
|
val data = byteArrayOf(1, 2, 3)
|
||||||
|
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data))
|
||||||
|
|
||||||
|
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
|
||||||
|
assertArrayEquals(data, receivedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun unauthorizedSession_CannotSendData() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val tcsDataReceived = CompletableDeferred<ByteArray>()
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onDataB = { _, _, _, _ -> }
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(10.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize initiator but not responder
|
||||||
|
initiatorSession.authorizable = Authorized()
|
||||||
|
responderSession.authorizable = Unauthorized()
|
||||||
|
|
||||||
|
val data = byteArrayOf(1, 2, 3)
|
||||||
|
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data))
|
||||||
|
|
||||||
|
delay(1.seconds)
|
||||||
|
assertFalse(tcsDataReceived.isCompleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun directHandshake_WithValidAppId_Succeeds() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val allowedAppId = 1234u
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt ->
|
||||||
|
linkType == LinkType.Direct && appId == allowedAppId
|
||||||
|
}
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
isHandshakeAllowed = responderIsHandshakeAllowed
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = allowedAppId)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(5.seconds) {
|
||||||
|
handshakeInitiatorCompleted.await()
|
||||||
|
handshakeResponderCompleted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNotNull(initiatorSession.remotePublicKey)
|
||||||
|
assertNotNull(responderSession.remotePublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun directHandshake_WithInvalidAppId_Fails() = runBlocking {
|
||||||
|
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
|
||||||
|
val initiatorKeyPair = generateKeyPair()
|
||||||
|
val responderKeyPair = generateKeyPair()
|
||||||
|
val allowedAppId = 1234u
|
||||||
|
val invalidAppId = 5678u
|
||||||
|
|
||||||
|
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
|
||||||
|
val initiatorClosed = CompletableDeferred<Boolean>()
|
||||||
|
val responderClosed = CompletableDeferred<Boolean>()
|
||||||
|
|
||||||
|
val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt ->
|
||||||
|
linkType == LinkType.Direct && appId == allowedAppId
|
||||||
|
}
|
||||||
|
|
||||||
|
val (initiatorSession, responderSession) = createSessions(
|
||||||
|
initiatorInput, initiatorOutput, responderInput, responderOutput,
|
||||||
|
initiatorKeyPair, responderKeyPair,
|
||||||
|
{ handshakeInitiatorCompleted.complete(true) },
|
||||||
|
{ handshakeResponderCompleted.complete(true) },
|
||||||
|
onInitiatorClose = {
|
||||||
|
initiatorClosed.complete(true)
|
||||||
|
},
|
||||||
|
onResponderClose = {
|
||||||
|
responderClosed.complete(true)
|
||||||
|
},
|
||||||
|
isHandshakeAllowed = responderIsHandshakeAllowed
|
||||||
|
)
|
||||||
|
|
||||||
|
initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = invalidAppId)
|
||||||
|
responderSession.startAsResponder()
|
||||||
|
|
||||||
|
withTimeout(5.seconds) {
|
||||||
|
initiatorClosed.await()
|
||||||
|
responderClosed.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
assertFalse(handshakeInitiatorCompleted.isCompleted)
|
||||||
|
assertFalse(handshakeResponderCompleted.isCompleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class Authorized : IAuthorizable {
|
||||||
|
override val isAuthorized: Boolean = true
|
||||||
|
}
|
||||||
|
|
||||||
|
class Unauthorized : IAuthorizable {
|
||||||
|
override val isAuthorized: Boolean = false
|
||||||
|
}
|
@ -69,7 +69,14 @@ fun warnIfMainThread(context: String) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun ensureNotMainThread() {
|
fun ensureNotMainThread() {
|
||||||
if (Looper.myLooper() == Looper.getMainLooper()) {
|
val isMainLooper = try {
|
||||||
|
Looper.myLooper() == Looper.getMainLooper()
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
//Ignore, for unit tests where its not mocked
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isMainLooper) {
|
||||||
Logger.e("Utility", "Throwing exception because a function that should not be called on main thread, is called on main thread")
|
Logger.e("Utility", "Throwing exception because a function that should not be called on main thread, is called on main thread")
|
||||||
throw IllegalStateException("Cannot run on main thread")
|
throw IllegalStateException("Cannot run on main thread")
|
||||||
}
|
}
|
||||||
|
@ -226,7 +226,7 @@ class StateSync {
|
|||||||
keyPair!!,
|
keyPair!!,
|
||||||
LittleEndianDataInputStream(socket.getInputStream()),
|
LittleEndianDataInputStream(socket.getInputStream()),
|
||||||
LittleEndianDataOutputStream(socket.getOutputStream()),
|
LittleEndianDataOutputStream(socket.getOutputStream()),
|
||||||
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) },
|
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) },
|
||||||
onNewChannel = { _, c ->
|
onNewChannel = { _, c ->
|
||||||
val remotePublicKey = c.remotePublicKey
|
val remotePublicKey = c.remotePublicKey
|
||||||
if (remotePublicKey == null) {
|
if (remotePublicKey == null) {
|
||||||
@ -297,7 +297,7 @@ class StateSync {
|
|||||||
if (connectionInfo.allowRemoteRelayed && Settings.instance.synchronization.connectThroughRelay) {
|
if (connectionInfo.allowRemoteRelayed && Settings.instance.synchronization.connectThroughRelay) {
|
||||||
try {
|
try {
|
||||||
Log.v(TAG, "Attempting relayed connection with '$targetKey'.")
|
Log.v(TAG, "Attempting relayed connection with '$targetKey'.")
|
||||||
runBlocking { relaySession.startRelayedChannel(targetKey, null) }
|
runBlocking { relaySession.startRelayedChannel(targetKey, APP_ID, null) }
|
||||||
} catch (e: Throwable) {
|
} catch (e: Throwable) {
|
||||||
Log.e(TAG, "Failed to start relayed channel with $targetKey.", e)
|
Log.e(TAG, "Failed to start relayed channel with $targetKey.", e)
|
||||||
}
|
}
|
||||||
@ -318,7 +318,7 @@ class StateSync {
|
|||||||
override val isAuthorized: Boolean get() = true
|
override val isAuthorized: Boolean get() = true
|
||||||
}
|
}
|
||||||
|
|
||||||
_relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, null)
|
_relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null)
|
||||||
|
|
||||||
Log.i(TAG, "Started relay session.")
|
Log.i(TAG, "Started relay session.")
|
||||||
} catch (e: Throwable) {
|
} catch (e: Throwable) {
|
||||||
@ -731,8 +731,8 @@ class StateSync {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?): Boolean {
|
private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?, appId: UInt): Boolean {
|
||||||
Log.v(TAG, "Check if handshake allowed from '$publicKey'.")
|
Log.v(TAG, "Check if handshake allowed from '$publicKey' (app id: $appId).")
|
||||||
if (publicKey == RELAY_PUBLIC_KEY)
|
if (publicKey == RELAY_PUBLIC_KEY)
|
||||||
return true
|
return true
|
||||||
|
|
||||||
@ -744,7 +744,7 @@ class StateSync {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode'.")
|
Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode' (app id: $appId).")
|
||||||
if (_pairingCode == null || pairingCode.isNullOrEmpty())
|
if (_pairingCode == null || pairingCode.isNullOrEmpty())
|
||||||
return false
|
return false
|
||||||
|
|
||||||
@ -766,7 +766,7 @@ class StateSync {
|
|||||||
if (channelSocket != null)
|
if (channelSocket != null)
|
||||||
session?.removeChannel(channelSocket!!)
|
session?.removeChannel(channelSocket!!)
|
||||||
},
|
},
|
||||||
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) },
|
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) },
|
||||||
onHandshakeComplete = { s ->
|
onHandshakeComplete = { s ->
|
||||||
val remotePublicKey = s.remotePublicKey
|
val remotePublicKey = s.remotePublicKey
|
||||||
if (remotePublicKey == null) {
|
if (remotePublicKey == null) {
|
||||||
@ -930,7 +930,7 @@ class StateSync {
|
|||||||
_remotePendingStatusUpdate[deviceInfo.publicKey] = onStatusUpdate
|
_remotePendingStatusUpdate[deviceInfo.publicKey] = onStatusUpdate
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
relaySession.startRelayedChannel(deviceInfo.publicKey, deviceInfo.pairingCode)
|
relaySession.startRelayedChannel(deviceInfo.publicKey, APP_ID, deviceInfo.pairingCode)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw e
|
throw e
|
||||||
@ -950,7 +950,7 @@ class StateSync {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
session.startAsInitiator(publicKey, pairingCode)
|
session.startAsInitiator(publicKey, APP_ID, pairingCode)
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1008,6 +1008,7 @@ class StateSync {
|
|||||||
val version = 1
|
val version = 1
|
||||||
val RELAY_SERVER = "relay.grayjay.app"
|
val RELAY_SERVER = "relay.grayjay.app"
|
||||||
val RELAY_PUBLIC_KEY = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
val RELAY_PUBLIC_KEY = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
||||||
|
val APP_ID = 0x534A5247u //GRayJaySync (GRJS)
|
||||||
|
|
||||||
private const val TAG = "StateSync"
|
private const val TAG = "StateSync"
|
||||||
const val PORT = 12315
|
const val PORT = 12315
|
||||||
|
@ -246,7 +246,7 @@ class ChannelRelayed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun sendRequestTransport(requestId: Int, publicKey: String, pairingCode: String? = null) {
|
fun sendRequestTransport(requestId: Int, publicKey: String, appId: UInt, pairingCode: String? = null) {
|
||||||
throwIfDisposed()
|
throwIfDisposed()
|
||||||
|
|
||||||
synchronized(sendLock) {
|
synchronized(sendLock) {
|
||||||
@ -270,10 +270,11 @@ class ChannelRelayed(
|
|||||||
0 to ByteArray(0)
|
0 to ByteArray(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
val packetSize = 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten
|
val packetSize = 4 + 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten
|
||||||
val packet = ByteArray(packetSize)
|
val packet = ByteArray(packetSize)
|
||||||
ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply {
|
ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply {
|
||||||
putInt(requestId)
|
putInt(requestId)
|
||||||
|
putInt(appId.toInt())
|
||||||
put(publicKeyBytes)
|
put(publicKeyBytes)
|
||||||
putInt(pairingMessageLength)
|
putInt(pairingMessageLength)
|
||||||
if (pairingMessageLength > 0) put(pairingMessage)
|
if (pairingMessageLength > 0) put(pairingMessage)
|
||||||
|
@ -38,12 +38,13 @@ class SyncSocketSession {
|
|||||||
private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)?
|
private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)?
|
||||||
private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)?
|
private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)?
|
||||||
private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)?
|
private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)?
|
||||||
private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)?
|
private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)?
|
||||||
private var _cipherStatePair: CipherStatePair? = null
|
private var _cipherStatePair: CipherStatePair? = null
|
||||||
private var _remotePublicKey: String? = null
|
private var _remotePublicKey: String? = null
|
||||||
val remotePublicKey: String? get() = _remotePublicKey
|
val remotePublicKey: String? get() = _remotePublicKey
|
||||||
private var _started: Boolean = false
|
private var _started: Boolean = false
|
||||||
private val _localKeyPair: DHState
|
private val _localKeyPair: DHState
|
||||||
|
private var _thread: Thread? = null
|
||||||
private var _localPublicKey: String
|
private var _localPublicKey: String
|
||||||
val localPublicKey: String get() = _localPublicKey
|
val localPublicKey: String get() = _localPublicKey
|
||||||
private val _onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)?
|
private val _onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)?
|
||||||
@ -87,7 +88,7 @@ class SyncSocketSession {
|
|||||||
onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null,
|
onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null,
|
||||||
onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null,
|
onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null,
|
||||||
onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null,
|
onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null,
|
||||||
isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? = null
|
isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)? = null
|
||||||
) {
|
) {
|
||||||
_inputStream = inputStream
|
_inputStream = inputStream
|
||||||
_outputStream = outputStream
|
_outputStream = outputStream
|
||||||
@ -105,31 +106,35 @@ class SyncSocketSession {
|
|||||||
_localPublicKey = Base64.getEncoder().encodeToString(localPublicKey)
|
_localPublicKey = Base64.getEncoder().encodeToString(localPublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun startAsInitiator(remotePublicKey: String, pairingCode: String? = null) {
|
fun startAsInitiator(remotePublicKey: String, appId: UInt = 0u, pairingCode: String? = null) {
|
||||||
_started = true
|
_started = true
|
||||||
try {
|
_thread = Thread {
|
||||||
handshakeAsInitiator(remotePublicKey, pairingCode)
|
try {
|
||||||
_onHandshakeComplete?.invoke(this)
|
handshakeAsInitiator(remotePublicKey, appId, pairingCode)
|
||||||
receiveLoop()
|
_onHandshakeComplete?.invoke(this)
|
||||||
} catch (e: Throwable) {
|
receiveLoop()
|
||||||
Logger.e(TAG, "Failed to run as initiator", e)
|
} catch (e: Throwable) {
|
||||||
} finally {
|
Logger.e(TAG, "Failed to run as initiator", e)
|
||||||
stop()
|
} finally {
|
||||||
}
|
stop()
|
||||||
|
}
|
||||||
|
}.apply { start() }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun startAsResponder() {
|
fun startAsResponder() {
|
||||||
_started = true
|
_started = true
|
||||||
try {
|
_thread = Thread {
|
||||||
if (handshakeAsResponder()) {
|
try {
|
||||||
_onHandshakeComplete?.invoke(this)
|
if (handshakeAsResponder()) {
|
||||||
receiveLoop()
|
_onHandshakeComplete?.invoke(this)
|
||||||
|
receiveLoop()
|
||||||
|
}
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
Logger.e(TAG, "Failed to run as responder", e)
|
||||||
|
} finally {
|
||||||
|
stop()
|
||||||
}
|
}
|
||||||
} catch (e: Throwable) {
|
}.apply { start() }
|
||||||
Logger.e(TAG, "Failed to run as responder", e)
|
|
||||||
} finally {
|
|
||||||
stop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun receiveLoop() {
|
private fun receiveLoop() {
|
||||||
@ -187,12 +192,13 @@ class SyncSocketSession {
|
|||||||
_onClose?.invoke(this)
|
_onClose?.invoke(this)
|
||||||
_inputStream.close()
|
_inputStream.close()
|
||||||
_outputStream.close()
|
_outputStream.close()
|
||||||
|
_thread = null
|
||||||
_cipherStatePair?.sender?.destroy()
|
_cipherStatePair?.sender?.destroy()
|
||||||
_cipherStatePair?.receiver?.destroy()
|
_cipherStatePair?.receiver?.destroy()
|
||||||
Logger.i(TAG, "Session closed")
|
Logger.i(TAG, "Session closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handshakeAsInitiator(remotePublicKey: String, pairingCode: String?) {
|
private fun handshakeAsInitiator(remotePublicKey: String, appId: UInt, pairingCode: String?) {
|
||||||
performVersionCheck()
|
performVersionCheck()
|
||||||
|
|
||||||
val initiator = HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR)
|
val initiator = HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR)
|
||||||
@ -218,7 +224,8 @@ class SyncSocketSession {
|
|||||||
val mainBuffer = ByteArray(512)
|
val mainBuffer = ByteArray(512)
|
||||||
val mainLength = initiator.writeMessage(mainBuffer, 0, null, 0, 0)
|
val mainLength = initiator.writeMessage(mainBuffer, 0, null, 0, 0)
|
||||||
|
|
||||||
val messageData = ByteBuffer.allocate(4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN)
|
val messageData = ByteBuffer.allocate(4 + 4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
messageData.putInt(appId.toInt())
|
||||||
messageData.putInt(pairingMessageLength)
|
messageData.putInt(pairingMessageLength)
|
||||||
if (pairingMessageLength > 0) messageData.put(pairingMessage)
|
if (pairingMessageLength > 0) messageData.put(pairingMessage)
|
||||||
messageData.put(mainBuffer, 0, mainLength)
|
messageData.put(mainBuffer, 0, mainLength)
|
||||||
@ -250,9 +257,10 @@ class SyncSocketSession {
|
|||||||
_inputStream.readFully(message)
|
_inputStream.readFully(message)
|
||||||
val messageBuffer = ByteBuffer.wrap(message).order(ByteOrder.LITTLE_ENDIAN)
|
val messageBuffer = ByteBuffer.wrap(message).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
|
||||||
|
val appId = messageBuffer.int.toUInt()
|
||||||
val pairingMessageLength = messageBuffer.int
|
val pairingMessageLength = messageBuffer.int
|
||||||
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { messageBuffer.get(it) } else byteArrayOf()
|
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { messageBuffer.get(it) } else byteArrayOf()
|
||||||
val mainLength = messageSize - 4 - pairingMessageLength
|
val mainLength = messageSize - 4 - 4 - pairingMessageLength
|
||||||
val mainMessage = ByteArray(mainLength).also { messageBuffer.get(it) }
|
val mainMessage = ByteArray(mainLength).also { messageBuffer.get(it) }
|
||||||
|
|
||||||
var pairingCode: String? = null
|
var pairingCode: String? = null
|
||||||
@ -267,6 +275,15 @@ class SyncSocketSession {
|
|||||||
|
|
||||||
val plaintext = ByteArray(512)
|
val plaintext = ByteArray(512)
|
||||||
responder.readMessage(mainMessage, 0, mainLength, plaintext, 0)
|
responder.readMessage(mainMessage, 0, mainLength, plaintext, 0)
|
||||||
|
val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength)
|
||||||
|
responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0)
|
||||||
|
val remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes)
|
||||||
|
|
||||||
|
val isAllowedToConnect = remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, remotePublicKey, pairingCode, appId) ?: true)
|
||||||
|
if (!isAllowedToConnect) {
|
||||||
|
stop()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
val responseBuffer = ByteArray(512)
|
val responseBuffer = ByteArray(512)
|
||||||
val responseLength = responder.writeMessage(responseBuffer, 0, null, 0, 0)
|
val responseLength = responder.writeMessage(responseBuffer, 0, null, 0, 0)
|
||||||
@ -274,13 +291,8 @@ class SyncSocketSession {
|
|||||||
_outputStream.write(responseBuffer, 0, responseLength)
|
_outputStream.write(responseBuffer, 0, responseLength)
|
||||||
|
|
||||||
_cipherStatePair = responder.split()
|
_cipherStatePair = responder.split()
|
||||||
val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength)
|
_remotePublicKey = remotePublicKey
|
||||||
responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0)
|
return true
|
||||||
_remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes)
|
|
||||||
|
|
||||||
return (_remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, _remotePublicKey!!, pairingCode) ?: true)).also {
|
|
||||||
if (!it) stop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun performVersionCheck() {
|
private fun performVersionCheck() {
|
||||||
@ -400,13 +412,14 @@ class SyncSocketSession {
|
|||||||
val remoteVersion = data.int
|
val remoteVersion = data.int
|
||||||
val connectionId = data.long
|
val connectionId = data.long
|
||||||
val requestId = data.int
|
val requestId = data.int
|
||||||
|
val appId = data.int.toUInt()
|
||||||
val publicKeyBytes = ByteArray(32).also { data.get(it) }
|
val publicKeyBytes = ByteArray(32).also { data.get(it) }
|
||||||
val pairingMessageLength = data.int
|
val pairingMessageLength = data.int
|
||||||
if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128)")
|
if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128) (app id: $appId)")
|
||||||
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { data.get(it) } else ByteArray(0)
|
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { data.get(it) } else ByteArray(0)
|
||||||
val channelMessageLength = data.int
|
val channelMessageLength = data.int
|
||||||
if (data.remaining() != channelMessageLength) {
|
if (data.remaining() != channelMessageLength) {
|
||||||
Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()}")
|
Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()} (app id: $appId)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
|
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
|
||||||
@ -420,7 +433,7 @@ class SyncSocketSession {
|
|||||||
val length = pairingProtocol.readMessage(pairingMessage, 0, pairingMessageLength, plaintext, 0)
|
val length = pairingProtocol.readMessage(pairingMessage, 0, pairingMessageLength, plaintext, 0)
|
||||||
String(plaintext, 0, length, Charsets.UTF_8)
|
String(plaintext, 0, length, Charsets.UTF_8)
|
||||||
} else null
|
} else null
|
||||||
val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode) ?: true)
|
val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode, appId) ?: true)
|
||||||
if (!isAllowed) {
|
if (!isAllowed) {
|
||||||
val rp = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN)
|
val rp = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
rp.putInt(2) // Status code for not allowed
|
rp.putInt(2) // Status code for not allowed
|
||||||
@ -876,14 +889,14 @@ class SyncSocketSession {
|
|||||||
return deferred.await()
|
return deferred.await()
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun startRelayedChannel(publicKey: String, pairingCode: String? = null): ChannelRelayed? {
|
suspend fun startRelayedChannel(publicKey: String, appId: UInt = 0u, pairingCode: String? = null): ChannelRelayed? {
|
||||||
val requestId = generateRequestId()
|
val requestId = generateRequestId()
|
||||||
val deferred = CompletableDeferred<ChannelRelayed>()
|
val deferred = CompletableDeferred<ChannelRelayed>()
|
||||||
val channel = ChannelRelayed(this, _localKeyPair, publicKey, true)
|
val channel = ChannelRelayed(this, _localKeyPair, publicKey, true)
|
||||||
_onNewChannel?.invoke(this, channel)
|
_onNewChannel?.invoke(this, channel)
|
||||||
_pendingChannels[requestId] = channel to deferred
|
_pendingChannels[requestId] = channel to deferred
|
||||||
try {
|
try {
|
||||||
channel.sendRequestTransport(requestId, publicKey, pairingCode)
|
channel.sendRequestTransport(requestId, publicKey, appId, pairingCode)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_pendingChannels.remove(requestId)?.let { it.first.close(); it.second.completeExceptionally(e) }
|
_pendingChannels.remove(requestId)?.let { it.first.close(); it.second.completeExceptionally(e) }
|
||||||
throw e
|
throw e
|
||||||
|
@ -9,6 +9,7 @@ import com.futo.platformplayer.noise.protocol.HandshakeState
|
|||||||
import com.futo.platformplayer.noise.protocol.Noise
|
import com.futo.platformplayer.noise.protocol.Noise
|
||||||
import com.futo.platformplayer.states.StateSync
|
import com.futo.platformplayer.states.StateSync
|
||||||
import com.futo.platformplayer.sync.internal.IAuthorizable
|
import com.futo.platformplayer.sync.internal.IAuthorizable
|
||||||
|
import com.futo.platformplayer.sync.internal.Opcode
|
||||||
import com.futo.platformplayer.sync.internal.SyncSocketSession
|
import com.futo.platformplayer.sync.internal.SyncSocketSession
|
||||||
import com.futo.platformplayer.sync.internal.SyncStream
|
import com.futo.platformplayer.sync.internal.SyncStream
|
||||||
import junit.framework.TestCase.assertEquals
|
import junit.framework.TestCase.assertEquals
|
||||||
@ -586,16 +587,16 @@ class NoiseProtocolTest {
|
|||||||
handshakeLatch.await(10, TimeUnit.SECONDS)
|
handshakeLatch.await(10, TimeUnit.SECONDS)
|
||||||
|
|
||||||
// Simulate initiator sending a PING and responder replying with PONG
|
// Simulate initiator sending a PING and responder replying with PONG
|
||||||
initiatorSession.send(SyncSocketSession.Opcode.PING.value)
|
initiatorSession.send(Opcode.PING.value)
|
||||||
responderSession.send(SyncSocketSession.Opcode.PONG.value)
|
responderSession.send(Opcode.PONG.value)
|
||||||
|
|
||||||
// Test data transfer
|
// Test data transfer
|
||||||
responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesExactlyOnePacket)
|
responderSession.send(Opcode.DATA.value, 0u, randomBytesExactlyOnePacket)
|
||||||
initiatorSession.send(SyncSocketSession.Opcode.DATA.value, 1u, randomBytes)
|
initiatorSession.send(Opcode.DATA.value, 1u, randomBytes)
|
||||||
|
|
||||||
// Send large data to test stream handling
|
// Send large data to test stream handling
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesBig)
|
responderSession.send(Opcode.DATA.value, 0u, randomBytesBig)
|
||||||
println("Sent 10MB in ${System.currentTimeMillis() - start}ms")
|
println("Sent 10MB in ${System.currentTimeMillis() - start}ms")
|
||||||
|
|
||||||
// Wait for a brief period to simulate delay and allow communication
|
// Wait for a brief period to simulate delay and allow communication
|
||||||
|
Loading…
x
Reference in New Issue
Block a user