diff --git a/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt new file mode 100644 index 00000000..7348c3c7 --- /dev/null +++ b/app/src/androidTest/java/com/futo/platformplayer/SyncServerTests.kt @@ -0,0 +1,266 @@ +package com.futo.platformplayer + +import com.futo.platformplayer.noise.protocol.Noise +import com.futo.platformplayer.sync.internal.* +import kotlinx.coroutines.* +import org.junit.Assert.* +import org.junit.Test +import java.net.Socket +import java.nio.ByteBuffer +import kotlin.random.Random +import kotlin.time.Duration.Companion.milliseconds + +class SyncServerTests { + + //private val relayHost = "relay.grayjay.app" + //private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw=" + private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw=" + private val relayHost = "192.168.1.175" + private val relayPort = 9000 + + /** Creates a client connected to the live relay server. */ + private suspend fun createClient( + onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null, + onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null, + onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null, + isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null + ): SyncSocketSession = withContext(Dispatchers.IO) { + val p = Noise.createDH("25519") + p.generateKeyPair() + val socket = Socket(relayHost, relayPort) + val inputStream = LittleEndianDataInputStream(socket.getInputStream()) + val outputStream = LittleEndianDataOutputStream(socket.getOutputStream()) + val tcs = CompletableDeferred() + val socketSession = SyncSocketSession( + relayHost, + p, + inputStream, + outputStream, + onClose = { socket.close() }, + onHandshakeComplete = { s -> + onHandshakeComplete?.invoke(s) + tcs.complete(true) + }, + onData = onData ?: { _, _, _, _ -> }, + onNewChannel = onNewChannel ?: { _, _ -> }, + isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true } + ) + socketSession.authorizable = AlwaysAuthorized() + socketSession.startAsInitiator(relayKey) + withTimeout(5000.milliseconds) { tcs.await() } + return@withContext socketSession + } + + @Test + fun multipleClientsHandshake_Success() = runBlocking { + val client1 = createClient() + val client2 = createClient() + assertNotNull(client1.remotePublicKey, "Client 1 handshake failed") + assertNotNull(client2.remotePublicKey, "Client 2 handshake failed") + client1.stop() + client2.stop() + } + + @Test + fun publishAndRequestConnectionInfo_Authorized_Success() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val clientC = createClient() + clientA.publishConnectionInformation(arrayOf(clientB.localPublicKey), 12345, true, true, true, true) + delay(100.milliseconds) + val infoB = clientB.requestConnectionInfo(clientA.localPublicKey) + val infoC = clientC.requestConnectionInfo(clientA.localPublicKey) + assertNotNull("Client B should receive connection info", infoB) + assertEquals(12345.toUShort(), infoB!!.port) + assertNull("Client C should not receive connection info (unauthorized)", infoC) + clientA.stop() + clientB.stop() + clientC.stop() + } + + @Test + fun relayedTransport_Bidirectional_Success() = runBlocking { + val tcsA = CompletableDeferred() + val tcsB = CompletableDeferred() + val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) }) + val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) }) + val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) } + val channelA = withTimeout(5000.milliseconds) { tcsA.await() } + channelA.authorizable = AlwaysAuthorized() + val channelB = withTimeout(5000.milliseconds) { tcsB.await() } + channelB.authorizable = AlwaysAuthorized() + channelTask.await() + + val tcsDataB = CompletableDeferred() + channelB.setDataHandler { _, _, o, so, d -> + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b) + } + channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3))) + + val tcsDataA = CompletableDeferred() + channelA.setDataHandler { _, _, o, so, d -> + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b) + } + channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6))) + + val receivedB = withTimeout(5000.milliseconds) { tcsDataB.await() } + val receivedA = withTimeout(5000.milliseconds) { tcsDataA.await() } + assertArrayEquals(byteArrayOf(1, 2, 3), receivedB) + assertArrayEquals(byteArrayOf(4, 5, 6), receivedA) + clientA.stop() + clientB.stop() + } + + @Test + fun relayedTransport_MaximumMessageSize_Success() = runBlocking { + val MAX_DATA_PER_PACKET = SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE - 8 - 16 - 16 + val maxSizeData = ByteArray(MAX_DATA_PER_PACKET).apply { Random.nextBytes(this) } + val tcsA = CompletableDeferred() + val tcsB = CompletableDeferred() + val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) }) + val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) }) + val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) } + val channelA = withTimeout(5000.milliseconds) { tcsA.await() } + channelA.authorizable = AlwaysAuthorized() + val channelB = withTimeout(5000.milliseconds) { tcsB.await() } + channelB.authorizable = AlwaysAuthorized() + channelTask.await() + + val tcsDataB = CompletableDeferred() + channelB.setDataHandler { _, _, o, so, d -> + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b) + } + channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxSizeData)) + val receivedData = withTimeout(5000.milliseconds) { tcsDataB.await() } + assertArrayEquals(maxSizeData, receivedData) + clientA.stop() + clientB.stop() + } + + @Test + fun publishAndGetRecord_Success() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val clientC = createClient() + val data = byteArrayOf(1, 2, 3) + val success = clientA.publishRecords(listOf(clientB.localPublicKey), "testKey", data) + val recordB = clientB.getRecord(clientA.localPublicKey, "testKey") + val recordC = clientC.getRecord(clientA.localPublicKey, "testKey") + assertTrue(success) + assertNotNull(recordB) + assertArrayEquals(data, recordB!!.first) + assertNull("Unauthorized client should not access record", recordC) + clientA.stop() + clientB.stop() + clientC.stop() + } + + @Test + fun getNonExistentRecord_ReturnsNull() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val record = clientB.getRecord(clientA.localPublicKey, "nonExistentKey") + assertNull("Getting non-existent record should return null", record) + clientA.stop() + clientB.stop() + } + + @Test + fun updateRecord_TimestampUpdated() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val key = "updateKey" + val data1 = byteArrayOf(1) + val data2 = byteArrayOf(2) + clientA.publishRecords(listOf(clientB.localPublicKey), key, data1) + val record1 = clientB.getRecord(clientA.localPublicKey, key) + delay(1000.milliseconds) + clientA.publishRecords(listOf(clientB.localPublicKey), key, data2) + val record2 = clientB.getRecord(clientA.localPublicKey, key) + assertNotNull(record1) + assertNotNull(record2) + assertTrue(record2!!.second > record1!!.second) + assertArrayEquals(data2, record2.first) + clientA.stop() + clientB.stop() + } + + @Test + fun deleteRecord_Success() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val data = byteArrayOf(1, 2, 3) + clientA.publishRecords(listOf(clientB.localPublicKey), "toDelete", data) + val success = clientB.deleteRecords(clientA.localPublicKey, clientB.localPublicKey, listOf("toDelete")) + val record = clientB.getRecord(clientA.localPublicKey, "toDelete") + assertTrue(success) + assertNull(record) + clientA.stop() + clientB.stop() + } + + @Test + fun listRecordKeys_Success() = runBlocking { + val clientA = createClient() + val clientB = createClient() + val keys = arrayOf("key1", "key2", "key3") + keys.forEach { key -> + clientA.publishRecords(listOf(clientB.localPublicKey), key, byteArrayOf(1)) + } + val listedKeys = clientB.listRecordKeys(clientA.localPublicKey, clientB.localPublicKey) + assertArrayEquals(keys, listedKeys.map { it.first }.toTypedArray()) + clientA.stop() + clientB.stop() + } + + @Test + fun singleLargeMessageViaRelayedChannel_Success() = runBlocking { + val largeData = ByteArray(100000).apply { Random.nextBytes(this) } + val tcsA = CompletableDeferred() + val tcsB = CompletableDeferred() + val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) }) + val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) }) + val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) } + val channelA = withTimeout(5000.milliseconds) { tcsA.await() } + channelA.authorizable = AlwaysAuthorized() + val channelB = withTimeout(5000.milliseconds) { tcsB.await() } + channelB.authorizable = AlwaysAuthorized() + channelTask.await() + + val tcsDataB = CompletableDeferred() + channelB.setDataHandler { _, _, o, so, d -> + val b = ByteArray(d.remaining()) + d.get(b) + if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b) + } + channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData)) + val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() } + assertArrayEquals(largeData, receivedData) + clientA.stop() + clientB.stop() + } + + @Test + fun publishAndGetLargeRecord_Success() = runBlocking { + val largeData = ByteArray(1000000).apply { Random.nextBytes(this) } + val clientA = createClient() + val clientB = createClient() + val success = clientA.publishRecords(listOf(clientB.localPublicKey), "largeRecord", largeData) + val record = clientB.getRecord(clientA.localPublicKey, "largeRecord") + assertTrue(success) + assertNotNull(record) + assertArrayEquals(largeData, record!!.first) + clientA.stop() + clientB.stop() + } +} + +class AlwaysAuthorized : IAuthorizable { + override val isAuthorized: Boolean get() = true +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/Utility.kt b/app/src/main/java/com/futo/platformplayer/Utility.kt index 5cd5d26f..7ddefc79 100644 --- a/app/src/main/java/com/futo/platformplayer/Utility.kt +++ b/app/src/main/java/com/futo/platformplayer/Utility.kt @@ -28,12 +28,11 @@ import com.futo.platformplayer.models.PlatformVideoWithTime import com.futo.platformplayer.others.PlatformLinkMovementMethod import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream -import java.io.File import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.nio.ByteBuffer -import java.nio.ByteOrder +import java.security.SecureRandom import java.time.OffsetDateTime import java.util.* import java.util.concurrent.ThreadLocalRandom @@ -284,6 +283,18 @@ fun ByteBuffer.toUtf8String(): String { return String(remainingBytes, Charsets.UTF_8) } +fun generateReadablePassword(length: Int): String { + val validChars = "ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz23456789" + val secureRandom = SecureRandom() + val randomBytes = ByteArray(length) + secureRandom.nextBytes(randomBytes) + val sb = StringBuilder(length) + for (byte in randomBytes) { + val index = (byte.toInt() and 0xFF) % validChars.length + sb.append(validChars[index]) + } + return sb.toString() +} fun ByteArray.toGzip(): ByteArray { if (this == null || this.isEmpty()) return ByteArray(0) diff --git a/app/src/main/java/com/futo/platformplayer/activities/SyncHomeActivity.kt b/app/src/main/java/com/futo/platformplayer/activities/SyncHomeActivity.kt index d1cd7706..2b7e3a72 100644 --- a/app/src/main/java/com/futo/platformplayer/activities/SyncHomeActivity.kt +++ b/app/src/main/java/com/futo/platformplayer/activities/SyncHomeActivity.kt @@ -100,7 +100,8 @@ class SyncHomeActivity : AppCompatActivity() { private fun updateDeviceView(syncDeviceView: SyncDeviceView, publicKey: String, session: SyncSession?): SyncDeviceView { val connected = session?.connected ?: false - syncDeviceView.setLinkType(if (connected) LinkType.Local else LinkType.None) + + syncDeviceView.setLinkType(session?.linkType ?: LinkType.None) .setName(session?.displayName ?: StateSync.instance.getCachedName(publicKey) ?: publicKey) //TODO: also display public key? .setStatus(if (connected) "Connected" else "Disconnected") diff --git a/app/src/main/java/com/futo/platformplayer/activities/SyncPairActivity.kt b/app/src/main/java/com/futo/platformplayer/activities/SyncPairActivity.kt index a7030b97..5e808977 100644 --- a/app/src/main/java/com/futo/platformplayer/activities/SyncPairActivity.kt +++ b/app/src/main/java/com/futo/platformplayer/activities/SyncPairActivity.kt @@ -109,9 +109,9 @@ class SyncPairActivity : AppCompatActivity() { lifecycleScope.launch(Dispatchers.IO) { try { - StateSync.instance.connect(deviceInfo) { session, complete, message -> + StateSync.instance.connect(deviceInfo) { complete, message -> lifecycleScope.launch(Dispatchers.Main) { - if (complete) { + if (complete != null && complete) { _layoutPairingSuccess.visibility = View.VISIBLE _layoutPairing.visibility = View.GONE } else { diff --git a/app/src/main/java/com/futo/platformplayer/activities/SyncShowPairingCodeActivity.kt b/app/src/main/java/com/futo/platformplayer/activities/SyncShowPairingCodeActivity.kt index 2fbb4b97..b9f3d437 100644 --- a/app/src/main/java/com/futo/platformplayer/activities/SyncShowPairingCodeActivity.kt +++ b/app/src/main/java/com/futo/platformplayer/activities/SyncShowPairingCodeActivity.kt @@ -67,7 +67,7 @@ class SyncShowPairingCodeActivity : AppCompatActivity() { } val ips = getIPs() - val selfDeviceInfo = SyncDeviceInfo(StateSync.instance.publicKey!!, ips.toTypedArray(), StateSync.PORT) + val selfDeviceInfo = SyncDeviceInfo(StateSync.instance.publicKey!!, ips.toTypedArray(), StateSync.PORT, StateSync.instance.pairingCode) val json = Json.encodeToString(selfDeviceInfo) val base64 = Base64.encodeToString(json.toByteArray(), Base64.URL_SAFE or Base64.NO_PADDING or Base64.NO_WRAP) val url = "grayjay://sync/${base64}" diff --git a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt index 96c25f9d..57bcde5f 100644 --- a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt +++ b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt @@ -6,38 +6,62 @@ import com.futo.platformplayer.LittleEndianDataInputStream import com.futo.platformplayer.LittleEndianDataOutputStream import com.futo.platformplayer.Settings import com.futo.platformplayer.UIDialogs +import com.futo.platformplayer.activities.MainActivity import com.futo.platformplayer.activities.SyncShowPairingCodeActivity +import com.futo.platformplayer.api.media.Serializer import com.futo.platformplayer.constructs.Event1 import com.futo.platformplayer.constructs.Event2 import com.futo.platformplayer.encryption.GEncryptionProvider +import com.futo.platformplayer.generateReadablePassword import com.futo.platformplayer.getConnectedSocket import com.futo.platformplayer.logging.Logger import com.futo.platformplayer.mdns.DnsService import com.futo.platformplayer.mdns.ServiceDiscoverer +import com.futo.platformplayer.models.HistoryVideo +import com.futo.platformplayer.models.Subscription import com.futo.platformplayer.noise.protocol.DHState import com.futo.platformplayer.noise.protocol.Noise +import com.futo.platformplayer.smartMerge import com.futo.platformplayer.stores.FragmentedStorage import com.futo.platformplayer.stores.StringStringMapStorage import com.futo.platformplayer.stores.StringArrayStorage import com.futo.platformplayer.stores.StringStorage import com.futo.platformplayer.stores.StringTMapStorage import com.futo.platformplayer.sync.SyncSessionData +import com.futo.platformplayer.sync.internal.ChannelSocket import com.futo.platformplayer.sync.internal.GJSyncOpcodes +import com.futo.platformplayer.sync.internal.IAuthorizable +import com.futo.platformplayer.sync.internal.IChannel +import com.futo.platformplayer.sync.internal.Opcode import com.futo.platformplayer.sync.internal.SyncDeviceInfo import com.futo.platformplayer.sync.internal.SyncKeyPair import com.futo.platformplayer.sync.internal.SyncSession +import com.futo.platformplayer.sync.internal.SyncSession.Companion import com.futo.platformplayer.sync.internal.SyncSocketSession +import com.futo.platformplayer.sync.models.SendToDevicePackage +import com.futo.platformplayer.sync.models.SyncPlaylistsPackage +import com.futo.platformplayer.sync.models.SyncSubscriptionGroupsPackage +import com.futo.platformplayer.sync.models.SyncSubscriptionsPackage +import com.futo.platformplayer.sync.models.SyncWatchLaterPackage import com.futo.polycentric.core.base64ToByteArray import com.futo.polycentric.core.toBase64 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json +import java.io.ByteArrayInputStream +import java.lang.Thread.sleep import java.net.InetAddress import java.net.InetSocketAddress import java.net.ServerSocket import java.net.Socket +import java.nio.ByteBuffer +import java.nio.channels.Channel +import java.time.Instant +import java.time.OffsetDateTime +import java.time.ZoneOffset import java.util.Base64 import java.util.Locale import kotlin.system.measureTimeMillis @@ -59,13 +83,19 @@ class StateSync { //TODO: Should sync mdns and casting mdns be merged? //TODO: Decrease interval that devices are updated //TODO: Send less data - val _serviceDiscoverer = ServiceDiscoverer(arrayOf("_gsync._tcp.local")) { handleServiceUpdated(it) } + private val _serviceDiscoverer = ServiceDiscoverer(arrayOf("_gsync._tcp.local")) { handleServiceUpdated(it) } + private val _pairingCode: String? = generateReadablePassword(8) + val pairingCode: String? get() = _pairingCode + private var _relaySession: SyncSocketSession? = null + private var _threadRelay: Thread? = null var keyPair: DHState? = null var publicKey: String? = null val deviceRemoved: Event1 = Event1() val deviceUpdatedOrAdded: Event2 = Event2() + //TODO: Should authorize acknowledge be implemented? + fun hasAuthorizedDevice(): Boolean { synchronized(_sessions) { return _sessions.any{ it.value.connected && it.value.isAuthorized }; @@ -127,7 +157,7 @@ class StateSync { while (_started) { val socket = serverSocket.accept() - val session = createSocketSession(socket, true) { session, socketSession -> + val session = createSocketSession(socket, true) { session -> } @@ -164,7 +194,7 @@ class StateSync { for (connectPair in addressesToConnect) { try { - val syncDeviceInfo = SyncDeviceInfo(connectPair.first, arrayOf(connectPair.second), PORT) + val syncDeviceInfo = SyncDeviceInfo(connectPair.first, arrayOf(connectPair.second), PORT, null) val now = System.currentTimeMillis() val lastConnectTime = synchronized(_lastConnectTimesIp) { @@ -188,6 +218,138 @@ class StateSync { } }.apply { start() } } + + _threadRelay = Thread { + while (_started) { + try { + Log.i(TAG, "Starting relay session...") + + var socketClosed = false; + val socket = Socket(RELAY_SERVER, 9000) + _relaySession = SyncSocketSession( + (socket.remoteSocketAddress as InetSocketAddress).address.hostAddress!!, + keyPair!!, + LittleEndianDataInputStream(socket.getInputStream()), + LittleEndianDataOutputStream(socket.getOutputStream()), + isHandshakeAllowed = { _, pk, pairingCode -> + Log.v(TAG, "Check if handshake allowed from '$pk'.") + if (pk == RELAY_PUBLIC_KEY) + return@SyncSocketSession true + + synchronized(_authorizedDevices) { + if (_authorizedDevices.values.contains(pk)) + return@SyncSocketSession true + } + + Log.v(TAG, "Check if handshake allowed with pairing code '$pairingCode' with active pairing code '$_pairingCode'.") + if (_pairingCode == null || pairingCode.isNullOrEmpty()) + return@SyncSocketSession false + + _pairingCode == pairingCode + }, + onNewChannel = { _, c -> + val remotePublicKey = c.remotePublicKey + if (remotePublicKey == null) { + Log.e(TAG, "Remote public key should never be null in onNewChannel.") + return@SyncSocketSession + } + + Log.i(TAG, "New channel established from relay (pk: '$remotePublicKey').") + + var session: SyncSession? + synchronized(_sessions) { + session = _sessions[remotePublicKey] + if (session == null) { + val remoteDeviceName = synchronized(_nameStorage) { + _nameStorage.get(remotePublicKey) + } + session = createNewSyncSession(remotePublicKey, remoteDeviceName) { } + _sessions[remotePublicKey] = session!! + } + session!!.addChannel(c) + } + + c.setDataHandler { _, channel, opcode, subOpcode, data -> + session?.handlePacket(opcode, subOpcode, data) + } + c.setCloseHandler { channel -> + session?.removeChannel(channel) + } + }, + onChannelEstablished = { _, channel, isResponder -> + handleAuthorization(channel, isResponder) + }, + onClose = { socketClosed = true }, + onHandshakeComplete = { relaySession -> + Thread { + try { + while (_started && !socketClosed) { + val unconnectedAuthorizedDevices = synchronized(_authorizedDevices) { + _authorizedDevices.values.filter { !isConnected(it) }.toTypedArray() + } + + relaySession.publishConnectionInformation(unconnectedAuthorizedDevices, PORT, true, false, false, true) + + val connectionInfos = runBlocking { relaySession.requestBulkConnectionInfo(unconnectedAuthorizedDevices) } + + for ((targetKey, connectionInfo) in connectionInfos) { + val potentialLocalAddresses = connectionInfo.ipv4Addresses.union(connectionInfo.ipv6Addresses) + .filter { it != connectionInfo.remoteIp } + if (connectionInfo.allowLocalDirect) { + Thread { + try { + Log.v(TAG, "Attempting to connect directly, locally to '$targetKey'.") + connect(potentialLocalAddresses.map { it }.toTypedArray(), PORT, targetKey, null) + } catch (e: Throwable) { + Log.e(TAG, "Failed to start direct connection using connection info with $targetKey.", e) + } + }.start() + } + + if (connectionInfo.allowRemoteDirect) { + // TODO: Implement direct remote connection if needed + } + + if (connectionInfo.allowRemoteHolePunched) { + // TODO: Implement hole punching if needed + } + + if (connectionInfo.allowRemoteProxied) { + try { + Log.v(TAG, "Attempting relayed connection with '$targetKey'.") + runBlocking { relaySession.startRelayedChannel(targetKey, null) } + } catch (e: Throwable) { + Log.e(TAG, "Failed to start relayed channel with $targetKey.", e) + } + } + } + + Thread.sleep(15000) + } + } catch (e: Throwable) { + Log.e(TAG, "Unhandled exception in relay session.", e) + relaySession.stop() + } + }.start() + } + ) + + _relaySession!!.authorizable = object : IAuthorizable { + override val isAuthorized: Boolean get() = true + } + + _relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, null) + + Log.i(TAG, "Started relay session.") + } catch (e: Throwable) { + Log.e(TAG, "Relay session failed.", e) + Thread.sleep(5000) + } finally { + _relaySession?.stop() + _relaySession = null + } + } + }.apply { start() } } private fun getDeviceName(): String { @@ -219,14 +381,14 @@ class StateSync { } } fun getSessions(): List { - return synchronized(_sessions) { + synchronized(_sessions) { return _sessions.values.toList() - }; + } } fun getAuthorizedSessions(): List { - return synchronized(_sessions) { + synchronized(_sessions) { return _sessions.values.filter { it.isAuthorized }.toList() - }; + } } fun getSyncSessionData(key: String): SyncSessionData { @@ -253,7 +415,7 @@ class StateSync { val urlSafePkey = s.texts.firstOrNull { it.startsWith("pk=") }?.substring("pk=".length) ?: continue val pkey = Base64.getEncoder().encodeToString(Base64.getDecoder().decode(urlSafePkey.replace('-', '+').replace('_', '/'))) - val syncDeviceInfo = SyncDeviceInfo(pkey, addresses, port) + val syncDeviceInfo = SyncDeviceInfo(pkey, addresses, port, null) val authorized = isAuthorized(pkey) if (authorized && !isConnected(pkey)) { @@ -288,11 +450,313 @@ class StateSync { deviceRemoved.emit(remotePublicKey) } - private fun createSocketSession(socket: Socket, isResponder: Boolean, onAuthorized: (session: SyncSession, socketSession: SyncSocketSession) -> Unit): SyncSocketSession { + + private fun handleSyncSubscriptionPackage(origin: SyncSession, pack: SyncSubscriptionsPackage) { + val added = mutableListOf() + for(sub in pack.subscriptions) { + if(!StateSubscriptions.instance.isSubscribed(sub.channel)) { + val removalTime = StateSubscriptions.instance.getSubscriptionRemovalTime(sub.channel.url); + if(sub.creationTime > removalTime) { + val newSub = StateSubscriptions.instance.addSubscription(sub.channel, sub.creationTime); + added.add(newSub); + } + } + } + if(added.size > 3) + UIDialogs.appToast("${added.size} Subscriptions from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}"); + else if(added.size > 0) + UIDialogs.appToast("Subscriptions from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}:\n" + + added.map { it.channel.name }.joinToString("\n")); + + + if(pack.subscriptions.isNotEmpty()) { + for (subRemoved in pack.subscriptionRemovals) { + val removed = StateSubscriptions.instance.applySubscriptionRemovals(pack.subscriptionRemovals); + if(removed.size > 3) { + UIDialogs.appToast("Removed ${removed.size} Subscriptions from ${origin.remotePublicKey.substring(0, 8.coerceAtMost(origin.remotePublicKey.length))}"); + } else if(removed.isNotEmpty()) { + UIDialogs.appToast("Subscriptions removed from ${origin.remotePublicKey.substring(0, 8.coerceAtMost(origin.remotePublicKey.length))}:\n" + removed.map { it.channel.name }.joinToString("\n")); + } + } + } + } + + private fun handleData(session: SyncSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + val remotePublicKey = session.remotePublicKey + when (subOpcode) { + GJSyncOpcodes.sendToDevices -> { + StateApp.instance.scopeOrNull?.launch(Dispatchers.Main) { + val context = StateApp.instance.contextOrNull; + if (context != null && context is MainActivity) { + val dataBody = ByteArray(data.remaining()); + val remainder = data.remaining(); + data.get(dataBody, 0, remainder); + val json = String(dataBody, Charsets.UTF_8); + val obj = Json.decodeFromString(json); + UIDialogs.appToast("Received url from device [${session.remotePublicKey}]:\n{${obj.url}"); + context.handleUrl(obj.url, obj.position); + } + }; + } + + GJSyncOpcodes.syncStateExchange -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val syncSessionData = Serializer.json.decodeFromString(json); + + Logger.i(TAG, "Received SyncSessionData from $remotePublicKey"); + + + session.sendData(GJSyncOpcodes.syncSubscriptions, StateSubscriptions.instance.getSyncSubscriptionsPackageString()); + session.sendData(GJSyncOpcodes.syncSubscriptionGroups, StateSubscriptionGroups.instance.getSyncSubscriptionGroupsPackageString()); + session.sendData(GJSyncOpcodes.syncPlaylists, StatePlaylists.instance.getSyncPlaylistsPackageString()) + + session.sendData(GJSyncOpcodes.syncWatchLater, Json.encodeToString(StatePlaylists.instance.getWatchLaterSyncPacket(false))); + + val recentHistory = StateHistory.instance.getRecentHistory(syncSessionData.lastHistory); + if(recentHistory.isNotEmpty()) + session.sendJsonData(GJSyncOpcodes.syncHistory, recentHistory); + } + + GJSyncOpcodes.syncExport -> { + val dataBody = ByteArray(data.remaining()); + val bytesStr = ByteArrayInputStream(data.array(), data.position(), data.remaining()); + bytesStr.use { bytesStrBytes -> + val exportStruct = StateBackup.ExportStructure.fromZipBytes(bytesStrBytes); + for (store in exportStruct.stores) { + if (store.key.equals("subscriptions", true)) { + val subStore = + StateSubscriptions.instance.getUnderlyingSubscriptionsStore(); + StateApp.instance.scopeOrNull?.launch(Dispatchers.IO) { + val pack = SyncSubscriptionsPackage( + store.value.map { + subStore.fromReconstruction(it, exportStruct.cache) + }, + StateSubscriptions.instance.getSubscriptionRemovals() + ); + handleSyncSubscriptionPackage(session, pack); + } + } + } + } + } + + GJSyncOpcodes.syncSubscriptions -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val subPackage = Serializer.json.decodeFromString(json); + handleSyncSubscriptionPackage(session, subPackage); + + val newestSub = subPackage.subscriptions.maxOf { it.creationTime }; + + val sesData = getSyncSessionData(remotePublicKey); + if(newestSub > sesData.lastSubscription) { + sesData.lastSubscription = newestSub; + saveSyncSessionData(sesData); + } + } + + GJSyncOpcodes.syncSubscriptionGroups -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val pack = Serializer.json.decodeFromString(json); + + var lastSubgroupChange = OffsetDateTime.MIN; + for(group in pack.groups){ + if(group.lastChange > lastSubgroupChange) + lastSubgroupChange = group.lastChange; + + val existing = StateSubscriptionGroups.instance.getSubscriptionGroup(group.id); + + if(existing == null) + StateSubscriptionGroups.instance.updateSubscriptionGroup(group, false, true); + else if(existing.lastChange < group.lastChange) { + existing.name = group.name; + existing.urls = group.urls; + existing.image = group.image; + existing.priority = group.priority; + existing.lastChange = group.lastChange; + StateSubscriptionGroups.instance.updateSubscriptionGroup(existing, false, true); + } + } + for(removal in pack.groupRemovals) { + val creation = StateSubscriptionGroups.instance.getSubscriptionGroup(removal.key); + val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value, 0), ZoneOffset.UTC); + if(creation != null && creation.creationTime < removalTime) + StateSubscriptionGroups.instance.deleteSubscriptionGroup(removal.key, false); + } + } + + GJSyncOpcodes.syncPlaylists -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val pack = Serializer.json.decodeFromString(json); + + for(playlist in pack.playlists) { + val existing = StatePlaylists.instance.getPlaylist(playlist.id); + + if(existing == null) + StatePlaylists.instance.createOrUpdatePlaylist(playlist, false); + else if(existing.dateUpdate.toLocalDateTime() < playlist.dateUpdate.toLocalDateTime()) { + existing.dateUpdate = playlist.dateUpdate; + existing.name = playlist.name; + existing.videos = playlist.videos; + existing.dateCreation = playlist.dateCreation; + existing.datePlayed = playlist.datePlayed; + StatePlaylists.instance.createOrUpdatePlaylist(existing, false); + } + } + for(removal in pack.playlistRemovals) { + val creation = StatePlaylists.instance.getPlaylist(removal.key); + val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value, 0), ZoneOffset.UTC); + if(creation != null && creation.dateCreation < removalTime) + StatePlaylists.instance.removePlaylist(creation, false); + + } + } + + GJSyncOpcodes.syncWatchLater -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val pack = Serializer.json.decodeFromString(json); + + Logger.i(TAG, "SyncWatchLater received ${pack.videos.size} (${pack.videoAdds?.size}, ${pack.videoRemovals?.size})"); + + val allExisting = StatePlaylists.instance.getWatchLater(); + for(video in pack.videos) { + val existing = allExisting.firstOrNull { it.url == video.url }; + val time = if(pack.videoAdds != null && pack.videoAdds.containsKey(video.url)) OffsetDateTime.ofInstant(Instant.ofEpochSecond(pack.videoAdds[video.url] ?: 0), ZoneOffset.UTC) else OffsetDateTime.MIN; + + if(existing == null) { + StatePlaylists.instance.addToWatchLater(video, false); + if(time > OffsetDateTime.MIN) + StatePlaylists.instance.setWatchLaterAddTime(video.url, time); + } + } + for(removal in pack.videoRemovals) { + val watchLater = allExisting.firstOrNull { it.url == removal.key } ?: continue; + val creation = StatePlaylists.instance.getWatchLaterRemovalTime(watchLater.url) ?: OffsetDateTime.MIN; + val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value), ZoneOffset.UTC); + if(creation < removalTime) + StatePlaylists.instance.removeFromWatchLater(watchLater, false, removalTime); + } + + val packReorderTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(pack.reorderTime), ZoneOffset.UTC); + val localReorderTime = StatePlaylists.instance.getWatchLaterLastReorderTime(); + if(localReorderTime < packReorderTime && pack.ordering != null) { + StatePlaylists.instance.updateWatchLaterOrdering(smartMerge(pack.ordering!!, StatePlaylists.instance.getWatchLaterOrdering()), true); + } + } + + GJSyncOpcodes.syncHistory -> { + val dataBody = ByteArray(data.remaining()); + data.get(dataBody); + val json = String(dataBody, Charsets.UTF_8); + val history = Serializer.json.decodeFromString>(json); + Logger.i(TAG, "SyncHistory received ${history.size} videos from ${remotePublicKey}"); + + var lastHistory = OffsetDateTime.MIN; + for(video in history){ + val hist = StateHistory.instance.getHistoryByVideo(video.video, true, video.date); + if(hist != null) + StateHistory.instance.updateHistoryPosition(video.video, hist, true, video.position, video.date) + if(lastHistory < video.date) + lastHistory = video.date; + } + + if(lastHistory != OffsetDateTime.MIN && history.size > 1) { + val sesData = getSyncSessionData(remotePublicKey); + if (lastHistory > sesData.lastHistory) { + sesData.lastHistory = lastHistory; + saveSyncSessionData(sesData); + } + } + } + } + } + + private fun createNewSyncSession(remotePublicKey: String, remoteDeviceName: String?, onAuthorized: ((SyncSession) -> Unit)?): SyncSession { + return SyncSession( + remotePublicKey, + onAuthorized = { it, isNewlyAuthorized, isNewSession -> + if (!isNewSession) { + return@SyncSession + } + + it.remoteDeviceName?.let { remoteDeviceName -> + synchronized(_nameStorage) { + _nameStorage.setAndSave(remotePublicKey, remoteDeviceName) + } + } + + Logger.i(TAG, "${remotePublicKey} authorized (name: ${it.displayName})") + onAuthorized?.invoke(it) + _authorizedDevices.addDistinct(remotePublicKey) + _authorizedDevices.save() + deviceUpdatedOrAdded.emit(it.remotePublicKey, it) + + checkForSync(it); + }, + onUnauthorized = { + unauthorize(remotePublicKey) + + synchronized(_sessions) { + it.close() + _sessions.remove(remotePublicKey) + } + }, + onConnectedChanged = { it, connected -> + Logger.i(TAG, "$remotePublicKey connected: " + connected) + deviceUpdatedOrAdded.emit(it.remotePublicKey, it) + }, + onClose = { + Logger.i(TAG, "$remotePublicKey closed") + + synchronized(_sessions) + { + _sessions.remove(it.remotePublicKey) + } + + deviceRemoved.emit(it.remotePublicKey) + }, + dataHandler = { it, opcode, subOpcode, data -> + handleData(it, opcode, subOpcode, data) + }, + remoteDeviceName + ) + } + + private fun createSocketSession(socket: Socket, isResponder: Boolean, onAuthorized: (session: SyncSession) -> Unit): SyncSocketSession { var session: SyncSession? = null - return SyncSocketSession((socket.remoteSocketAddress as InetSocketAddress).address.hostAddress!!, keyPair!!, LittleEndianDataInputStream(socket.getInputStream()), LittleEndianDataOutputStream(socket.getOutputStream()), + var channelSocket: ChannelSocket? = null + return SyncSocketSession( + (socket.remoteSocketAddress as InetSocketAddress).address.hostAddress!!, + keyPair!!, + LittleEndianDataInputStream(socket.getInputStream()), + LittleEndianDataOutputStream(socket.getOutputStream()), onClose = { s -> - session?.removeSocketSession(s) + if (channelSocket != null) + session?.removeChannel(channelSocket!!) + }, + isHandshakeAllowed = { _, pk, pairingCode -> + Logger.v(TAG, "Check if handshake allowed from '${pk}'.") + + synchronized (_authorizedDevices) + { + if (_authorizedDevices.values.contains(pk)) + return@SyncSocketSession true + } + + Logger.v(TAG, "Check if handshake allowed with pairing code '${pairingCode}' with active pairing code '${_pairingCode}'."); + if (_pairingCode == null || pairingCode.isNullOrEmpty()) + return@SyncSocketSession false + + return@SyncSocketSession _pairingCode == pairingCode }, onHandshakeComplete = { s -> val remotePublicKey = s.remotePublicKey @@ -303,6 +767,8 @@ class StateSync { Logger.i(TAG, "Handshake complete with (LocalPublicKey = ${s.localPublicKey}, RemotePublicKey = ${s.remotePublicKey})") + channelSocket = ChannelSocket(s) + synchronized(_sessions) { session = _sessions[s.remotePublicKey] if (session == null) { @@ -310,126 +776,99 @@ class StateSync { _nameStorage.get(remotePublicKey) } - session = SyncSession(remotePublicKey, onAuthorized = { it, isNewlyAuthorized, isNewSession -> - if (!isNewSession) { - return@SyncSession - } + synchronized(_lastAddressStorage) { + _lastAddressStorage.setAndSave(remotePublicKey, s.remoteAddress) + } - it.remoteDeviceName?.let { remoteDeviceName -> - synchronized(_nameStorage) { - _nameStorage.setAndSave(remotePublicKey, remoteDeviceName) - } - } - - Logger.i(TAG, "${s.remotePublicKey} authorized (name: ${it.displayName})") - synchronized(_lastAddressStorage) { - _lastAddressStorage.setAndSave(remotePublicKey, s.remoteAddress) - } - - onAuthorized(it, s) - _authorizedDevices.addDistinct(remotePublicKey) - _authorizedDevices.save() - deviceUpdatedOrAdded.emit(it.remotePublicKey, session!!) - - checkForSync(it); - }, onUnauthorized = { - unauthorize(remotePublicKey) - - synchronized(_sessions) { - session?.close() - _sessions.remove(remotePublicKey) - } - }, onConnectedChanged = { it, connected -> - Logger.i(TAG, "${s.remotePublicKey} connected: " + connected) - deviceUpdatedOrAdded.emit(it.remotePublicKey, session!!) - }, onClose = { - Logger.i(TAG, "${s.remotePublicKey} closed") - - synchronized(_sessions) - { - _sessions.remove(it.remotePublicKey) - } - - deviceRemoved.emit(it.remotePublicKey) - - }, remoteDeviceName) + session = createNewSyncSession(remotePublicKey, remoteDeviceName, onAuthorized) _sessions[remotePublicKey] = session!! } - session!!.addSocketSession(s) + session!!.addChannel(channelSocket!!) } - if (isResponder) { - val isAuthorized = synchronized(_authorizedDevices) { - _authorizedDevices.values.contains(remotePublicKey) - } - - if (!isAuthorized) { - val scope = StateApp.instance.scopeOrNull - val activity = SyncShowPairingCodeActivity.activity - - if (scope != null && activity != null) { - scope.launch(Dispatchers.Main) { - UIDialogs.showConfirmationDialog(activity, "Allow connection from ${remotePublicKey}?", action = { - scope.launch(Dispatchers.IO) { - try { - session!!.authorize(s) - Logger.i(TAG, "Connection authorized for $remotePublicKey by confirmation") - } catch (e: Throwable) { - Logger.e(TAG, "Failed to send authorize", e) - } - } - }, cancelAction = { - scope.launch(Dispatchers.IO) { - try { - unauthorize(remotePublicKey) - } catch (e: Throwable) { - Logger.w(TAG, "Failed to send unauthorize", e) - } - - synchronized(_sessions) { - session?.close() - _sessions.remove(remotePublicKey) - } - } - }) - } - } else { - val publicKey = session!!.remotePublicKey - session!!.unauthorize(s) - session!!.close() - - synchronized(_sessions) { - _sessions.remove(publicKey) - } - - Logger.i(TAG, "Connection unauthorized for ${remotePublicKey} because not authorized and not on pairing activity to ask") - } - } else { - //Responder does not need to check because already approved - session!!.authorize(s) - Logger.i(TAG, "Connection authorized for ${remotePublicKey} because already authorized") - } - } else { - //Initiator does not need to check because the manual action of scanning the QR counts as approval - session!!.authorize(s) - Logger.i(TAG, "Connection authorized for ${remotePublicKey} because initiator") - } + handleAuthorization(channelSocket!!, isResponder) }, onData = { s, opcode, subOpcode, data -> - session?.handlePacket(s, opcode, subOpcode, data) - }) + session?.handlePacket(opcode, subOpcode, data) + } + ) + } + + private fun handleAuthorization(channel: IChannel, isResponder: Boolean) { + val syncSession = channel.syncSession!! + val remotePublicKey = channel.remotePublicKey!! + + if (isResponder) { + val isAuthorized = synchronized(_authorizedDevices) { + _authorizedDevices.values.contains(remotePublicKey) + } + + if (!isAuthorized) { + val scope = StateApp.instance.scopeOrNull + val activity = SyncShowPairingCodeActivity.activity + + if (scope != null && activity != null) { + scope.launch(Dispatchers.Main) { + UIDialogs.showConfirmationDialog(activity, "Allow connection from ${remotePublicKey}?", + action = { + scope.launch(Dispatchers.IO) { + try { + syncSession.authorize() + Logger.i(TAG, "Connection authorized for $remotePublicKey by confirmation") + } catch (e: Throwable) { + Logger.e(TAG, "Failed to send authorize", e) + } + } + }, + cancelAction = { + scope.launch(Dispatchers.IO) { + try { + unauthorize(remotePublicKey) + } catch (e: Throwable) { + Logger.w(TAG, "Failed to send unauthorize", e) + } + + syncSession.close() + synchronized(_sessions) { + _sessions.remove(remotePublicKey) + } + } + } + ) + } + } else { + val publicKey = syncSession.remotePublicKey + syncSession.unauthorize() + syncSession.close() + + synchronized(_sessions) { + _sessions.remove(publicKey) + } + + Logger.i(TAG, "Connection unauthorized for $remotePublicKey because not authorized and not on pairing activity to ask") + } + } else { + //Responder does not need to check because already approved + syncSession.authorize() + Logger.i(TAG, "Connection authorized for $remotePublicKey because already authorized") + } + } else { + //Initiator does not need to check because the manual action of scanning the QR counts as approval + syncSession.authorize() + Logger.i(TAG, "Connection authorized for $remotePublicKey because initiator") + } } inline fun broadcastJsonData(subOpcode: UByte, data: T) { - broadcast(SyncSocketSession.Opcode.DATA.value, subOpcode, Json.encodeToString(data)); + broadcast(Opcode.DATA.value, subOpcode, Json.encodeToString(data)); } fun broadcastData(subOpcode: UByte, data: String) { - broadcast(SyncSocketSession.Opcode.DATA.value, subOpcode, data.toByteArray(Charsets.UTF_8)); + broadcast(Opcode.DATA.value, subOpcode, ByteBuffer.wrap(data.toByteArray(Charsets.UTF_8))); } fun broadcast(opcode: UByte, subOpcode: UByte, data: String) { - broadcast(opcode, subOpcode, data.toByteArray(Charsets.UTF_8)); + broadcast(opcode, subOpcode, ByteBuffer.wrap(data.toByteArray(Charsets.UTF_8))); } - fun broadcast(opcode: UByte, subOpcode: UByte, data: ByteArray) { + fun broadcast(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { for(session in getAuthorizedSessions()) { try { session.send(opcode, subOpcode, data); @@ -456,21 +895,46 @@ class StateSync { _serverSocket?.close() _serverSocket = null - //_thread?.join() + _thread?.interrupt() _thread = null + _connectThread?.interrupt() _connectThread = null + _threadRelay?.interrupt() + _threadRelay = null + + _relaySession?.stop() + _relaySession = null } - fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((session: SyncSocketSession?, complete: Boolean, message: String) -> Unit)? = null): SyncSocketSession { - onStatusUpdate?.invoke(null, false, "Connecting...") - val socket = getConnectedSocket(deviceInfo.addresses.map { InetAddress.getByName(it) }, deviceInfo.port) ?: throw Exception("Failed to connect") - onStatusUpdate?.invoke(null, false, "Handshaking...") + fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) { + try { + connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate) + } catch (e: Throwable) { + Logger.e(TAG, "Failed to connect directly", e) + val relaySession = _relaySession + if (relaySession != null) { + onStatusUpdate?.invoke(null, "Connecting via relay...") - val session = createSocketSession(socket, false) { _, ss -> - onStatusUpdate?.invoke(ss, true, "Handshake complete") + runBlocking { + relaySession.startRelayedChannel(deviceInfo.publicKey, deviceInfo.pairingCode) + onStatusUpdate?.invoke(true, "Connected") + } + } else { + throw Exception("Failed to connect.") + } + } + } + + fun connect(addresses: Array, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null): SyncSocketSession { + onStatusUpdate?.invoke(null, "Connecting directly...") + val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port) ?: throw Exception("Failed to connect") + onStatusUpdate?.invoke(null, "Handshaking...") + + val session = createSocketSession(socket, false) { s -> + onStatusUpdate?.invoke(true, "Authorized") } - session.startAsInitiator(deviceInfo.publicKey) + session.startAsInitiator(publicKey, pairingCode) return session } @@ -526,6 +990,8 @@ class StateSync { val hash = "BLAKE2b" var protocolName = "Noise_${pattern}_${dh}_${cipher}_${hash}" val version = 1 + val RELAY_SERVER = "relay.grayjay.app" + val RELAY_PUBLIC_KEY = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw=" private const val TAG = "StateSync" const val PORT = 12315 diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt new file mode 100644 index 00000000..2d3f1580 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt @@ -0,0 +1,332 @@ +package com.futo.platformplayer.sync.internal + +import com.futo.platformplayer.logging.Logger +import com.futo.platformplayer.noise.protocol.CipherStatePair +import com.futo.platformplayer.noise.protocol.DHState +import com.futo.platformplayer.noise.protocol.HandshakeState +import com.futo.platformplayer.states.StateSync +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.Base64 + +interface IChannel : AutoCloseable { + val remotePublicKey: String? + val remoteVersion: Int? + var authorizable: IAuthorizable? + var syncSession: SyncSession? + fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) + fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null) + fun setCloseHandler(onClose: ((IChannel) -> Unit)?) +} + +class ChannelSocket(private val session: SyncSocketSession) : IChannel { + override val remotePublicKey: String? get() = session.remotePublicKey + override val remoteVersion: Int? get() = session.remoteVersion + private var onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)? = null + private var onClose: ((IChannel) -> Unit)? = null + + override var authorizable: IAuthorizable? + get() = session.authorizable + set(value) { session.authorizable = value } + override var syncSession: SyncSession? = null + + override fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) { + this.onData = onData + } + + override fun setCloseHandler(onClose: ((IChannel) -> Unit)?) { + this.onClose = onClose + } + + override fun close() { + session.stop() + onClose?.invoke(this) + } + + fun invokeDataHandler(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + onData?.invoke(session, this, opcode, subOpcode, data) + } + + override fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer?) { + if (data != null) { + session.send(opcode, subOpcode, data) + } else { + session.send(opcode, subOpcode) + } + } +} + +class ChannelRelayed( + private val session: SyncSocketSession, + private val localKeyPair: DHState, + private val publicKey: String, + private val initiator: Boolean +) : IChannel { + private val sendLock = Object() + private val decryptLock = Object() + private var handshakeState: HandshakeState? = if (initiator) { + HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR).apply { + localKeyPair.copyFrom(this@ChannelRelayed.localKeyPair) + remotePublicKey.setPublicKey(Base64.getDecoder().decode(publicKey), 0) + } + } else { + HandshakeState(StateSync.protocolName, HandshakeState.RESPONDER).apply { + localKeyPair.copyFrom(this@ChannelRelayed.localKeyPair) + } + } + private var transport: CipherStatePair? = null + override var authorizable: IAuthorizable? = null + val isAuthorized: Boolean get() = authorizable?.isAuthorized ?: false + var connectionId: Long = 0L + override var remotePublicKey: String? = publicKey + private set + override var remoteVersion: Int? = null + private set + override var syncSession: SyncSession? = null + + private var onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)? = null + private var onClose: ((IChannel) -> Unit)? = null + private var disposed = false + + init { + handshakeState?.start() + } + + override fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) { + this.onData = onData + } + + override fun setCloseHandler(onClose: ((IChannel) -> Unit)?) { + this.onClose = onClose + } + + override fun close() { + disposed = true + + if (connectionId != 0L) { + Thread { + try { + session.sendRelayError(connectionId, SyncErrorCode.ConnectionClosed) + } catch (e: Exception) { + Logger.e("ChannelRelayed", "Exception while sending relay error", e) + } + }.start() + } + + transport?.sender?.destroy() + transport?.receiver?.destroy() + transport = null + handshakeState?.destroy() + handshakeState = null + + onClose?.invoke(this) + } + + private fun throwIfDisposed() { + if (disposed) throw IllegalStateException("ChannelRelayed is disposed") + } + + fun invokeDataHandler(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + onData?.invoke(session, this, opcode, subOpcode, data) + } + + private fun completeHandshake(remoteVersion: Int, transport: CipherStatePair) { + throwIfDisposed() + + this.remoteVersion = remoteVersion + val remoteKeyBytes = ByteArray(handshakeState!!.remotePublicKey.publicKeyLength) + handshakeState!!.remotePublicKey.getPublicKey(remoteKeyBytes, 0) + this.remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes) + handshakeState?.destroy() + handshakeState = null + this.transport = transport + Logger.i("ChannelRelayed", "Completed handshake for connectionId $connectionId") + } + + private fun sendPacket(packet: ByteArray) { + throwIfDisposed() + + synchronized(sendLock) { + val encryptedPayload = ByteArray(packet.size + 16) + val encryptedLength = transport!!.sender.encryptWithAd(null, packet, 0, encryptedPayload, 0, packet.size) + + val relayedPacket = ByteArray(8 + encryptedLength) + ByteBuffer.wrap(relayedPacket).order(ByteOrder.LITTLE_ENDIAN).apply { + putLong(connectionId) + put(encryptedPayload, 0, encryptedLength) + } + + session.send(Opcode.RELAY.value, RelayOpcode.DATA.value, ByteBuffer.wrap(relayedPacket).order(ByteOrder.LITTLE_ENDIAN)) + } + } + + fun sendError(errorCode: SyncErrorCode) { + throwIfDisposed() + + synchronized(sendLock) { + val packet = ByteArray(4) + ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).putInt(errorCode.value) + + val encryptedPayload = ByteArray(4 + 16) + val encryptedLength = transport!!.sender.encryptWithAd(null, packet, 0, encryptedPayload, 0, packet.size) + + val relayedPacket = ByteArray(8 + encryptedLength) + ByteBuffer.wrap(relayedPacket).order(ByteOrder.LITTLE_ENDIAN).apply { + putLong(connectionId) + put(encryptedPayload, 0, encryptedLength) + } + + session.send(Opcode.RELAY.value, RelayOpcode.ERROR.value, ByteBuffer.wrap(relayedPacket)) + } + } + + override fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer?) { + throwIfDisposed() + + val actualCount = data?.remaining() ?: 0 + val ENCRYPTION_OVERHEAD = 16 + val CONNECTION_ID_SIZE = 8 + val HEADER_SIZE = 6 + val MAX_DATA_PER_PACKET = SyncSocketSession.MAXIMUM_PACKET_SIZE - HEADER_SIZE - CONNECTION_ID_SIZE - ENCRYPTION_OVERHEAD - 16 + + if (actualCount > MAX_DATA_PER_PACKET && data != null) { + val streamId = session.generateStreamId() + val totalSize = actualCount + var sendOffset = 0 + + while (sendOffset < totalSize) { + val bytesRemaining = totalSize - sendOffset + val bytesToSend = minOf(MAX_DATA_PER_PACKET - 8 - 2, bytesRemaining) + + val streamData: ByteArray + val streamOpcode: StreamOpcode + if (sendOffset == 0) { + streamOpcode = StreamOpcode.START + streamData = ByteArray(4 + 4 + 1 + 1 + bytesToSend) + ByteBuffer.wrap(streamData).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(streamId) + putInt(totalSize) + put(opcode.toByte()) + put(subOpcode.toByte()) + put(data.array(), data.position() + sendOffset, bytesToSend) + } + } else { + streamData = ByteArray(4 + 4 + bytesToSend) + ByteBuffer.wrap(streamData).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(streamId) + putInt(sendOffset) + put(data.array(), data.position() + sendOffset, bytesToSend) + } + streamOpcode = if (bytesToSend < bytesRemaining) StreamOpcode.DATA else StreamOpcode.END + } + + val fullPacket = ByteArray(HEADER_SIZE + streamData.size) + ByteBuffer.wrap(fullPacket).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(streamData.size + 2) + put(Opcode.STREAM.value.toByte()) + put(streamOpcode.value.toByte()) + put(streamData) + } + + sendPacket(fullPacket) + sendOffset += bytesToSend + } + } else { + val packet = ByteArray(HEADER_SIZE + actualCount) + ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(actualCount + 2) + put(opcode.toByte()) + put(subOpcode.toByte()) + if (actualCount > 0 && data != null) put(data.array(), data.position(), actualCount) + } + sendPacket(packet) + } + } + + fun sendRequestTransport(requestId: Int, publicKey: String, pairingCode: String? = null) { + throwIfDisposed() + + synchronized(sendLock) { + val channelMessage = ByteArray(1024) + val channelBytesWritten = handshakeState!!.writeMessage(channelMessage, 0, null, 0, 0) + + val publicKeyBytes = Base64.getDecoder().decode(publicKey) + if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes") + + val (pairingMessageLength, pairingMessage) = if (pairingCode != null) { + val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply { + remotePublicKey.setPublicKey(publicKeyBytes, 0) + start() + } + val pairingCodeBytes = pairingCode.toByteArray(Charsets.UTF_8) + if (pairingCodeBytes.size > 32) throw IllegalArgumentException("Pairing code must not exceed 32 bytes") + val pairingMessageBuffer = ByteArray(1024) + val bytesWritten = pairingHandshake.writeMessage(pairingMessageBuffer, 0, pairingCodeBytes, 0, pairingCodeBytes.size) + bytesWritten to pairingMessageBuffer.copyOf(bytesWritten) + } else { + 0 to ByteArray(0) + } + + val packetSize = 4 + 32 + 4 + pairingMessageLength + 4 + channelBytesWritten + val packet = ByteArray(packetSize) + ByteBuffer.wrap(packet).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(requestId) + put(publicKeyBytes) + putInt(pairingMessageLength) + if (pairingMessageLength > 0) put(pairingMessage) + putInt(channelBytesWritten) + put(channelMessage, 0, channelBytesWritten) + } + + session.send(Opcode.REQUEST.value, RequestOpcode.TRANSPORT.value, ByteBuffer.wrap(packet)) + } + } + + fun sendResponseTransport(remoteVersion: Int, requestId: Int, handshakeMessage: ByteArray) { + throwIfDisposed() + + synchronized(sendLock) { + val message = ByteArray(1024) + val plaintext = ByteArray(1024) + handshakeState!!.readMessage(handshakeMessage, 0, handshakeMessage.size, plaintext, 0) + val bytesWritten = handshakeState!!.writeMessage(message, 0, null, 0, 0) + val transport = handshakeState!!.split() + + val responsePacket = ByteArray(20 + bytesWritten) + ByteBuffer.wrap(responsePacket).order(ByteOrder.LITTLE_ENDIAN).apply { + putInt(0) // Status code + putLong(connectionId) + putInt(requestId) + putInt(bytesWritten) + put(message, 0, bytesWritten) + } + + completeHandshake(remoteVersion, transport) + session.send(Opcode.RESPONSE.value, ResponseOpcode.TRANSPORT.value, ByteBuffer.wrap(responsePacket)) + } + } + + fun decrypt(encryptedPayload: ByteBuffer): ByteBuffer { + throwIfDisposed() + + synchronized(decryptLock) { + val encryptedBytes = ByteArray(encryptedPayload.remaining()).also { encryptedPayload.get(it) } + val decryptedPayload = ByteArray(encryptedBytes.size - 16) + val plen = transport!!.receiver.decryptWithAd(null, encryptedBytes, 0, decryptedPayload, 0, encryptedBytes.size) + if (plen != decryptedPayload.size) throw IllegalStateException("Expected decrypted payload length to be $plen") + return ByteBuffer.wrap(decryptedPayload).order(ByteOrder.LITTLE_ENDIAN) + } + } + + fun handleTransportRelayed(remoteVersion: Int, connectionId: Long, handshakeMessage: ByteArray) { + throwIfDisposed() + + synchronized(decryptLock) { + this.connectionId = connectionId + val plaintext = ByteArray(1024) + val plen = handshakeState!!.readMessage(handshakeMessage, 0, handshakeMessage.size, plaintext, 0) + val transport = handshakeState!!.split() + completeHandshake(remoteVersion, transport) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/Opcode.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/Opcode.kt new file mode 100644 index 00000000..8a12b579 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/Opcode.kt @@ -0,0 +1,60 @@ +package com.futo.platformplayer.sync.internal + +enum class Opcode(val value: UByte) { + PING(0u), + PONG(1u), + NOTIFY(2u), + STREAM(3u), + DATA(4u), + REQUEST(5u), + RESPONSE(6u), + RELAY(7u) +} + +enum class NotifyOpcode(val value: UByte) { + AUTHORIZED(0u), + UNAUTHORIZED(1u), + CONNECTION_INFO(2u) +} + +enum class StreamOpcode(val value: UByte) { + START(0u), + DATA(1u), + END(2u) +} + +enum class RequestOpcode(val value: UByte) { + CONNECTION_INFO(0u), + TRANSPORT(1u), + TRANSPORT_RELAYED(2u), + PUBLISH_RECORD(3u), + DELETE_RECORD(4u), + LIST_RECORD_KEYS(5u), + GET_RECORD(6u), + BULK_PUBLISH_RECORD(7u), + BULK_GET_RECORD(8u), + BULK_CONNECTION_INFO(9u), + BULK_DELETE_RECORD(10u) +} + +enum class ResponseOpcode(val value: UByte) { + CONNECTION_INFO(0u), + TRANSPORT(1u), + TRANSPORT_RELAYED(2u), //TODO: Server errors also included in this one, disentangle? + PUBLISH_RECORD(3u), + DELETE_RECORD(4u), + LIST_RECORD_KEYS(5u), + GET_RECORD(6u), + BULK_PUBLISH_RECORD(7u), + BULK_GET_RECORD(8u), + BULK_CONNECTION_INFO(9u), + BULK_DELETE_RECORD(10u) +} + +enum class RelayOpcode(val value: UByte) { + DATA(0u), + RELAYED_DATA(1u), + ERROR(2u), + RELAYED_ERROR(3u), + RELAY_ERROR(4u) +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncDeviceInfo.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncDeviceInfo.kt index 17a70860..a3bb6e00 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncDeviceInfo.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncDeviceInfo.kt @@ -5,10 +5,12 @@ class SyncDeviceInfo { var publicKey: String var addresses: Array var port: Int + var pairingCode: String? - constructor(publicKey: String, addresses: Array, port: Int) { + constructor(publicKey: String, addresses: Array, port: Int, pairingCode: String?) { this.publicKey = publicKey this.addresses = addresses this.port = port + this.pairingCode = pairingCode } } \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncErrorCode.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncErrorCode.kt new file mode 100644 index 00000000..0b4be0ce --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncErrorCode.kt @@ -0,0 +1,6 @@ +package com.futo.platformplayer.sync.internal + +enum class SyncErrorCode(val value: Int) { + ConnectionClosed(1), + NotFound(2) +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSession.kt index e4273d63..76af1edb 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSession.kt @@ -1,37 +1,13 @@ package com.futo.platformplayer.sync.internal import com.futo.platformplayer.UIDialogs -import com.futo.platformplayer.activities.MainActivity -import com.futo.platformplayer.api.media.Serializer import com.futo.platformplayer.logging.Logger -import com.futo.platformplayer.models.HistoryVideo import com.futo.platformplayer.models.Subscription -import com.futo.platformplayer.smartMerge -import com.futo.platformplayer.states.StateApp -import com.futo.platformplayer.states.StateBackup -import com.futo.platformplayer.states.StateHistory -import com.futo.platformplayer.states.StatePlaylists -import com.futo.platformplayer.states.StateSubscriptionGroups import com.futo.platformplayer.states.StateSubscriptions -import com.futo.platformplayer.states.StateSync -import com.futo.platformplayer.sync.SyncSessionData -import com.futo.platformplayer.sync.internal.SyncSocketSession.Opcode -import com.futo.platformplayer.sync.models.SendToDevicePackage -import com.futo.platformplayer.sync.models.SyncPlaylistsPackage -import com.futo.platformplayer.sync.models.SyncSubscriptionGroupsPackage import com.futo.platformplayer.sync.models.SyncSubscriptionsPackage -import com.futo.platformplayer.sync.models.SyncWatchLaterPackage -import com.futo.platformplayer.toUtf8String -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json -import java.io.ByteArrayInputStream import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.time.Instant -import java.time.OffsetDateTime -import java.time.ZoneOffset import java.util.UUID interface IAuthorizable { @@ -39,13 +15,14 @@ interface IAuthorizable { } class SyncSession : IAuthorizable { - private val _socketSessions: MutableList = mutableListOf() + private val _channels: MutableList = mutableListOf() private var _authorized: Boolean = false private var _remoteAuthorized: Boolean = false private val _onAuthorized: (session: SyncSession, isNewlyAuthorized: Boolean, isNewSession: Boolean) -> Unit private val _onUnauthorized: (session: SyncSession) -> Unit private val _onClose: (session: SyncSession) -> Unit private val _onConnectedChanged: (session: SyncSession, connected: Boolean) -> Unit + private val _dataHandler: (session: SyncSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit val remotePublicKey: String override val isAuthorized get() = _authorized && _remoteAuthorized private var _wasAuthorized = false @@ -56,140 +33,151 @@ class SyncSession : IAuthorizable { private set val displayName: String get() = remoteDeviceName ?: remotePublicKey - var connected: Boolean = false - private set(v) { - if (field != v) { - field = v - this._onConnectedChanged(this, v) + val linkType: LinkType get() + { + var hasProxied = false + var hasDirect = false + synchronized(_channels) + { + for (channel in _channels) + { + if (channel is ChannelRelayed) + hasProxied = true + if (channel is ChannelSocket) + hasDirect = true + if (hasProxied && hasDirect) + return LinkType.Local + } } + + if (hasProxied) + return LinkType.Proxied + if (hasDirect) + return LinkType.Local + return LinkType.None } - constructor(remotePublicKey: String, onAuthorized: (session: SyncSession, isNewlyAuthorized: Boolean, isNewSession: Boolean) -> Unit, onUnauthorized: (session: SyncSession) -> Unit, onConnectedChanged: (session: SyncSession, connected: Boolean) -> Unit, onClose: (session: SyncSession) -> Unit, remoteDeviceName: String?) { + var connected: Boolean = false + private set(v) { + if (field != v) { + field = v + this._onConnectedChanged(this, v) + } + } + + constructor( + remotePublicKey: String, + onAuthorized: (session: SyncSession, isNewlyAuthorized: Boolean, isNewSession: Boolean) -> Unit, + onUnauthorized: (session: SyncSession) -> Unit, + onConnectedChanged: (session: SyncSession, connected: Boolean) -> Unit, + onClose: (session: SyncSession) -> Unit, + dataHandler: (session: SyncSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit, + remoteDeviceName: String? = null + ) { this.remotePublicKey = remotePublicKey + this.remoteDeviceName = remoteDeviceName _onAuthorized = onAuthorized _onUnauthorized = onUnauthorized _onConnectedChanged = onConnectedChanged _onClose = onClose + _dataHandler = dataHandler } - fun addSocketSession(socketSession: SyncSocketSession) { - if (socketSession.remotePublicKey != remotePublicKey) { - throw Exception("Public key of session must match public key of socket session") + fun addChannel(channel: IChannel) { + if (channel.remotePublicKey != remotePublicKey) { + throw Exception("Public key of session must match public key of channel") } - synchronized(_socketSessions) { - _socketSessions.add(socketSession) - connected = _socketSessions.isNotEmpty() + synchronized(_channels) { + _channels.add(channel) + connected = _channels.isNotEmpty() } - socketSession.authorizable = this + channel.authorizable = this + channel.syncSession = this } - fun authorize(socketSession: SyncSocketSession) { + fun authorize() { Logger.i(TAG, "Sent AUTHORIZED with session id $_id") - - if (socketSession.remoteVersion >= 3) { - val idStringBytes = _id.toString().toByteArray() - val nameBytes = "${android.os.Build.MANUFACTURER}-${android.os.Build.MODEL}".toByteArray() - val buffer = ByteArray(1 + idStringBytes.size + 1 + nameBytes.size) - socketSession.send(Opcode.NOTIFY_AUTHORIZED.value, 0u, ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).apply { - put(idStringBytes.size.toByte()) - put(idStringBytes) - put(nameBytes.size.toByte()) - put(nameBytes) - }.apply { flip() }) - } else { - socketSession.send(Opcode.NOTIFY_AUTHORIZED.value, 0u, ByteBuffer.wrap(_id.toString().toByteArray())) - } + val idString = _id.toString() + val idBytes = idString.toByteArray(Charsets.UTF_8) + val name = "${android.os.Build.MANUFACTURER}-${android.os.Build.MODEL}" + val nameBytes = name.toByteArray(Charsets.UTF_8) + val buffer = ByteArray(1 + idBytes.size + 1 + nameBytes.size) + buffer[0] = idBytes.size.toByte() + System.arraycopy(idBytes, 0, buffer, 1, idBytes.size) + buffer[1 + idBytes.size] = nameBytes.size.toByte() + System.arraycopy(nameBytes, 0, buffer, 2 + idBytes.size, nameBytes.size) + send(Opcode.NOTIFY.value, NotifyOpcode.AUTHORIZED.value, ByteBuffer.wrap(buffer)) _authorized = true checkAuthorized() } - fun unauthorize(socketSession: SyncSocketSession? = null) { - if (socketSession != null) - socketSession.send(Opcode.NOTIFY_UNAUTHORIZED.value) - else { - val ss = synchronized(_socketSessions) { - _socketSessions.first() - } - - ss.send(Opcode.NOTIFY_UNAUTHORIZED.value) - } + fun unauthorize() { + send(Opcode.NOTIFY.value, NotifyOpcode.UNAUTHORIZED.value) } private fun checkAuthorized() { if (isAuthorized) { - val isNewlyAuthorized = !_wasAuthorized; - val isNewSession = _lastAuthorizedRemoteId != _remoteId; - Logger.i(TAG, "onAuthorized (isNewlyAuthorized = $isNewlyAuthorized, isNewSession = $isNewSession)"); - _onAuthorized.invoke(this, !_wasAuthorized, _lastAuthorizedRemoteId != _remoteId) + val isNewlyAuthorized = !_wasAuthorized + val isNewSession = _lastAuthorizedRemoteId != _remoteId + Logger.i(TAG, "onAuthorized (isNewlyAuthorized = $isNewlyAuthorized, isNewSession = $isNewSession)") + _onAuthorized(this, isNewlyAuthorized, isNewSession) _wasAuthorized = true _lastAuthorizedRemoteId = _remoteId } } - fun removeSocketSession(socketSession: SyncSocketSession) { - synchronized(_socketSessions) { - _socketSessions.remove(socketSession) - connected = _socketSessions.isNotEmpty() + fun removeChannel(channel: IChannel) { + synchronized(_channels) { + _channels.remove(channel) + connected = _channels.isNotEmpty() } } fun close() { - synchronized(_socketSessions) { - for (socketSession in _socketSessions) { - socketSession.stop() - } - - _socketSessions.clear() + synchronized(_channels) { + _channels.forEach { it.close() } + _channels.clear() } - - _onClose.invoke(this) + _onClose(this) } - fun handlePacket(socketSession: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { try { - Logger.i(TAG, "Handle packet (opcode: ${opcode}, subOpcode: ${subOpcode}, data.length: ${data.remaining()})") + Logger.i(TAG, "Handle packet (opcode: $opcode, subOpcode: $subOpcode, data.length: ${data.remaining()})") when (opcode) { - Opcode.NOTIFY_AUTHORIZED.value -> { - if (socketSession.remoteVersion >= 3) { + Opcode.NOTIFY.value -> when (subOpcode) { + NotifyOpcode.AUTHORIZED.value -> { val idByteCount = data.get().toInt() if (idByteCount > 64) throw Exception("Id should always be smaller than 64 bytes") - val idBytes = ByteArray(idByteCount) data.get(idBytes) val nameByteCount = data.get().toInt() if (nameByteCount > 64) throw Exception("Name should always be smaller than 64 bytes") - val nameBytes = ByteArray(nameByteCount) data.get(nameBytes) _remoteId = UUID.fromString(idBytes.toString(Charsets.UTF_8)) remoteDeviceName = nameBytes.toString(Charsets.UTF_8) - } else { - val str = data.toUtf8String() - _remoteId = if (data.remaining() >= 0) UUID.fromString(str) else UUID.fromString("00000000-0000-0000-0000-000000000000") - remoteDeviceName = null + _remoteAuthorized = true + Logger.i(TAG, "Received AUTHORIZED with session id $_remoteId (device name: '${remoteDeviceName ?: "not set"}')") + checkAuthorized() + return + } + NotifyOpcode.UNAUTHORIZED.value -> { + _remoteAuthorized = false + _remoteId = null + remoteDeviceName = null + _lastAuthorizedRemoteId = null + _onUnauthorized(this) + return } - - _remoteAuthorized = true - Logger.i(TAG, "Received AUTHORIZED with session id $_remoteId (device name: '${remoteDeviceName ?: "not set"}')") - checkAuthorized() - return } - Opcode.NOTIFY_UNAUTHORIZED.value -> { - _remoteId = null - remoteDeviceName = null - _lastAuthorizedRemoteId = null - _remoteAuthorized = false - _onUnauthorized(this) - return - } - //TODO: Handle any kind of packet (that is not necessarily authorized) } if (!isAuthorized) { @@ -197,282 +185,58 @@ class SyncSession : IAuthorizable { } if (opcode != Opcode.DATA.value) { - Logger.w(TAG, "Unknown opcode received: (opcode = ${opcode}, subOpcode = ${subOpcode})}") + Logger.w(TAG, "Unknown opcode received: (opcode = $opcode, subOpcode = $subOpcode)") return } - Logger.i(TAG, "Received (opcode = ${opcode}, subOpcode = ${subOpcode}) (${data.remaining()} bytes)") - //TODO: Abstract this out - when (subOpcode) { - GJSyncOpcodes.sendToDevices -> { - StateApp.instance.scopeOrNull?.launch(Dispatchers.Main) { - val context = StateApp.instance.contextOrNull; - if (context != null && context is MainActivity) { - val dataBody = ByteArray(data.remaining()); - val remainder = data.remaining(); - data.get(dataBody, 0, remainder); - val json = String(dataBody, Charsets.UTF_8); - val obj = Json.decodeFromString(json); - UIDialogs.appToast("Received url from device [${socketSession.remotePublicKey}]:\n{${obj.url}"); - context.handleUrl(obj.url, obj.position); - } - }; - } - - GJSyncOpcodes.syncStateExchange -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val syncSessionData = Serializer.json.decodeFromString(json); - - Logger.i(TAG, "Received SyncSessionData from " + remotePublicKey); - - - sendData(GJSyncOpcodes.syncSubscriptions, StateSubscriptions.instance.getSyncSubscriptionsPackageString()); - sendData(GJSyncOpcodes.syncSubscriptionGroups, StateSubscriptionGroups.instance.getSyncSubscriptionGroupsPackageString()); - sendData(GJSyncOpcodes.syncPlaylists, StatePlaylists.instance.getSyncPlaylistsPackageString()) - - sendData(GJSyncOpcodes.syncWatchLater, Json.encodeToString(StatePlaylists.instance.getWatchLaterSyncPacket(false))); - - val recentHistory = StateHistory.instance.getRecentHistory(syncSessionData.lastHistory); - if(recentHistory.size > 0) - sendJsonData(GJSyncOpcodes.syncHistory, recentHistory); - } - - GJSyncOpcodes.syncExport -> { - val dataBody = ByteArray(data.remaining()); - val bytesStr = ByteArrayInputStream(data.array(), data.position(), data.remaining()); - try { - val exportStruct = StateBackup.ExportStructure.fromZipBytes(bytesStr); - for (store in exportStruct.stores) { - if (store.key.equals("subscriptions", true)) { - val subStore = - StateSubscriptions.instance.getUnderlyingSubscriptionsStore(); - StateApp.instance.scopeOrNull?.launch(Dispatchers.IO) { - val pack = SyncSubscriptionsPackage( - store.value.map { - subStore.fromReconstruction(it, exportStruct.cache) - }, - StateSubscriptions.instance.getSubscriptionRemovals() - ); - handleSyncSubscriptionPackage(this@SyncSession, pack); - } - } - } - } finally { - bytesStr.close(); - } - } - - GJSyncOpcodes.syncSubscriptions -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val subPackage = Serializer.json.decodeFromString(json); - handleSyncSubscriptionPackage(this, subPackage); - - val newestSub = subPackage.subscriptions.maxOf { it.creationTime }; - - val sesData = StateSync.instance.getSyncSessionData(remotePublicKey); - if(newestSub > sesData.lastSubscription) { - sesData.lastSubscription = newestSub; - StateSync.instance.saveSyncSessionData(sesData); - } - } - - GJSyncOpcodes.syncSubscriptionGroups -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val pack = Serializer.json.decodeFromString(json); - - var lastSubgroupChange = OffsetDateTime.MIN; - for(group in pack.groups){ - if(group.lastChange > lastSubgroupChange) - lastSubgroupChange = group.lastChange; - - val existing = StateSubscriptionGroups.instance.getSubscriptionGroup(group.id); - - if(existing == null) - StateSubscriptionGroups.instance.updateSubscriptionGroup(group, false, true); - else if(existing.lastChange < group.lastChange) { - existing.name = group.name; - existing.urls = group.urls; - existing.image = group.image; - existing.priority = group.priority; - existing.lastChange = group.lastChange; - StateSubscriptionGroups.instance.updateSubscriptionGroup(existing, false, true); - } - } - for(removal in pack.groupRemovals) { - val creation = StateSubscriptionGroups.instance.getSubscriptionGroup(removal.key); - val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value, 0), ZoneOffset.UTC); - if(creation != null && creation.creationTime < removalTime) - StateSubscriptionGroups.instance.deleteSubscriptionGroup(removal.key, false); - } - } - - GJSyncOpcodes.syncPlaylists -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val pack = Serializer.json.decodeFromString(json); - - for(playlist in pack.playlists) { - val existing = StatePlaylists.instance.getPlaylist(playlist.id); - - if(existing == null) - StatePlaylists.instance.createOrUpdatePlaylist(playlist, false); - else if(existing.dateUpdate.toLocalDateTime() < playlist.dateUpdate.toLocalDateTime()) { - existing.dateUpdate = playlist.dateUpdate; - existing.name = playlist.name; - existing.videos = playlist.videos; - existing.dateCreation = playlist.dateCreation; - existing.datePlayed = playlist.datePlayed; - StatePlaylists.instance.createOrUpdatePlaylist(existing, false); - } - } - for(removal in pack.playlistRemovals) { - val creation = StatePlaylists.instance.getPlaylist(removal.key); - val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value, 0), ZoneOffset.UTC); - if(creation != null && creation.dateCreation < removalTime) - StatePlaylists.instance.removePlaylist(creation, false); - - } - } - - GJSyncOpcodes.syncWatchLater -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val pack = Serializer.json.decodeFromString(json); - - Logger.i(TAG, "SyncWatchLater received ${pack.videos.size} (${pack.videoAdds?.size}, ${pack.videoRemovals?.size})"); - - val allExisting = StatePlaylists.instance.getWatchLater(); - for(video in pack.videos) { - val existing = allExisting.firstOrNull { it.url == video.url }; - val time = if(pack.videoAdds != null && pack.videoAdds.containsKey(video.url)) OffsetDateTime.ofInstant(Instant.ofEpochSecond(pack.videoAdds[video.url] ?: 0), ZoneOffset.UTC) else OffsetDateTime.MIN; - - if(existing == null) { - StatePlaylists.instance.addToWatchLater(video, false); - if(time > OffsetDateTime.MIN) - StatePlaylists.instance.setWatchLaterAddTime(video.url, time); - } - } - for(removal in pack.videoRemovals) { - val watchLater = allExisting.firstOrNull { it.url == removal.key } ?: continue; - val creation = StatePlaylists.instance.getWatchLaterRemovalTime(watchLater.url) ?: OffsetDateTime.MIN; - val removalTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(removal.value), ZoneOffset.UTC); - if(creation < removalTime) - StatePlaylists.instance.removeFromWatchLater(watchLater, false, removalTime); - } - - val packReorderTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(pack.reorderTime), ZoneOffset.UTC); - val localReorderTime = StatePlaylists.instance.getWatchLaterLastReorderTime(); - if(localReorderTime < packReorderTime && pack.ordering != null) { - StatePlaylists.instance.updateWatchLaterOrdering(smartMerge(pack.ordering!!, StatePlaylists.instance.getWatchLaterOrdering()), true); - } - } - - GJSyncOpcodes.syncHistory -> { - val dataBody = ByteArray(data.remaining()); - data.get(dataBody); - val json = String(dataBody, Charsets.UTF_8); - val history = Serializer.json.decodeFromString>(json); - Logger.i(TAG, "SyncHistory received ${history.size} videos from ${remotePublicKey}"); - - var lastHistory = OffsetDateTime.MIN; - for(video in history){ - val hist = StateHistory.instance.getHistoryByVideo(video.video, true, video.date); - if(hist != null) - StateHistory.instance.updateHistoryPosition(video.video, hist, true, video.position, video.date) - if(lastHistory < video.date) - lastHistory = video.date; - } - - if(lastHistory != OffsetDateTime.MIN && history.size > 1) { - val sesData = StateSync.instance.getSyncSessionData(remotePublicKey); - if (lastHistory > sesData.lastHistory) { - sesData.lastHistory = lastHistory; - StateSync.instance.saveSyncSessionData(sesData); - } - } - } - } + Logger.i(TAG, "Received (opcode = $opcode, subOpcode = $subOpcode) (${data.remaining()} bytes)") + _dataHandler.invoke(this, opcode, subOpcode, data) + } catch (ex: Exception) { + Logger.w(TAG, "Failed to handle sync package $opcode: ${ex.message}", ex) } catch(ex: Exception) { Logger.w(TAG, "Failed to handle sync package ${opcode}: ${ex.message}", ex); } } - private fun handleSyncSubscriptionPackage(origin: SyncSession, pack: SyncSubscriptionsPackage) { - val added = mutableListOf() - for(sub in pack.subscriptions) { - if(!StateSubscriptions.instance.isSubscribed(sub.channel)) { - val removalTime = StateSubscriptions.instance.getSubscriptionRemovalTime(sub.channel.url); - if(sub.creationTime > removalTime) { - val newSub = StateSubscriptions.instance.addSubscription(sub.channel, sub.creationTime); - added.add(newSub); - } - } - } - if(added.size > 3) - UIDialogs.appToast("${added.size} Subscriptions from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}"); - else if(added.size > 0) - UIDialogs.appToast("Subscriptions from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}:\n" + - added.map { it.channel.name }.joinToString("\n")); - - - if(pack.subscriptions != null && pack.subscriptions.size > 0) { - for (subRemoved in pack.subscriptionRemovals) { - val removed = StateSubscriptions.instance.applySubscriptionRemovals(pack.subscriptionRemovals); - if(removed.size > 3) - UIDialogs.appToast("Removed ${removed.size} Subscriptions from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}"); - else if(removed.size > 0) - UIDialogs.appToast("Subscriptions removed from ${origin.remotePublicKey.substring(0, Math.min(8, origin.remotePublicKey.length))}:\n" + - removed.map { it.channel.name }.joinToString("\n")); - - } - } - } - inline fun sendJsonData(subOpcode: UByte, data: T) { - send(Opcode.DATA.value, subOpcode, Json.encodeToString(data)); + send(Opcode.DATA.value, subOpcode, Json.encodeToString(data)) } - fun sendData(subOpcode: UByte, data: String) { - send(Opcode.DATA.value, subOpcode, data.toByteArray(Charsets.UTF_8)); - } - fun send(opcode: UByte, subOpcode: UByte, data: String) { - send(opcode, subOpcode, data.toByteArray(Charsets.UTF_8)); - } - fun send(opcode: UByte, subOpcode: UByte, data: ByteArray) { - val socketSessions = synchronized(_socketSessions) { - _socketSessions.toList() - } - if (socketSessions.isEmpty()) { - Logger.v(TAG, "Packet was not sent (opcode = ${opcode}, subOpcode = ${subOpcode}) due to no connected sockets") + fun sendData(subOpcode: UByte, data: String) { + send(Opcode.DATA.value, subOpcode, ByteBuffer.wrap(data.toByteArray(Charsets.UTF_8))) + } + + fun send(opcode: UByte, subOpcode: UByte, data: String) { + send(opcode, subOpcode, ByteBuffer.wrap(data.toByteArray(Charsets.UTF_8))) + } + + fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null) { + //TODO: Prioritize local connections + val channels = synchronized(_channels) { _channels.toList() } + if (channels.isEmpty()) { + //TODO: Should this throw? + Logger.v(TAG, "Packet was not sent (opcode = $opcode, subOpcode = $subOpcode) due to no connected sockets") return } var sent = false - for (socketSession in socketSessions) { + for (channel in channels) { try { - socketSession.send(opcode, subOpcode, ByteBuffer.wrap(data)) + channel.send(opcode, subOpcode, data) sent = true break } catch (e: Throwable) { - Logger.w(TAG, "Packet failed to send (opcode = ${opcode}, subOpcode = ${subOpcode})", e) + Logger.w(TAG, "Packet failed to send (opcode = $opcode, subOpcode = $subOpcode)", e) } } if (!sent) { - throw Exception("Packet was not sent (opcode = ${opcode}, subOpcode = ${subOpcode}) due to send errors and no remaining candidates") + throw Exception("Packet was not sent (opcode = $opcode, subOpcode = $subOpcode) due to send errors and no remaining candidates") } } - private companion object { - const val TAG = "SyncSession" + companion object { + private const val TAG = "SyncSession" } } \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt index c997cec4..c8f4f683 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt @@ -1,5 +1,6 @@ package com.futo.platformplayer.sync.internal +import android.os.Build import com.futo.platformplayer.LittleEndianDataInputStream import com.futo.platformplayer.LittleEndianDataOutputStream import com.futo.platformplayer.ensureNotMainThread @@ -8,22 +9,19 @@ import com.futo.platformplayer.noise.protocol.CipherStatePair import com.futo.platformplayer.noise.protocol.DHState import com.futo.platformplayer.noise.protocol.HandshakeState import com.futo.platformplayer.states.StateSync +import kotlinx.coroutines.CompletableDeferred +import java.net.Inet4Address +import java.net.Inet6Address +import java.net.InetAddress +import java.net.NetworkInterface import java.nio.ByteBuffer import java.nio.ByteOrder -import java.util.UUID +import java.util.Base64 +import java.util.Locale +import java.util.concurrent.ConcurrentHashMap +import kotlin.math.min class SyncSocketSession { - enum class Opcode(val value: UByte) { - PING(0u), - PONG(1u), - NOTIFY_AUTHORIZED(2u), - NOTIFY_UNAUTHORIZED(3u), - STREAM_START(4u), - STREAM_DATA(5u), - STREAM_END(6u), - DATA(7u) - } - private val _inputStream: LittleEndianDataInputStream private val _outputStream: LittleEndianDataOutputStream private val _sendLockObject = Object() @@ -32,11 +30,15 @@ class SyncSocketSession { private val _sendBuffer = ByteArray(MAXIMUM_PACKET_SIZE) private val _sendBufferEncrypted = ByteArray(MAXIMUM_PACKET_SIZE_ENCRYPTED) private val _syncStreams = hashMapOf() - private val _streamIdGenerator = 0 + private var _streamIdGenerator = 0 private val _streamIdGeneratorLock = Object() - private val _onClose: (session: SyncSocketSession) -> Unit - private val _onHandshakeComplete: (session: SyncSocketSession) -> Unit - private var _thread: Thread? = null + private var _requestIdGenerator = 0 + private val _requestIdGeneratorLock = Object() + private val _onClose: ((session: SyncSocketSession) -> Unit)? + private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)? + private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? + private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? + private val _isHandshakeAllowed: ((session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? private var _cipherStatePair: CipherStatePair? = null private var _remotePublicKey: String? = null val remotePublicKey: String? get() = _remotePublicKey @@ -44,55 +46,90 @@ class SyncSocketSession { private val _localKeyPair: DHState private var _localPublicKey: String val localPublicKey: String get() = _localPublicKey - private val _onData: (session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit + private val _onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? + val isAuthorized: Boolean + get() = authorizable?.isAuthorized ?: false var authorizable: IAuthorizable? = null var remoteVersion: Int = -1 private set val remoteAddress: String - constructor(remoteAddress: String, localKeyPair: DHState, inputStream: LittleEndianDataInputStream, outputStream: LittleEndianDataOutputStream, onClose: (session: SyncSocketSession) -> Unit, onHandshakeComplete: (session: SyncSocketSession) -> Unit, onData: (session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit) { + private val _channels = ConcurrentHashMap() + private val _pendingChannels = ConcurrentHashMap>>() + private val _pendingConnectionInfoRequests = ConcurrentHashMap>() + private val _pendingPublishRequests = ConcurrentHashMap>() + private val _pendingDeleteRequests = ConcurrentHashMap>() + private val _pendingListKeysRequests = ConcurrentHashMap>>>() + private val _pendingGetRecordRequests = ConcurrentHashMap?>>() + private val _pendingBulkGetRecordRequests = ConcurrentHashMap>>>() + private val _pendingBulkConnectionInfoRequests = ConcurrentHashMap>>() + + data class ConnectionInfo( + val port: UShort, + val name: String, + val remoteIp: String, + val ipv4Addresses: List, + val ipv6Addresses: List, + val allowLocalDirect: Boolean, + val allowRemoteDirect: Boolean, + val allowRemoteHolePunched: Boolean, + val allowRemoteProxied: Boolean + ) + + constructor( + remoteAddress: String, + localKeyPair: DHState, + inputStream: LittleEndianDataInputStream, + outputStream: LittleEndianDataOutputStream, + onClose: ((session: SyncSocketSession) -> Unit)? = null, + onHandshakeComplete: ((session: SyncSocketSession) -> Unit)? = null, + onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null, + onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null, + onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null, + isHandshakeAllowed: ((session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? = null + ) { _inputStream = inputStream _outputStream = outputStream _onClose = onClose _onHandshakeComplete = onHandshakeComplete _localKeyPair = localKeyPair _onData = onData + _onNewChannel = onNewChannel + _onChannelEstablished = onChannelEstablished + _isHandshakeAllowed = isHandshakeAllowed this.remoteAddress = remoteAddress val localPublicKey = ByteArray(localKeyPair.publicKeyLength) localKeyPair.getPublicKey(localPublicKey, 0) - _localPublicKey = java.util.Base64.getEncoder().encodeToString(localPublicKey) + _localPublicKey = Base64.getEncoder().encodeToString(localPublicKey) } - fun startAsInitiator(remotePublicKey: String) { + fun startAsInitiator(remotePublicKey: String, pairingCode: String? = null) { _started = true - _thread = Thread { - try { - handshakeAsInitiator(remotePublicKey) - _onHandshakeComplete.invoke(this) - receiveLoop() - } catch (e: Throwable) { - Logger.e(TAG, "Failed to run as initiator", e) - } finally { - stop() - } - }.apply { start() } + try { + handshakeAsInitiator(remotePublicKey, pairingCode) + _onHandshakeComplete?.invoke(this) + receiveLoop() + } catch (e: Throwable) { + Logger.e(TAG, "Failed to run as initiator", e) + } finally { + stop() + } } fun startAsResponder() { _started = true - _thread = Thread { - try { - handshakeAsResponder() - _onHandshakeComplete.invoke(this) + try { + if (handshakeAsResponder()) { + _onHandshakeComplete?.invoke(this) receiveLoop() - } catch(e: Throwable) { - Logger.e(TAG, "Failed to run as responder", e) - } finally { - stop() } - }.apply { start() } + } catch (e: Throwable) { + Logger.e(TAG, "Failed to run as responder", e) + } finally { + stop() + } } private fun receiveLoop() { @@ -116,7 +153,7 @@ class SyncSocketSession { val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize) //Logger.i(TAG, "Decrypted message (size = ${plen})") - handleData(_bufferDecrypted, plen) + handleData(_bufferDecrypted, plen, null) } catch (e: Throwable) { Logger.e(TAG, "Exception while receiving data", e) break @@ -126,46 +163,129 @@ class SyncSocketSession { fun stop() { _started = false - _onClose(this) + _pendingConnectionInfoRequests.forEach { it.value.cancel() } + _pendingConnectionInfoRequests.clear() + _pendingPublishRequests.forEach { it.value.cancel() } + _pendingPublishRequests.clear() + _pendingDeleteRequests.forEach { it.value.cancel() } + _pendingDeleteRequests.clear() + _pendingListKeysRequests.forEach { it.value.cancel() } + _pendingListKeysRequests.clear() + _pendingGetRecordRequests.forEach { it.value.cancel() } + _pendingGetRecordRequests.clear() + _pendingBulkGetRecordRequests.forEach { it.value.cancel() } + _pendingBulkGetRecordRequests.clear() + _pendingBulkConnectionInfoRequests.forEach { it.value.cancel() } + _pendingBulkConnectionInfoRequests.clear() + _pendingChannels.forEach { it.value.first.close(); it.value.second.cancel() } + _pendingChannels.clear() + synchronized(_syncStreams) { + _syncStreams.clear() + } + _channels.values.forEach { it.close() } + _channels.clear() + _onClose?.invoke(this) _inputStream.close() _outputStream.close() - _thread = null + _cipherStatePair?.sender?.destroy() + _cipherStatePair?.receiver?.destroy() Logger.i(TAG, "Session closed") } - private fun handshakeAsInitiator(remotePublicKey: String) { + private fun handshakeAsInitiator(remotePublicKey: String, pairingCode: String?) { performVersionCheck() val initiator = HandshakeState(StateSync.protocolName, HandshakeState.INITIATOR) initiator.localKeyPair.copyFrom(_localKeyPair) + initiator.remotePublicKey.setPublicKey(Base64.getDecoder().decode(remotePublicKey), 0) + initiator.start() - initiator.remotePublicKey.setPublicKey(java.util.Base64.getDecoder().decode(remotePublicKey), 0) - _cipherStatePair = handshake(initiator) - - _remotePublicKey = initiator.remotePublicKey.let { - val pkey = ByteArray(it.publicKeyLength) - it.getPublicKey(pkey, 0) - return@let java.util.Base64.getEncoder().encodeToString(pkey) + val pairingMessage: ByteArray + val pairingMessageLength: Int + if (pairingCode != null) { + val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR) + pairingHandshake.remotePublicKey.setPublicKey(Base64.getDecoder().decode(remotePublicKey), 0) + pairingHandshake.start() + val pairingCodeBytes = pairingCode.toByteArray(Charsets.UTF_8) + val pairingBuffer = ByteArray(512) + pairingMessageLength = pairingHandshake.writeMessage(pairingBuffer, 0, pairingCodeBytes, 0, pairingCodeBytes.size) + pairingMessage = pairingBuffer.copyOf(pairingMessageLength) + } else { + pairingMessage = ByteArray(0) + pairingMessageLength = 0 } + + val mainBuffer = ByteArray(512) + val mainLength = initiator.writeMessage(mainBuffer, 0, null, 0, 0) + + val messageData = ByteBuffer.allocate(4 + pairingMessageLength + mainLength).order(ByteOrder.LITTLE_ENDIAN) + messageData.putInt(pairingMessageLength) + if (pairingMessageLength > 0) messageData.put(pairingMessage) + messageData.put(mainBuffer, 0, mainLength) + val messageDataArray = messageData.array() + _outputStream.writeInt(messageDataArray.size) + _outputStream.write(messageDataArray) + + val responseSize = _inputStream.readInt() + val responseMessage = ByteArray(responseSize) + _inputStream.readFully(responseMessage) + val plaintext = ByteArray(512) // Buffer for any payload (none expected here) + initiator.readMessage(responseMessage, 0, responseSize, plaintext, 0) + + _cipherStatePair = initiator.split() + val remoteKeyBytes = ByteArray(initiator.remotePublicKey.publicKeyLength) + initiator.remotePublicKey.getPublicKey(remoteKeyBytes, 0) + _remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes) } - private fun handshakeAsResponder() { + private fun handshakeAsResponder(): Boolean { performVersionCheck() val responder = HandshakeState(StateSync.protocolName, HandshakeState.RESPONDER) responder.localKeyPair.copyFrom(_localKeyPair) - _cipherStatePair = handshake(responder) + responder.start() - _remotePublicKey = responder.remotePublicKey.let { - val pkey = ByteArray(it.publicKeyLength) - it.getPublicKey(pkey, 0) - return@let java.util.Base64.getEncoder().encodeToString(pkey) + val messageSize = _inputStream.readInt() + val message = ByteArray(messageSize) + _inputStream.readFully(message) + val messageBuffer = ByteBuffer.wrap(message).order(ByteOrder.LITTLE_ENDIAN) + + val pairingMessageLength = messageBuffer.int + val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { messageBuffer.get(it) } else byteArrayOf() + val mainLength = messageSize - 4 - pairingMessageLength + val mainMessage = ByteArray(mainLength).also { messageBuffer.get(it) } + + var pairingCode: String? = null + if (pairingMessageLength > 0) { + val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER) + pairingHandshake.localKeyPair.copyFrom(_localKeyPair) + pairingHandshake.start() + val pairingPlaintext = ByteArray(512) + val plaintextLength = pairingHandshake.readMessage(pairingMessage, 0, pairingMessageLength, pairingPlaintext, 0) + pairingCode = String(pairingPlaintext, 0, plaintextLength, Charsets.UTF_8) + } + + val plaintext = ByteArray(512) + responder.readMessage(mainMessage, 0, mainLength, plaintext, 0) + + val responseBuffer = ByteArray(512) + val responseLength = responder.writeMessage(responseBuffer, 0, null, 0, 0) + _outputStream.writeInt(responseLength) + _outputStream.write(responseBuffer, 0, responseLength) + + _cipherStatePair = responder.split() + val remoteKeyBytes = ByteArray(responder.remotePublicKey.publicKeyLength) + responder.remotePublicKey.getPublicKey(remoteKeyBytes, 0) + _remotePublicKey = Base64.getEncoder().encodeToString(remoteKeyBytes) + + return (_remotePublicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(this, _remotePublicKey!!, pairingCode) ?: true)).also { + if (!it) stop() } } private fun performVersionCheck() { - val CURRENT_VERSION = 3 - val MINIMUM_VERSION = 2 + val CURRENT_VERSION = 4 + val MINIMUM_VERSION = 4 _outputStream.writeInt(CURRENT_VERSION) remoteVersion = _inputStream.readInt() Logger.i(TAG, "performVersionCheck (version = $remoteVersion)") @@ -173,44 +293,8 @@ class SyncSocketSession { throw Exception("Invalid version") } - private fun handshake(handshakeState: HandshakeState): CipherStatePair { - handshakeState.start() - - val message = ByteArray(8192) - val plaintext = ByteArray(8192) - - while (_started) { - when (handshakeState.action) { - HandshakeState.READ_MESSAGE -> { - val messageSize = _inputStream.readInt() - Logger.i(TAG, "Handshake read message (size = ${messageSize})") - - var bytesRead = 0 - while (bytesRead < messageSize) { - val read = _inputStream.read(message, bytesRead, messageSize - bytesRead) - if (read == -1) - throw Exception("Stream closed") - bytesRead += read - } - - handshakeState.readMessage(message, 0, messageSize, plaintext, 0) - } - HandshakeState.WRITE_MESSAGE -> { - val messageSize = handshakeState.writeMessage(message, 0, null, 0, 0) - Logger.i(TAG, "Handshake wrote message (size = ${messageSize})") - _outputStream.writeInt(messageSize) - _outputStream.write(message, 0, messageSize) - } - HandshakeState.SPLIT -> { - //Logger.i(TAG, "Handshake split") - return handshakeState.split() - } - else -> throw Exception("Unexpected state (handshakeState.action = ${handshakeState.action})") - } - } - - throw Exception("Handshake finished without completing") - } + fun generateStreamId(): Int = synchronized(_streamIdGeneratorLock) { _streamIdGenerator++ } + private fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ } fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { ensureNotMainThread() @@ -219,37 +303,35 @@ class SyncSocketSession { val segmentSize = MAXIMUM_PACKET_SIZE - HEADER_SIZE val segmentData = ByteArray(segmentSize) var sendOffset = 0 - val id = synchronized(_streamIdGeneratorLock) { - _streamIdGenerator + 1 - } + val id = generateStreamId() while (sendOffset < data.remaining()) { val bytesRemaining = data.remaining() - sendOffset var bytesToSend: Int var segmentPacketSize: Int - val segmentOpcode: UByte + val streamOp: StreamOpcode if (sendOffset == 0) { - segmentOpcode = Opcode.STREAM_START.value + streamOp = StreamOpcode.START bytesToSend = segmentSize - 4 - 4 - 1 - 1 segmentPacketSize = bytesToSend + 4 + 4 + 1 + 1 } else { bytesToSend = minOf(segmentSize - 4 - 4, bytesRemaining) - segmentOpcode = if (bytesToSend >= bytesRemaining) Opcode.STREAM_END.value else Opcode.STREAM_DATA.value + streamOp = if (bytesToSend >= bytesRemaining) StreamOpcode.END else StreamOpcode.DATA segmentPacketSize = bytesToSend + 4 + 4 } ByteBuffer.wrap(segmentData).order(ByteOrder.LITTLE_ENDIAN).apply { putInt(id) - putInt(if (segmentOpcode == Opcode.STREAM_START.value) data.remaining() else sendOffset) - if (segmentOpcode == Opcode.STREAM_START.value) { + putInt(if (streamOp == StreamOpcode.START) data.remaining() else sendOffset) + if (streamOp == StreamOpcode.START) { put(opcode.toByte()) put(subOpcode.toByte()) } put(data.array(), data.position() + sendOffset, bytesToSend) } - send(segmentOpcode, 0u, ByteBuffer.wrap(segmentData, 0, segmentPacketSize)) + send(Opcode.STREAM.value, streamOp.value, ByteBuffer.wrap(segmentData, 0, segmentPacketSize)) sendOffset += bytesToSend } } else { @@ -270,6 +352,7 @@ class SyncSocketSession { } } + @OptIn(ExperimentalUnsignedTypes::class) fun send(opcode: UByte, subOpcode: UByte = 0u) { ensureNotMainThread() @@ -288,108 +371,806 @@ class SyncSocketSession { } } - private fun handleData(data: ByteArray, length: Int) { + private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) { + return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel) + } + + private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) { + val length = data.remaining() if (length < HEADER_SIZE) throw Exception("Packet must be at least 6 bytes (header size)") - val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int + val size = data.int if (size != length - 4) throw Exception("Incomplete packet received") - val opcode = data.asUByteArray()[4] - val subOpcode = data.asUByteArray()[5] - val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2) - handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN)) + val opcode = data.get().toUByte() + val subOpcode = data.get().toUByte() + handlePacket(opcode, subOpcode, data, sourceChannel) } - private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { + when (subOpcode) { + RequestOpcode.TRANSPORT_RELAYED.value -> { + Logger.i(TAG, "Received request for a relayed transport") + if (data.remaining() < 52) { + Logger.e(TAG, "HandleRequestTransport: Packet too short") + return + } + val remoteVersion = data.int + val connectionId = data.long + val requestId = data.int + val publicKeyBytes = ByteArray(32).also { data.get(it) } + val pairingMessageLength = data.int + if (pairingMessageLength > 128) throw IllegalArgumentException("Pairing message length ($pairingMessageLength) exceeds maximum (128)") + val pairingMessage = if (pairingMessageLength > 0) ByteArray(pairingMessageLength).also { data.get(it) } else ByteArray(0) + val channelMessageLength = data.int + if (data.remaining() != channelMessageLength) { + Logger.e(TAG, "Invalid packet size. Expected ${52 + pairingMessageLength + 4 + channelMessageLength}, got ${data.capacity()}") + return + } + val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) } + val publicKey = Base64.getEncoder().encodeToString(publicKeyBytes) + val pairingCode = if (pairingMessageLength > 0) { + val pairingProtocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply { + localKeyPair.copyFrom(_localKeyPair) + start() + } + val plaintext = ByteArray(1024) + val length = pairingProtocol.readMessage(pairingMessage, 0, pairingMessageLength, plaintext, 0) + String(plaintext, 0, length, Charsets.UTF_8) + } else null + val isAllowed = publicKey != _localPublicKey && (_isHandshakeAllowed?.invoke(this, publicKey, pairingCode) ?: true) + if (!isAllowed) { + val rp = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN) + rp.putInt(2) // Status code for not allowed + rp.putLong(connectionId) + rp.putInt(requestId) + rp.rewind() + send(Opcode.RESPONSE.value, ResponseOpcode.TRANSPORT.value, rp) + return + } + val channel = ChannelRelayed(this, _localKeyPair, publicKey, false) + channel.connectionId = connectionId + _onNewChannel?.invoke(this, channel) + _channels[connectionId] = channel + channel.sendResponseTransport(remoteVersion, requestId, channelHandshakeMessage) + _onChannelEstablished?.invoke(this, channel, true) + } + else -> Logger.w(TAG, "Unhandled request opcode: $subOpcode") + } + } + + private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { + if (data.remaining() < 8) { + Logger.e(TAG, "Response packet too short") + return + } + val requestId = data.int + val statusCode = data.int + when (subOpcode) { + ResponseOpcode.CONNECTION_INFO.value -> { + _pendingConnectionInfoRequests.remove(requestId)?.let { tcs -> + if (statusCode == 0) { + try { + val connectionInfo = parseConnectionInfo(data) + tcs.complete(connectionInfo) + } catch (e: Exception) { + tcs.completeExceptionally(e) + } + } else { + tcs.complete(null) + } + } ?: Logger.e(TAG, "No pending request for requestId $requestId") + } + ResponseOpcode.TRANSPORT_RELAYED.value -> { + if (statusCode == 0) { + if (data.remaining() < 16) { + Logger.e(TAG, "RESPONSE_TRANSPORT packet too short") + return + } + val remoteVersion = data.int + val connectionId = data.long + val messageLength = data.int + if (data.remaining() != messageLength) { + Logger.e(TAG, "Invalid RESPONSE_TRANSPORT packet size. Expected ${16 + messageLength}, got ${data.remaining() + 16}") + return + } + val handshakeMessage = ByteArray(messageLength).also { data.get(it) } + _pendingChannels.remove(requestId)?.let { (channel, tcs) -> + channel.handleTransportRelayed(remoteVersion, connectionId, handshakeMessage) + _channels[connectionId] = channel + tcs.complete(channel) + _onChannelEstablished?.invoke(this, channel, false) + } ?: Logger.e(TAG, "No pending channel for requestId $requestId") + } else { + _pendingChannels.remove(requestId)?.let { (channel, tcs) -> + channel.close() + tcs.completeExceptionally(Exception("Relayed transport request $requestId failed with code $statusCode")) + } + } + } + ResponseOpcode.PUBLISH_RECORD.value, ResponseOpcode.BULK_PUBLISH_RECORD.value -> { + _pendingPublishRequests.remove(requestId)?.complete(statusCode == 0) + ?: Logger.e(TAG, "No pending publish request for requestId $requestId") + } + ResponseOpcode.DELETE_RECORD.value, ResponseOpcode.BULK_DELETE_RECORD.value -> { + _pendingDeleteRequests.remove(requestId)?.complete(statusCode == 0) + ?: Logger.e(TAG, "No pending delete request for requestId $requestId") + } + ResponseOpcode.LIST_RECORD_KEYS.value -> { + _pendingListKeysRequests.remove(requestId)?.let { tcs -> + if (statusCode == 0) { + try { + val keyCount = data.int + val keys = mutableListOf>() + repeat(keyCount) { + val keyLength = data.get().toInt() + val key = ByteArray(keyLength).also { data.get(it) }.toString(Charsets.UTF_8) + val timestamp = data.long + keys.add(key to timestamp) + } + tcs.complete(keys) + } catch (e: Exception) { + tcs.completeExceptionally(e) + } + } else { + tcs.completeExceptionally(Exception("Error listing keys: status code $statusCode")) + } + } ?: Logger.e(TAG, "No pending list keys request for requestId $requestId") + } + ResponseOpcode.GET_RECORD.value -> { + _pendingGetRecordRequests.remove(requestId)?.let { tcs -> + if (statusCode == 0) { + try { + val blobLength = data.int + val encryptedBlob = ByteArray(blobLength).also { data.get(it) } + val timestamp = data.long + val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply { + localKeyPair.copyFrom(_localKeyPair) + start() + } + val handshakeMessage = encryptedBlob.copyOf(48) + val plaintext = ByteArray(0) + protocol.readMessage(handshakeMessage, 0, 48, plaintext, 0) + val transportPair = protocol.split() + var blobOffset = 48 + val chunks = mutableListOf() + while (blobOffset + 4 <= encryptedBlob.size) { + val chunkLength = ByteBuffer.wrap(encryptedBlob, blobOffset, 4).order(ByteOrder.LITTLE_ENDIAN).int + blobOffset += 4 + val encryptedChunk = encryptedBlob.copyOfRange(blobOffset, blobOffset + chunkLength) + val decryptedChunk = ByteArray(chunkLength - 16) + transportPair.receiver.decryptWithAd(null, encryptedChunk, 0, decryptedChunk, 0, encryptedChunk.size) + chunks.add(decryptedChunk) + blobOffset += chunkLength + } + val dataResult = chunks.reduce { acc, bytes -> acc + bytes } + tcs.complete(dataResult to timestamp) + } catch (e: Exception) { + tcs.completeExceptionally(e) + } + } else if (statusCode == 2) { + tcs.complete(null) + } else { + tcs.completeExceptionally(Exception("Error getting record: statusCode $statusCode")) + } + } + } + ResponseOpcode.BULK_GET_RECORD.value -> { + _pendingBulkGetRecordRequests.remove(requestId)?.let { tcs -> + if (statusCode == 0) { + try { + val recordCount = data.get().toInt() + val records = mutableMapOf>() + repeat(recordCount) { + val publisherBytes = ByteArray(32).also { data.get(it) } + val publisher = Base64.getEncoder().encodeToString(publisherBytes) + val blobLength = data.int + val encryptedBlob = ByteArray(blobLength).also { data.get(it) } + val timestamp = data.long + val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply { + localKeyPair.copyFrom(_localKeyPair) + start() + } + val handshakeMessage = encryptedBlob.copyOf(48) + val plaintext = ByteArray(0) + protocol.readMessage(handshakeMessage, 0, 48, plaintext, 0) + val transportPair = protocol.split() + var blobOffset = 48 + val chunks = mutableListOf() + while (blobOffset + 4 <= encryptedBlob.size) { + val chunkLength = ByteBuffer.wrap(encryptedBlob, blobOffset, 4).order(ByteOrder.LITTLE_ENDIAN).int + blobOffset += 4 + val encryptedChunk = encryptedBlob.copyOfRange(blobOffset, blobOffset + chunkLength) + val decryptedChunk = ByteArray(chunkLength - 16) + transportPair.receiver.decryptWithAd(null, encryptedChunk, 0, decryptedChunk, 0, encryptedChunk.size) + chunks.add(decryptedChunk) + blobOffset += chunkLength + } + val dataResult = chunks.reduce { acc, bytes -> acc + bytes } + records[publisher] = dataResult to timestamp + } + tcs.complete(records) + } catch (e: Exception) { + tcs.completeExceptionally(e) + } + } else { + tcs.completeExceptionally(Exception("Error getting bulk records: statusCode $statusCode")) + } + } + } + ResponseOpcode.BULK_CONNECTION_INFO.value -> { + _pendingBulkConnectionInfoRequests.remove(requestId)?.let { tcs -> + try { + val numResponses = data.get().toInt() + val result = mutableMapOf() + repeat(numResponses) { + val publicKey = Base64.getEncoder().encodeToString(ByteArray(32).also { data.get(it) }) + val status = data.get().toInt() + if (status == 0) { + val infoSize = data.int + val infoData = ByteArray(infoSize).also { data.get(it) } + result[publicKey] = parseConnectionInfo(ByteBuffer.wrap(infoData).order(ByteOrder.LITTLE_ENDIAN)) + } + } + tcs.complete(result) + } catch (e: Exception) { + tcs.completeExceptionally(e) + } + } ?: Logger.e(TAG, "No pending bulk request for requestId $requestId") + } + } + } + + private fun parseConnectionInfo(data: ByteBuffer): ConnectionInfo { + val ipSize = data.get().toInt() + val remoteIpBytes = ByteArray(ipSize).also { data.get(it) } + val remoteIp = remoteIpBytes.joinToString(".") { it.toUByte().toString() } + val handshakeMessage = ByteArray(48).also { data.get(it) } + val ciphertext = ByteArray(data.remaining()).also { data.get(it) } + val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply { + localKeyPair.copyFrom(_localKeyPair) + start() + } + val plaintext = ByteArray(0) + protocol.readMessage(handshakeMessage, 0, 48, plaintext, 0) + val transportPair = protocol.split() + val decryptedData = ByteArray(ciphertext.size - 16) + transportPair.receiver.decryptWithAd(null, ciphertext, 0, decryptedData, 0, ciphertext.size) + val info = ByteBuffer.wrap(decryptedData).order(ByteOrder.LITTLE_ENDIAN) + val port = info.short.toUShort() + val nameLength = info.get().toInt() + val name = ByteArray(nameLength).also { info.get(it) }.toString(Charsets.UTF_8) + val ipv4Count = info.get().toInt() + val ipv4Addresses = List(ipv4Count) { ByteArray(4).also { info.get(it) }.joinToString(".") { it.toUByte().toString() } } + val ipv6Count = info.get().toInt() + val ipv6Addresses = List(ipv6Count) { ByteArray(16).also { info.get(it) }.joinToString(":") { it.toUByte().toString(16).padStart(2, '0') } } + val allowLocalDirect = info.get() != 0.toByte() + val allowRemoteDirect = info.get() != 0.toByte() + val allowRemoteHolePunched = info.get() != 0.toByte() + val allowRemoteProxied = info.get() != 0.toByte() + return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied) + } + + private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { + when (subOpcode) { + NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> { + if (sourceChannel != null) + sourceChannel.invokeDataHandler(Opcode.NOTIFY.value, subOpcode, data) + else + _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data) + } + NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ } + } + } + + fun sendRelayError(connectionId: Long, errorCode: SyncErrorCode) { + val packet = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN) + packet.putLong(connectionId) + packet.putInt(errorCode.value) + packet.rewind() + send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet) + } + + private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { + when (subOpcode) { + RelayOpcode.RELAYED_DATA.value -> { + if (data.remaining() < 8) { + Logger.e(TAG, "RELAYED_DATA packet too short") + return + } + val connectionId = data.long + val channel = _channels[connectionId] ?: run { + Logger.e(TAG, "No channel found for connectionId $connectionId") + return + } + val decryptedPayload = channel.decrypt(data) + try { + handleData(decryptedPayload, channel) + } catch (e: Exception) { + Logger.e(TAG, "Exception while handling relayed data", e) + channel.sendError(SyncErrorCode.ConnectionClosed) + channel.close() + _channels.remove(connectionId) + } + } + RelayOpcode.RELAYED_ERROR.value -> { + if (data.remaining() < 8) { + Logger.e(TAG, "RELAYED_ERROR packet too short") + return + } + val connectionId = data.long + val channel = _channels[connectionId] ?: run { + Logger.e(TAG, "No channel found for connectionId $connectionId") + sendRelayError(connectionId, SyncErrorCode.NotFound) + return + } + val decryptedPayload = channel.decrypt(data) + val errorCode = SyncErrorCode.entries.find { it.value == decryptedPayload.int } ?: SyncErrorCode.ConnectionClosed + Logger.e(TAG, "Received relayed error (errorCode = $errorCode) on connectionId $connectionId, closing") + channel.close() + _channels.remove(connectionId) + } + RelayOpcode.RELAY_ERROR.value -> { + if (data.remaining() < 12) { + Logger.e(TAG, "RELAY_ERROR packet too short") + return + } + val connectionId = data.long + val errorCode = SyncErrorCode.entries.find { it.value == data.int } ?: SyncErrorCode.ConnectionClosed + val channel = _channels[connectionId] ?: run { + Logger.e(TAG, "Received error code $errorCode for non-existent channel with connectionId $connectionId") + return + } + Logger.i(TAG, "Received relay error (errorCode = $errorCode) on connectionId $connectionId, closing") + channel.close() + _channels.remove(connectionId) + _pendingChannels.entries.find { it.value.first == channel }?.let { + _pendingChannels.remove(it.key)?.second?.cancel() + } + } + } + } + + private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) { Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})") when (opcode) { Opcode.PING.value -> { - send(Opcode.PONG.value) + if (sourceChannel != null) + sourceChannel.send(Opcode.PONG.value) + else + send(Opcode.PONG.value) //Logger.i(TAG, "Received ping, sent pong") return } Opcode.PONG.value -> { - //Logger.i(TAG, "Received pong") + Logger.v(TAG, "Received pong") return } - Opcode.NOTIFY_AUTHORIZED.value, - Opcode.NOTIFY_UNAUTHORIZED.value -> { - _onData.invoke(this, opcode, subOpcode, data) + Opcode.REQUEST.value -> { + handleRequest(subOpcode, data, sourceChannel) return } - } - - if (authorizable?.isAuthorized != true) { - return - } - - when (opcode) { - Opcode.STREAM_START.value -> { - val id = data.int - val expectedSize = data.int - val op = data.get().toUByte() - val subOp = data.get().toUByte() - - val syncStream = SyncStream(expectedSize, op, subOp) - if (data.remaining() > 0) { - syncStream.add(data.array(), data.position(), data.remaining()) - } - - synchronized(_syncStreams) { - _syncStreams[id] = syncStream - } + Opcode.RESPONSE.value -> { + handleResponse(subOpcode, data, sourceChannel) + return } - Opcode.STREAM_DATA.value -> { - val id = data.int - val expectedOffset = data.int - - val syncStream = synchronized(_syncStreams) { - _syncStreams[id] ?: throw Exception("Received data for sync stream that does not exist") - } - - if (expectedOffset != syncStream.bytesReceived) { - throw Exception("Expected offset does not match the amount of received bytes") - } - - if (data.remaining() > 0) { - syncStream.add(data.array(), data.position(), data.remaining()) - } + Opcode.NOTIFY.value -> { + handleNotify(subOpcode, data, sourceChannel) + return } - Opcode.STREAM_END.value -> { - val id = data.int - val expectedOffset = data.int - - val syncStream = synchronized(_syncStreams) { - _syncStreams.remove(id) ?: throw Exception("Received data for sync stream that does not exist") - } - - if (expectedOffset != syncStream.bytesReceived) { - throw Exception("Expected offset does not match the amount of received bytes") - } - - if (data.remaining() > 0) { - syncStream.add(data.array(), data.position(), data.remaining()) - } - - if (!syncStream.isComplete) { - throw Exception("After sync stream end, the stream must be complete") - } - - handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }) + Opcode.RELAY.value -> { + handleRelay(subOpcode, data, sourceChannel) + return } - Opcode.DATA.value -> { - _onData.invoke(this, opcode, subOpcode, data) - } - else -> { - Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})") + else -> if (isAuthorized) when (opcode) { + Opcode.STREAM.value -> when (subOpcode) + { + StreamOpcode.START.value -> { + val id = data.int + val expectedSize = data.int + val op = data.get().toUByte() + val subOp = data.get().toUByte() + + val syncStream = SyncStream(expectedSize, op, subOp) + if (data.remaining() > 0) { + syncStream.add(data.array(), data.position(), data.remaining()) + } + + synchronized(_syncStreams) { + _syncStreams[id] = syncStream + } + } + StreamOpcode.DATA.value -> { + val id = data.int + val expectedOffset = data.int + + val syncStream = synchronized(_syncStreams) { + _syncStreams[id] ?: throw Exception("Received data for sync stream that does not exist") + } + + if (expectedOffset != syncStream.bytesReceived) { + throw Exception("Expected offset does not match the amount of received bytes") + } + + if (data.remaining() > 0) { + syncStream.add(data.array(), data.position(), data.remaining()) + } + } + StreamOpcode.END.value -> { + val id = data.int + val expectedOffset = data.int + + val syncStream = synchronized(_syncStreams) { + _syncStreams.remove(id) ?: throw Exception("Received data for sync stream that does not exist") + } + + if (expectedOffset != syncStream.bytesReceived) { + throw Exception("Expected offset does not match the amount of received bytes") + } + + if (data.remaining() > 0) { + syncStream.add(data.array(), data.position(), data.remaining()) + } + + if (!syncStream.isComplete) { + throw Exception("After sync stream end, the stream must be complete") + } + + handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel) + } + } + Opcode.DATA.value -> { + if (sourceChannel != null) + sourceChannel.invokeDataHandler(opcode, subOpcode, data) + else + _onData?.invoke(this, opcode, subOpcode, data) + } + else -> { + Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})") + } } } } + suspend fun requestConnectionInfo(publicKey: String): ConnectionInfo? { + val requestId = generateRequestId() + val deferred = CompletableDeferred() + _pendingConnectionInfoRequests[requestId] = deferred + try { + val publicKeyBytes = Base64.getDecoder().decode(publicKey) + if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes") + val packet = ByteBuffer.allocate(4 + 32).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(publicKeyBytes) + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.CONNECTION_INFO.value, packet) + } catch (e: Exception) { + _pendingConnectionInfoRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun requestBulkConnectionInfo(publicKeys: Array): Map { + val requestId = generateRequestId() + val deferred = CompletableDeferred>() + _pendingBulkConnectionInfoRequests[requestId] = deferred + try { + val packet = ByteBuffer.allocate(4 + 1 + publicKeys.size * 32).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(publicKeys.size.toByte()) + for (pk in publicKeys) { + val pkBytes = Base64.getDecoder().decode(pk) + if (pkBytes.size != 32) throw IllegalArgumentException("Invalid public key length for $pk") + packet.put(pkBytes) + } + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.BULK_CONNECTION_INFO.value, packet) + } catch (e: Exception) { + _pendingBulkConnectionInfoRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun startRelayedChannel(publicKey: String, pairingCode: String? = null): ChannelRelayed? { + val requestId = generateRequestId() + val deferred = CompletableDeferred() + val channel = ChannelRelayed(this, _localKeyPair, publicKey, true) + _onNewChannel?.invoke(this, channel) + _pendingChannels[requestId] = channel to deferred + try { + channel.sendRequestTransport(requestId, publicKey, pairingCode) + } catch (e: Exception) { + _pendingChannels.remove(requestId)?.let { it.first.close(); it.second.completeExceptionally(e) } + throw e + } + return deferred.await() + } + + private fun getDeviceName(): String { + val manufacturer = Build.MANUFACTURER.replaceFirstChar { if (it.isLowerCase()) it.titlecase( + Locale.getDefault()) else it.toString() } + val model = Build.MODEL + + return if (model.startsWith(manufacturer, ignoreCase = true)) { + model.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() } + } else { + "$manufacturer $model".replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() } + } + } + + private fun getLimitedUtf8Bytes(str: String, maxByteLength: Int): ByteArray { + val bytes = str.toByteArray(Charsets.UTF_8) + if (bytes.size <= maxByteLength) return bytes + + var truncateAt = maxByteLength + while (truncateAt > 0 && (bytes[truncateAt].toInt() and 0xC0) == 0x80) { + truncateAt-- + } + return bytes.copyOf(truncateAt) + } + + fun publishConnectionInformation( + authorizedKeys: Array, + port: Int, + allowLocalDirect: Boolean, + allowRemoteDirect: Boolean, + allowRemoteHolePunched: Boolean, + allowRemoteProxied: Boolean + ) { + if (authorizedKeys.size > 255) throw IllegalArgumentException("Number of authorized keys exceeds 255") + + val ipv4Addresses = mutableListOf() + val ipv6Addresses = mutableListOf() + for (nic in NetworkInterface.getNetworkInterfaces()) { + if (nic.isUp) { + for (addr in nic.inetAddresses) { + if (!addr.isLoopbackAddress) { + when (addr) { + is Inet4Address -> ipv4Addresses.add(addr.hostAddress) + is Inet6Address -> ipv6Addresses.add(addr.hostAddress) + } + } + } + } + } + + val deviceName = getDeviceName() + val nameBytes = getLimitedUtf8Bytes(deviceName, 255) + + val blobSize = 2 + 1 + nameBytes.size + 1 + ipv4Addresses.size * 4 + 1 + ipv6Addresses.size * 16 + 1 + 1 + 1 + 1 + val data = ByteBuffer.allocate(blobSize).order(ByteOrder.LITTLE_ENDIAN) + data.putShort(port.toShort()) + data.put(nameBytes.size.toByte()) + data.put(nameBytes) + data.put(ipv4Addresses.size.toByte()) + for (addr in ipv4Addresses) { + val addrBytes = InetAddress.getByName(addr).address + data.put(addrBytes) + } + data.put(ipv6Addresses.size.toByte()) + for (addr in ipv6Addresses) { + val addrBytes = InetAddress.getByName(addr).address + data.put(addrBytes) + } + data.put(if (allowLocalDirect) 1 else 0) + data.put(if (allowRemoteDirect) 1 else 0) + data.put(if (allowRemoteHolePunched) 1 else 0) + data.put(if (allowRemoteProxied) 1 else 0) + + val handshakeSize = 48 // Noise handshake size for N pattern + + data.rewind() + val ciphertextSize = data.remaining() + 16 // Encrypted data size + val totalSize = 1 + authorizedKeys.size * (32 + handshakeSize + 4 + ciphertextSize) + val publishBytes = ByteBuffer.allocate(totalSize).order(ByteOrder.LITTLE_ENDIAN) + publishBytes.put(authorizedKeys.size.toByte()) + + for (key in authorizedKeys) { + val publicKeyBytes = Base64.getDecoder().decode(key) + if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes") + + val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR) + protocol.remotePublicKey.setPublicKey(publicKeyBytes, 0) + protocol.start() + + val handshakeMessage = ByteArray(handshakeSize) + val handshakeBytesWritten = protocol.writeMessage(handshakeMessage, 0, null, 0, 0) + if (handshakeBytesWritten != handshakeSize) throw IllegalStateException("Handshake message size mismatch") + + val transportPair = protocol.split() + + publishBytes.put(publicKeyBytes) + publishBytes.put(handshakeMessage) + + val ciphertext = ByteArray(ciphertextSize) + val ciphertextBytesWritten = transportPair.sender.encryptWithAd(null, data.array(), data.position(), ciphertext, 0, data.remaining()) + if (ciphertextBytesWritten != ciphertextSize) throw IllegalStateException("Ciphertext size mismatch") + + publishBytes.putInt(ciphertextBytesWritten) + publishBytes.put(ciphertext, 0, ciphertextBytesWritten) + } + + publishBytes.rewind() + send(Opcode.NOTIFY.value, NotifyOpcode.CONNECTION_INFO.value, publishBytes) + } + + suspend fun publishRecords(consumerPublicKeys: List, key: String, data: ByteArray): Boolean { + val keyBytes = key.toByteArray(Charsets.UTF_8) + if (key.isEmpty() || keyBytes.size > 32) throw IllegalArgumentException("Key must be 1-32 bytes") + if (consumerPublicKeys.isEmpty()) throw IllegalArgumentException("At least one consumer required") + val requestId = generateRequestId() + val deferred = CompletableDeferred() + _pendingPublishRequests[requestId] = deferred + try { + val MAX_PLAINTEXT_SIZE = 65535 + val HANDSHAKE_SIZE = 48 + val LENGTH_SIZE = 4 + val TAG_SIZE = 16 + val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE + + var blobSize = HANDSHAKE_SIZE + var dataOffset = 0 + for (i in 0 until chunkCount) { + val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset) + blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE) + dataOffset += chunkSize + } + + val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize) + val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(keyBytes.size.toByte()) + packet.put(keyBytes) + packet.put(consumerPublicKeys.size.toByte()) + + for (consumer in consumerPublicKeys) { + val consumerBytes = Base64.getDecoder().decode(consumer) + if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes") + packet.put(consumerBytes) + val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply { + remotePublicKey.setPublicKey(consumerBytes, 0) + start() + } + val handshakeMessage = ByteArray(HANDSHAKE_SIZE) + protocol.writeMessage(handshakeMessage, 0, null, 0, 0) + val transportPair = protocol.split() + packet.putInt(blobSize) + packet.put(handshakeMessage) + + dataOffset = 0 + for (i in 0 until chunkCount) { + val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset) + val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize) + val ciphertext = ByteArray(chunkSize + TAG_SIZE) + val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size) + packet.putInt(written) + packet.put(ciphertext, 0, written) + dataOffset += chunkSize + } + } + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.BULK_PUBLISH_RECORD.value, packet) + } catch (e: Exception) { + _pendingPublishRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun getRecord(publisherPublicKey: String, key: String): Pair? { + if (key.isEmpty() || key.length > 32) throw IllegalArgumentException("Key must be 1-32 bytes") + val requestId = generateRequestId() + val deferred = CompletableDeferred?>() + _pendingGetRecordRequests[requestId] = deferred + try { + val publisherBytes = Base64.getDecoder().decode(publisherPublicKey) + if (publisherBytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes") + val keyBytes = key.toByteArray(Charsets.UTF_8) + val packet = ByteBuffer.allocate(4 + 32 + 1 + keyBytes.size).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(publisherBytes) + packet.put(keyBytes.size.toByte()) + packet.put(keyBytes) + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.GET_RECORD.value, packet) + } catch (e: Exception) { + _pendingGetRecordRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun getRecords(publisherPublicKeys: List, key: String): Map> { + if (key.isEmpty() || key.length > 32) throw IllegalArgumentException("Key must be 1-32 bytes") + if (publisherPublicKeys.isEmpty()) return emptyMap() + val requestId = generateRequestId() + val deferred = CompletableDeferred>>() + _pendingBulkGetRecordRequests[requestId] = deferred + try { + val keyBytes = key.toByteArray(Charsets.UTF_8) + val packet = ByteBuffer.allocate(4 + 1 + keyBytes.size + 1 + publisherPublicKeys.size * 32).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(keyBytes.size.toByte()) + packet.put(keyBytes) + packet.put(publisherPublicKeys.size.toByte()) + for (publisher in publisherPublicKeys) { + val bytes = Base64.getDecoder().decode(publisher) + if (bytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes") + packet.put(bytes) + } + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.BULK_GET_RECORD.value, packet) + } catch (e: Exception) { + _pendingBulkGetRecordRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun deleteRecords(publisherPublicKey: String, consumerPublicKey: String, keys: List): Boolean { + if (keys.any { it.toByteArray(Charsets.UTF_8).size > 32 }) throw IllegalArgumentException("Keys must be at most 32 bytes") + val requestId = generateRequestId() + val deferred = CompletableDeferred() + _pendingDeleteRequests[requestId] = deferred + try { + val publisherBytes = Base64.getDecoder().decode(publisherPublicKey) + if (publisherBytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes") + val consumerBytes = Base64.getDecoder().decode(consumerPublicKey) + if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes") + val packetSize = 4 + 32 + 32 + 1 + keys.sumOf { 1 + it.toByteArray(Charsets.UTF_8).size } + val packet = ByteBuffer.allocate(packetSize).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(publisherBytes) + packet.put(consumerBytes) + packet.put(keys.size.toByte()) + for (key in keys) { + val keyBytes = key.toByteArray(Charsets.UTF_8) + packet.put(keyBytes.size.toByte()) + packet.put(keyBytes) + } + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.BULK_DELETE_RECORD.value, packet) + } catch (e: Exception) { + _pendingDeleteRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + + suspend fun listRecordKeys(publisherPublicKey: String, consumerPublicKey: String): List> { + val requestId = generateRequestId() + val deferred = CompletableDeferred>>() + _pendingListKeysRequests[requestId] = deferred + try { + val publisherBytes = Base64.getDecoder().decode(publisherPublicKey) + if (publisherBytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes") + val consumerBytes = Base64.getDecoder().decode(consumerPublicKey) + if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes") + val packet = ByteBuffer.allocate(4 + 32 + 32).order(ByteOrder.LITTLE_ENDIAN) + packet.putInt(requestId) + packet.put(publisherBytes) + packet.put(consumerBytes) + packet.rewind() + send(Opcode.REQUEST.value, RequestOpcode.LIST_RECORD_KEYS.value, packet) + } catch (e: Exception) { + _pendingListKeysRequests.remove(requestId)?.completeExceptionally(e) + throw e + } + return deferred.await() + } + companion object { + val dh = "25519" + val pattern = "N" + val cipher = "ChaChaPoly" + val hash = "BLAKE2b" + var nProtocolName = "Noise_${pattern}_${dh}_${cipher}_${hash}" + private const val TAG = "SyncSocketSession" const val MAXIMUM_PACKET_SIZE = 65535 - 16 const val MAXIMUM_PACKET_SIZE_ENCRYPTED = MAXIMUM_PACKET_SIZE + 16