Implemented app id and updated unit tests.

This commit is contained in:
Koen J 2025-04-28 13:59:50 +02:00
parent daa91986ef
commit c4623c80ff
7 changed files with 664 additions and 57 deletions

View File

@ -3,19 +3,21 @@ package com.futo.platformplayer
import com.futo.platformplayer.noise.protocol.Noise
import com.futo.platformplayer.sync.internal.*
import kotlinx.coroutines.*
import kotlinx.coroutines.selects.select
import org.junit.Assert.*
import org.junit.Test
import java.net.Socket
import java.nio.ByteBuffer
import kotlin.random.Random
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
class SyncServerTests {
//private val relayHost = "relay.grayjay.app"
//private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw="
private val relayHost = "192.168.1.175"
private val relayHost = "192.168.1.138"
private val relayPort = 9000
/** Creates a client connected to the live relay server. */
@ -23,7 +25,8 @@ class SyncServerTests {
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null
isHandshakeAllowed: ((LinkType, SyncSocketSession, String, String?, UInt) -> Boolean)? = null,
onException: ((Throwable) -> Unit)? = null
): SyncSocketSession = withContext(Dispatchers.IO) {
val p = Noise.createDH("25519")
p.generateKeyPair()
@ -43,10 +46,14 @@ class SyncServerTests {
},
onData = onData ?: { _, _, _, _ -> },
onNewChannel = onNewChannel ?: { _, _ -> },
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true }
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _, _, _ -> true }
)
socketSession.authorizable = AlwaysAuthorized()
try {
socketSession.startAsInitiator(relayKey)
} catch (e: Throwable) {
onException?.invoke(e)
}
withTimeout(5000.milliseconds) { tcs.await() }
return@withContext socketSession
}
@ -259,6 +266,71 @@ class SyncServerTests {
clientA.stop()
clientB.stop()
}
@Test
fun relayedTransport_WithValidAppId_Success() = runBlocking {
// Arrange: Set up clients
val allowedAppId = 1234u
val tcsB = CompletableDeferred<ChannelRelayed>()
// Client B requires appId 1234
val clientB = createClient(
onNewChannel = { _, c -> tcsB.complete(c) },
isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId }
)
val clientA = createClient()
// Act: Start relayed channel with valid appId
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey, appId = allowedAppId) }
val channelB = withTimeout(5.seconds) { tcsB.await() }
withTimeout(5.seconds) { channelTask.await() }
// Assert: Channel is established
assertNotNull("Channel should be created on target with valid appId", channelB)
// Clean up
clientA.stop()
clientB.stop()
}
@Test
fun relayedTransport_WithInvalidAppId_Fails() = runBlocking {
// Arrange: Set up clients
val allowedAppId = 1234u
val invalidAppId = 5678u
val tcsB = CompletableDeferred<ChannelRelayed>()
// Client B requires appId 1234
val clientB = createClient(
onNewChannel = { _, c -> tcsB.complete(c) },
isHandshakeAllowed = { linkType, _, _, _, appId -> linkType == LinkType.Relayed && appId == allowedAppId },
onException = { }
)
val clientA = createClient()
// Act & Assert: Attempt with invalid appId should fail
try {
withTimeout(5.seconds) {
clientA.startRelayedChannel(clientB.localPublicKey, appId = invalidAppId)
}
fail("Starting relayed channel with invalid appId should fail")
} catch (e: Throwable) {
// Expected: The channel creation should time out or fail
}
// Ensure no channel was created on client B
val completedTask = select {
tcsB.onAwait { "channel" }
async { delay(1.seconds); "timeout" }.onAwait { "timeout" }
}
assertEquals("No channel should be created with invalid appId", "timeout", completedTask)
// Clean up
clientA.stop()
clientB.stop()
}
}
class AlwaysAuthorized : IAuthorizable {

View File

@ -0,0 +1,512 @@
package com.futo.platformplayer
import com.futo.platformplayer.noise.protocol.DHState
import com.futo.platformplayer.noise.protocol.Noise
import com.futo.platformplayer.sync.internal.*
import kotlinx.coroutines.*
import org.junit.Assert.*
import org.junit.Test
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.nio.ByteBuffer
import kotlin.random.Random
import java.io.InputStream
import java.io.OutputStream
import kotlin.time.Duration.Companion.seconds
data class PipeStreams(
val initiatorInput: LittleEndianDataInputStream,
val initiatorOutput: LittleEndianDataOutputStream,
val responderInput: LittleEndianDataInputStream,
val responderOutput: LittleEndianDataOutputStream
)
typealias OnHandshakeComplete = (SyncSocketSession) -> Unit
typealias IsHandshakeAllowed = (LinkType, SyncSocketSession, String, String?, UInt) -> Boolean
typealias OnClose = (SyncSocketSession) -> Unit
typealias OnData = (SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit
class SyncSocketTests {
private fun createPipeStreams(): PipeStreams {
val initiatorOutput = PipedOutputStream()
val responderOutput = PipedOutputStream()
val responderInput = PipedInputStream(initiatorOutput)
val initiatorInput = PipedInputStream(responderOutput)
return PipeStreams(
LittleEndianDataInputStream(initiatorInput), LittleEndianDataOutputStream(initiatorOutput),
LittleEndianDataInputStream(responderInput), LittleEndianDataOutputStream(responderOutput)
)
}
fun generateKeyPair(): DHState {
val p = Noise.createDH("25519")
p.generateKeyPair()
return p
}
private fun createSessions(
initiatorInput: LittleEndianDataInputStream,
initiatorOutput: LittleEndianDataOutputStream,
responderInput: LittleEndianDataInputStream,
responderOutput: LittleEndianDataOutputStream,
initiatorKeyPair: DHState,
responderKeyPair: DHState,
onInitiatorHandshakeComplete: OnHandshakeComplete,
onResponderHandshakeComplete: OnHandshakeComplete,
onInitiatorClose: OnClose? = null,
onResponderClose: OnClose? = null,
onClose: OnClose? = null,
isHandshakeAllowed: IsHandshakeAllowed? = null,
onDataA: OnData? = null,
onDataB: OnData? = null
): Pair<SyncSocketSession, SyncSocketSession> {
val initiatorSession = SyncSocketSession(
"", initiatorKeyPair, initiatorInput, initiatorOutput,
onClose = {
onClose?.invoke(it)
onInitiatorClose?.invoke(it)
},
onHandshakeComplete = onInitiatorHandshakeComplete,
onData = onDataA,
isHandshakeAllowed = isHandshakeAllowed
)
val responderSession = SyncSocketSession(
"", responderKeyPair, responderInput, responderOutput,
onClose = {
onClose?.invoke(it)
onResponderClose?.invoke(it)
},
onHandshakeComplete = onResponderHandshakeComplete,
onData = onDataB,
isHandshakeAllowed = isHandshakeAllowed
)
return Pair(initiatorSession, responderSession)
}
@Test
fun handshake_WithValidPairingCode_Succeeds(): Unit = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val validPairingCode = "secret"
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
)
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = validPairingCode)
responderSession.startAsResponder()
withTimeout(5.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
}
@Test
fun handshake_WithInvalidPairingCode_Fails() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val validPairingCode = "secret"
val invalidPairingCode = "wrong"
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val initiatorClosed = CompletableDeferred<Boolean>()
val responderClosed = CompletableDeferred<Boolean>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onInitiatorClose = {
initiatorClosed.complete(true)
},
onResponderClose = {
responderClosed.complete(true)
},
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
)
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = invalidPairingCode)
responderSession.startAsResponder()
withTimeout(100.seconds) {
initiatorClosed.await()
responderClosed.await()
}
assertFalse(handshakeInitiatorCompleted.isCompleted)
assertFalse(handshakeResponderCompleted.isCompleted)
}
@Test
fun handshake_WithoutPairingCodeWhenRequired_Fails() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val validPairingCode = "secret"
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val initiatorClosed = CompletableDeferred<Boolean>()
val responderClosed = CompletableDeferred<Boolean>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onInitiatorClose = {
initiatorClosed.complete(true)
},
onResponderClose = {
responderClosed.complete(true)
},
isHandshakeAllowed = { _, _, _, pairingCode, _ -> pairingCode == validPairingCode }
)
initiatorSession.startAsInitiator(responderSession.localPublicKey) // No pairing code
responderSession.startAsResponder()
withTimeout(5.seconds) {
initiatorClosed.await()
responderClosed.await()
}
assertFalse(handshakeInitiatorCompleted.isCompleted)
assertFalse(handshakeResponderCompleted.isCompleted)
}
@Test
fun handshake_WithPairingCodeWhenNotRequired_Succeeds(): Unit = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val pairingCode = "unnecessary"
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
isHandshakeAllowed = { _, _, _, _, _ -> true } // Always allow
)
initiatorSession.startAsInitiator(responderSession.localPublicKey, pairingCode = pairingCode)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
}
@Test
fun sendAndReceive_SmallDataPacket_Succeeds() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val tcsDataReceived = CompletableDeferred<ByteArray>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onDataB = { _, opcode, subOpcode, data ->
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
val b = ByteArray(data.remaining())
data.get(b)
tcsDataReceived.complete(b)
}
}
)
initiatorSession.startAsInitiator(responderSession.localPublicKey)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
// Ensure both sessions are authorized
initiatorSession.authorizable = Authorized()
responderSession.authorizable = Authorized()
val smallData = byteArrayOf(1, 2, 3)
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(smallData))
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
assertArrayEquals(smallData, receivedData)
}
@Test
fun sendAndReceive_ExactlyMaximumPacketSize_Succeeds() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val tcsDataReceived = CompletableDeferred<ByteArray>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onDataB = { _, opcode, subOpcode, data ->
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
val b = ByteArray(data.remaining())
data.get(b)
tcsDataReceived.complete(b)
}
}
)
initiatorSession.startAsInitiator(responderSession.localPublicKey)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
// Ensure both sessions are authorized
initiatorSession.authorizable = Authorized()
responderSession.authorizable = Authorized()
val maxData = ByteArray(SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE).apply { Random.nextBytes(this) }
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxData))
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
assertArrayEquals(maxData, receivedData)
}
@Test
fun stream_LargeData_Succeeds() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val tcsDataReceived = CompletableDeferred<ByteArray>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onDataB = { _, opcode, subOpcode, data ->
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
val b = ByteArray(data.remaining())
data.get(b)
tcsDataReceived.complete(b)
}
}
)
initiatorSession.startAsInitiator(responderSession.localPublicKey)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
// Ensure both sessions are authorized
initiatorSession.authorizable = Authorized()
responderSession.authorizable = Authorized()
val largeData = ByteArray(2 * (SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE)).apply { Random.nextBytes(this) }
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
assertArrayEquals(largeData, receivedData)
}
@Test
fun authorizedSession_CanSendData() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val tcsDataReceived = CompletableDeferred<ByteArray>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onDataB = { _, opcode, subOpcode, data ->
if (opcode == Opcode.DATA.value && subOpcode == 0u.toUByte()) {
val b = ByteArray(data.remaining())
data.get(b)
tcsDataReceived.complete(b)
}
}
)
initiatorSession.startAsInitiator(responderSession.localPublicKey)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
// Authorize both sessions
initiatorSession.authorizable = Authorized()
responderSession.authorizable = Authorized()
val data = byteArrayOf(1, 2, 3)
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data))
val receivedData = withTimeout(10.seconds) { tcsDataReceived.await() }
assertArrayEquals(data, receivedData)
}
@Test
fun unauthorizedSession_CannotSendData() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val tcsDataReceived = CompletableDeferred<ByteArray>()
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onDataB = { _, _, _, _ -> }
)
initiatorSession.startAsInitiator(responderSession.localPublicKey)
responderSession.startAsResponder()
withTimeout(10.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
// Authorize initiator but not responder
initiatorSession.authorizable = Authorized()
responderSession.authorizable = Unauthorized()
val data = byteArrayOf(1, 2, 3)
initiatorSession.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(data))
delay(1.seconds)
assertFalse(tcsDataReceived.isCompleted)
}
@Test
fun directHandshake_WithValidAppId_Succeeds() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val allowedAppId = 1234u
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt ->
linkType == LinkType.Direct && appId == allowedAppId
}
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
isHandshakeAllowed = responderIsHandshakeAllowed
)
initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = allowedAppId)
responderSession.startAsResponder()
withTimeout(5.seconds) {
handshakeInitiatorCompleted.await()
handshakeResponderCompleted.await()
}
assertNotNull(initiatorSession.remotePublicKey)
assertNotNull(responderSession.remotePublicKey)
}
@Test
fun directHandshake_WithInvalidAppId_Fails() = runBlocking {
val (initiatorInput, initiatorOutput, responderInput, responderOutput) = createPipeStreams()
val initiatorKeyPair = generateKeyPair()
val responderKeyPair = generateKeyPair()
val allowedAppId = 1234u
val invalidAppId = 5678u
val handshakeInitiatorCompleted = CompletableDeferred<Boolean>()
val handshakeResponderCompleted = CompletableDeferred<Boolean>()
val initiatorClosed = CompletableDeferred<Boolean>()
val responderClosed = CompletableDeferred<Boolean>()
val responderIsHandshakeAllowed = { linkType: LinkType, _: SyncSocketSession, _: String, _: String?, appId: UInt ->
linkType == LinkType.Direct && appId == allowedAppId
}
val (initiatorSession, responderSession) = createSessions(
initiatorInput, initiatorOutput, responderInput, responderOutput,
initiatorKeyPair, responderKeyPair,
{ handshakeInitiatorCompleted.complete(true) },
{ handshakeResponderCompleted.complete(true) },
onInitiatorClose = {
initiatorClosed.complete(true)
},
onResponderClose = {
responderClosed.complete(true)
},
isHandshakeAllowed = responderIsHandshakeAllowed
)
initiatorSession.startAsInitiator(responderSession.localPublicKey, appId = invalidAppId)
responderSession.startAsResponder()
withTimeout(5.seconds) {
initiatorClosed.await()
responderClosed.await()
}
assertFalse(handshakeInitiatorCompleted.isCompleted)
assertFalse(handshakeResponderCompleted.isCompleted)
}
}
class Authorized : IAuthorizable {
override val isAuthorized: Boolean = true
}
class Unauthorized : IAuthorizable {
override val isAuthorized: Boolean = false
}

View File

@ -69,7 +69,14 @@ fun warnIfMainThread(context: String) {
}
fun ensureNotMainThread() {
if (Looper.myLooper() == Looper.getMainLooper()) {
val isMainLooper = try {
Looper.myLooper() == Looper.getMainLooper()
} catch (e: Throwable) {
//Ignore, for unit tests where its not mocked
false
}
if (isMainLooper) {
Logger.e("Utility", "Throwing exception because a function that should not be called on main thread, is called on main thread")
throw IllegalStateException("Cannot run on main thread")
}

View File

@ -226,7 +226,7 @@ class StateSync {
keyPair!!,
LittleEndianDataInputStream(socket.getInputStream()),
LittleEndianDataOutputStream(socket.getOutputStream()),
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) },
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) },
onNewChannel = { _, c ->
val remotePublicKey = c.remotePublicKey
if (remotePublicKey == null) {
@ -297,7 +297,7 @@ class StateSync {
if (connectionInfo.allowRemoteRelayed && Settings.instance.synchronization.connectThroughRelay) {
try {
Log.v(TAG, "Attempting relayed connection with '$targetKey'.")
runBlocking { relaySession.startRelayedChannel(targetKey, null) }
runBlocking { relaySession.startRelayedChannel(targetKey, APP_ID, null) }
} catch (e: Throwable) {
Log.e(TAG, "Failed to start relayed channel with $targetKey.", e)
}
@ -318,7 +318,7 @@ class StateSync {
override val isAuthorized: Boolean get() = true
}
_relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, null)
_relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null)
Log.i(TAG, "Started relay session.")
} catch (e: Throwable) {
@ -731,8 +731,8 @@ class StateSync {
)
}
private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?): Boolean {
Log.v(TAG, "Check if handshake allowed from '$publicKey'.")
private fun isHandshakeAllowed(linkType: LinkType, syncSocketSession: SyncSocketSession, publicKey: String, pairingCode: String?, appId: UInt): Boolean {
Log.v(TAG, "Check if handshake allowed from '$publicKey' (app id: $appId).")
if (publicKey == RELAY_PUBLIC_KEY)
return true
@ -744,7 +744,7 @@ class StateSync {
}
}
Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode'.")
Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode' (app id: $appId).")
if (_pairingCode == null || pairingCode.isNullOrEmpty())
return false
@ -766,7 +766,7 @@ class StateSync {
if (channelSocket != null)
session?.removeChannel(channelSocket!!)
},
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode) },
isHandshakeAllowed = { linkType, syncSocketSession, publicKey, pairingCode, appId -> isHandshakeAllowed(linkType, syncSocketSession, publicKey, pairingCode, appId) },
onHandshakeComplete = { s ->
val remotePublicKey = s.remotePublicKey
if (remotePublicKey == null) {
@ -930,7 +930,7 @@ class StateSync {
_remotePendingStatusUpdate[deviceInfo.publicKey] = onStatusUpdate
}
}
relaySession.startRelayedChannel(deviceInfo.publicKey, deviceInfo.pairingCode)
relaySession.startRelayedChannel(deviceInfo.publicKey, APP_ID, deviceInfo.pairingCode)
}
} else {
throw e
@ -950,7 +950,7 @@ class StateSync {
}
}
session.startAsInitiator(publicKey, pairingCode)
session.startAsInitiator(publicKey, APP_ID, pairingCode)
return session
}
@ -1008,6 +1008,7 @@ class StateSync {
val version = 1
val RELAY_SERVER = "relay.grayjay.app"
val RELAY_PUBLIC_KEY = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
val APP_ID = 0x534A5247u //GRayJaySync (GRJS)
private const val TAG = "StateSync"
const val PORT = 12315

View File

@ -246,7 +246,7 @@ class ChannelRelayed(
}
}
fun sendRequestTransport(requestId: Int, publicKey: String, pairingCode: String? = null) {
fun sendRequestTransport(requestId: Int, publicKey: String, appId: UInt, pairingCode: String? = null) {
throwIfDisposed()
synchronized(sendLock) {
@ -270,10 +270,11 @@ class ChannelRelayed(
0 to ByteArray(0)
}
val packetSize = 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten
val packetSize = 4 + 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten
val packet = ByteArray(packetSize)
ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply {
putInt(requestId)
putInt(appId.toInt())
put(publicKeyBytes)
putInt(pairingMessageLength)
if (pairingMessageLength > 0) put(pairingMessage)

View File

@ -38,12 +38,13 @@ class SyncSocketSession {
private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)?
private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)?
private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)?
private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)?
private val _isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)?
private var _cipherStatePair: CipherStatePair? = null
private var _remotePublicKey: String? = null
val remotePublicKey: String? get() = _remotePublicKey
private var _started: Boolean = false
private val _localKeyPair: DHState
private var _thread: Thread? = null
private var _localPublicKey: String
val localPublicKey: String get() = _localPublicKey
private val _onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)?
@ -87,7 +88,7 @@ class SyncSocketSession {
onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null,
onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null,
onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null,
isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? = null
isHandshakeAllowed: ((linkType: LinkType, session: SyncSocketSession, remotePublicKey: String, pairingCode: String?, appId: UInt) -> Boolean)? = null
) {
_inputStream = inputStream
_outputStream = outputStream
@ -105,10 +106,11 @@ class SyncSocketSession {
_localPublicKey = Base64.getEncoder().encodeToString(localPublicKey)
}
fun startAsInitiator(remotePublicKey: String, pairingCode: String? = null) {
fun startAsInitiator(remotePublicKey: String, appId: UInt = 0u, pairingCode: String? = null) {
_started = true
_thread = Thread {
try {
handshakeAsInitiator(remotePublicKey, pairingCode)
handshakeAsInitiator(remotePublicKey, appId, pairingCode)
_onHandshakeComplete?.invoke(this)
receiveLoop()
} catch (e: Throwable) {
@ -116,10 +118,12 @@ class SyncSocketSession {
} finally {
stop()
}
}.apply { start() }
}
fun startAsResponder() {
_started = true
_thread = Thread {
try {
if (handshakeAsResponder()) {
_onHandshakeComplete?.invoke(this)
@ -130,6 +134,7 @@ class SyncSocketSession {
} finally {
stop()
}
}.apply { start() }
}
private fun receiveLoop() {
@ -187,12 +192,13 @@ class SyncSocketSession {
_onClose?.invoke(this)
_inputStream.close()
_outputStream.close()
_thread = null
_cipherStatePair?.sender?.destroy()
_cipherStatePair?.receiver?.destroy()
Logger.i(TAG, "Session closed")
}
private fun handshakeAsInitiator(remotePublicKey: String, pairingCode: String?) {
private fun handshakeAsInitiator(remotePublicKey: String, appId: UInt, pairingCode: String?) {
performVersionCheck()
val initiator = HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR)
@ -218,7 +224,8 @@ class SyncSocketSession {
val mainBuffer = ByteArray(512)
val mainLength = initiator.writeMessage(mainBuffer, 0, null, 0, 0)
val messageData = ByteBuffer.allocate(4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN)
val messageData = ByteBuffer.allocate(4 + 4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN)
messageData.putInt(appId.toInt())
messageData.putInt(pairingMessageLength)
if (pairingMessageLength > 0) messageData.put(pairingMessage)
messageData.put(mainBuffer, 0, mainLength)
@ -250,9 +257,10 @@ class SyncSocketSession {
_inputStream.readFully(message)
val messageBuffer = ByteBuffer.wrap(message).order(ByteOrder.LITTLE_ENDIAN)
val appId = messageBuffer.int.toUInt()
val pairingMessageLength = messageBuffer.int
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { messageBuffer.get(it) } else byteArrayOf()
val mainLength = messageSize - 4 - pairingMessageLength
val mainLength = messageSize - 4 - 4 - pairingMessageLength
val mainMessage = ByteArray(mainLength).also { messageBuffer.get(it) }
var pairingCode: String? = null
@ -267,6 +275,15 @@ class SyncSocketSession {
val plaintext = ByteArray(512)
responder.readMessage(mainMessage, 0, mainLength, plaintext, 0)
val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength)
responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0)
val remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes)
val isAllowedToConnect = remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, remotePublicKey, pairingCode, appId) ?: true)
if (!isAllowedToConnect) {
stop()
return false
}
val responseBuffer = ByteArray(512)
val responseLength = responder.writeMessage(responseBuffer, 0, null, 0, 0)
@ -274,13 +291,8 @@ class SyncSocketSession {
_outputStream.write(responseBuffer, 0, responseLength)
_cipherStatePair = responder.split()
val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength)
responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0)
_remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes)
return (_remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Direct, this, _remotePublicKey!!, pairingCode) ?: true)).also {
if (!it) stop()
}
_remotePublicKey = remotePublicKey
return true
}
private fun performVersionCheck() {
@ -400,13 +412,14 @@ class SyncSocketSession {
val remoteVersion = data.int
val connectionId = data.long
val requestId = data.int
val appId = data.int.toUInt()
val publicKeyBytes = ByteArray(32).also { data.get(it) }
val pairingMessageLength = data.int
if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128)")
if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128) (app id: $appId)")
val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { data.get(it) } else ByteArray(0)
val channelMessageLength = data.int
if (data.remaining() != channelMessageLength) {
Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()}")
Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()} (app id: $appId)")
return
}
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
@ -420,7 +433,7 @@ class SyncSocketSession {
val length = pairingProtocol.readMessage(pairingMessage, 0, pairingMessageLength, plaintext, 0)
String(plaintext, 0, length, Charsets.UTF_8)
} else null
val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode) ?: true)
val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(LinkType.Relayed, this, publicKey, pairingCode, appId) ?: true)
if (!isAllowed) {
val rp = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN)
rp.putInt(2) // Status code for not allowed
@ -876,14 +889,14 @@ class SyncSocketSession {
return deferred.await()
}
suspend fun startRelayedChannel(publicKey: String, pairingCode: String? = null): ChannelRelayed? {
suspend fun startRelayedChannel(publicKey: String, appId: UInt = 0u, pairingCode: String? = null): ChannelRelayed? {
val requestId = generateRequestId()
val deferred = CompletableDeferred<ChannelRelayed>()
val channel = ChannelRelayed(this, _localKeyPair, publicKey, true)
_onNewChannel?.invoke(this, channel)
_pendingChannels[requestId] = channel to deferred
try {
channel.sendRequestTransport(requestId, publicKey, pairingCode)
channel.sendRequestTransport(requestId, publicKey, appId, pairingCode)
} catch (e: Exception) {
_pendingChannels.remove(requestId)?.let { it.first.close(); it.second.completeExceptionally(e) }
throw e

View File

@ -9,6 +9,7 @@ import com.futo.platformplayer.noise.protocol.HandshakeState
import com.futo.platformplayer.noise.protocol.Noise
import com.futo.platformplayer.states.StateSync
import com.futo.platformplayer.sync.internal.IAuthorizable
import com.futo.platformplayer.sync.internal.Opcode
import com.futo.platformplayer.sync.internal.SyncSocketSession
import com.futo.platformplayer.sync.internal.SyncStream
import junit.framework.TestCase.assertEquals
@ -586,16 +587,16 @@ class NoiseProtocolTest {
handshakeLatch.await(10, TimeUnit.SECONDS)
// Simulate initiator sending a PING and responder replying with PONG
initiatorSession.send(SyncSocketSession.Opcode.PING.value)
responderSession.send(SyncSocketSession.Opcode.PONG.value)
initiatorSession.send(Opcode.PING.value)
responderSession.send(Opcode.PONG.value)
// Test data transfer
responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesExactlyOnePacket)
initiatorSession.send(SyncSocketSession.Opcode.DATA.value, 1u, randomBytes)
responderSession.send(Opcode.DATA.value, 0u, randomBytesExactlyOnePacket)
initiatorSession.send(Opcode.DATA.value, 1u, randomBytes)
// Send large data to test stream handling
val start = System.currentTimeMillis()
responderSession.send(SyncSocketSession.Opcode.DATA.value, 0u, randomBytesBig)
responderSession.send(Opcode.DATA.value, 0u, randomBytesBig)
println("Sent 10MB in ${System.currentTimeMillis() - start}ms")
// Wait for a brief period to simulate delay and allow communication