diff --git a/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt b/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt index 25d2bcd1..865af56f 100644 --- a/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt +++ b/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt @@ -4,10 +4,12 @@ import android.app.AlertDialog import android.content.ContentResolver import android.content.Context import android.net.Uri +import android.net.nsd.NsdManager +import android.net.nsd.NsdServiceInfo +import android.os.Build import android.os.Looper import android.util.Base64 import android.util.Log -import android.util.Xml import androidx.annotation.OptIn import androidx.media3.common.util.UnstableApi import com.futo.platformplayer.R @@ -40,8 +42,6 @@ import com.futo.platformplayer.constructs.Event1 import com.futo.platformplayer.constructs.Event2 import com.futo.platformplayer.exceptions.UnsupportedCastException import com.futo.platformplayer.logging.Logger -import com.futo.platformplayer.mdns.DnsService -import com.futo.platformplayer.mdns.ServiceDiscoverer import com.futo.platformplayer.models.CastingDeviceInfo import com.futo.platformplayer.parsers.HLS import com.futo.platformplayer.states.StateApp @@ -55,7 +55,6 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json -import java.io.ByteArrayInputStream import java.net.InetAddress import java.net.URLDecoder import java.net.URLEncoder @@ -84,48 +83,15 @@ class StateCasting { private var _audioExecutor: JSRequestExecutor? = null private val _client = ManagedHttpClient(); var _resumeCastingDevice: CastingDeviceInfo? = null; - val _serviceDiscoverer = ServiceDiscoverer(arrayOf( - "_googlecast._tcp.local", - "_airplay._tcp.local", - "_fastcast._tcp.local", - "_fcast._tcp.local" - )) { handleServiceUpdated(it) } - + private var _nsdManager: NsdManager? = null val isCasting: Boolean get() = activeDevice != null; - private fun handleServiceUpdated(services: List) { - for (s in services) { - //TODO: Addresses IPv4 only? - val addresses = s.addresses.toTypedArray() - val port = s.port.toInt() - var name = s.texts.firstOrNull { it.startsWith("md=") }?.substring("md=".length) - if (s.name.endsWith("._googlecast._tcp.local")) { - if (name == null) { - name = s.name.substring(0, s.name.length - "._googlecast._tcp.local".length) - } - - addOrUpdateChromeCastDevice(name, addresses, port) - } else if (s.name.endsWith("._airplay._tcp.local")) { - if (name == null) { - name = s.name.substring(0, s.name.length - "._airplay._tcp.local".length) - } - - addOrUpdateAirPlayDevice(name, addresses, port) - } else if (s.name.endsWith("._fastcast._tcp.local")) { - if (name == null) { - name = s.name.substring(0, s.name.length - "._fastcast._tcp.local".length) - } - - addOrUpdateFastCastDevice(name, addresses, port) - } else if (s.name.endsWith("._fcast._tcp.local")) { - if (name == null) { - name = s.name.substring(0, s.name.length - "._fcast._tcp.local".length) - } - - addOrUpdateFastCastDevice(name, addresses, port) - } - } - } + private val _discoveryListeners = mapOf( + "_googlecast._tcp" to createDiscoveryListener(::addOrUpdateChromeCastDevice), + "_airplay._tcp" to createDiscoveryListener(::addOrUpdateAirPlayDevice), + "_fastcast._tcp" to createDiscoveryListener(::addOrUpdateFastCastDevice), + "_fcast._tcp" to createDiscoveryListener(::addOrUpdateFastCastDevice) + ) fun handleUrl(context: Context, url: String) { val uri = Uri.parse(url) @@ -197,23 +163,25 @@ class StateCasting { enableDeveloper(true); Logger.i(TAG, "CastingService started."); + + _nsdManager = context.getSystemService(Context.NSD_SERVICE) as NsdManager } @Synchronized fun startDiscovering() { - try { - _serviceDiscoverer.start() - } catch (e: Throwable) { - Logger.i(TAG, "Failed to start ServiceDiscoverer", e) + _nsdManager?.apply { + _discoveryListeners.forEach { + discoverServices(it.key, NsdManager.PROTOCOL_DNS_SD, it.value) + } } } @Synchronized fun stopDiscovering() { - try { - _serviceDiscoverer.stop() - } catch (e: Throwable) { - Logger.i(TAG, "Failed to stop ServiceDiscoverer", e) + _nsdManager?.apply { + _discoveryListeners.forEach { + stopServiceDiscovery(it.value) + } } } @@ -239,6 +207,77 @@ class StateCasting { _castServer.removeAllHandlers(); Logger.i(TAG, "CastingService stopped.") + + _nsdManager = null + } + + private fun createDiscoveryListener(addOrUpdate: (String, Array, Int) -> Unit): NsdManager.DiscoveryListener { + return object : NsdManager.DiscoveryListener { + override fun onDiscoveryStarted(regType: String) { + Log.d(TAG, "Service discovery started for $regType") + } + + override fun onDiscoveryStopped(serviceType: String) { + Log.i(TAG, "Discovery stopped: $serviceType") + } + + override fun onServiceLost(service: NsdServiceInfo) { + Log.e(TAG, "service lost: $service") + // TODO: Handle service lost, e.g., remove device + } + + override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) { + Log.e(TAG, "Discovery failed for $serviceType: Error code:$errorCode") + _nsdManager?.stopServiceDiscovery(this) + } + + override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) { + Log.e(TAG, "Stop discovery failed for $serviceType: Error code:$errorCode") + _nsdManager?.stopServiceDiscovery(this) + } + + override fun onServiceFound(service: NsdServiceInfo) { + Log.v(TAG, "Service discovery success for ${service.serviceType}: $service") + val addresses = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + service.hostAddresses.toTypedArray() + } else { + arrayOf(service.host) + } + addOrUpdate(service.serviceName, addresses, service.port) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + _nsdManager?.registerServiceInfoCallback(service, { it.run() }, object : NsdManager.ServiceInfoCallback { + override fun onServiceUpdated(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "onServiceUpdated: $serviceInfo") + addOrUpdate(serviceInfo.serviceName, serviceInfo.hostAddresses.toTypedArray(), serviceInfo.port) + } + + override fun onServiceLost() { + Log.v(TAG, "onServiceLost: $service") + // TODO: Handle service lost + } + + override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) { + Log.v(TAG, "onServiceInfoCallbackRegistrationFailed: $errorCode") + } + + override fun onServiceInfoCallbackUnregistered() { + Log.v(TAG, "onServiceInfoCallbackUnregistered") + } + }) + } else { + _nsdManager?.resolveService(service, object : NsdManager.ResolveListener { + override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) { + Log.v(TAG, "Resolve failed: $errorCode") + } + + override fun onServiceResolved(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "Resolve Succeeded: $serviceInfo") + addOrUpdate(serviceInfo.serviceName, arrayOf(serviceInfo.host), serviceInfo.port) + } + }) + } + } + } } private val _castingDialogLock = Any(); diff --git a/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt b/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt deleted file mode 100644 index ac3c61e0..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.futo.platformplayer.mdns - -data class BroadcastService( - val deviceName: String, - val serviceName: String, - val port: UShort, - val ttl: UInt, - val weight: UShort, - val priority: UShort, - val texts: List? = null -) \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt deleted file mode 100644 index 2c27edf8..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt +++ /dev/null @@ -1,93 +0,0 @@ -package com.futo.platformplayer.mdns - -import java.nio.ByteBuffer -import java.nio.ByteOrder - -enum class QueryResponse(val value: Byte) { - Query(0), - Response(1) -} - -enum class DnsOpcode(val value: Byte) { - StandardQuery(0), - InverseQuery(1), - ServerStatusRequest(2) -} - -enum class DnsResponseCode(val value: Byte) { - NoError(0), - FormatError(1), - ServerFailure(2), - NameError(3), - NotImplemented(4), - Refused(5) -} - -data class DnsPacketHeader( - val identifier: UShort, - val queryResponse: Int, - val opcode: Int, - val authoritativeAnswer: Boolean, - val truncated: Boolean, - val recursionDesired: Boolean, - val recursionAvailable: Boolean, - val answerAuthenticated: Boolean, - val nonAuthenticatedData: Boolean, - val responseCode: DnsResponseCode -) - -data class DnsPacket( - val header: DnsPacketHeader, - val questions: List, - val answers: List, - val authorities: List, - val additionals: List -) { - companion object { - fun parse(data: ByteArray): DnsPacket { - val span = data.asUByteArray() - val flags = (span[2].toInt() shl 8 or span[3].toInt()).toUShort() - val questionCount = (span[4].toInt() shl 8 or span[5].toInt()).toUShort() - val answerCount = (span[6].toInt() shl 8 or span[7].toInt()).toUShort() - val authorityCount = (span[8].toInt() shl 8 or span[9].toInt()).toUShort() - val additionalCount = (span[10].toInt() shl 8 or span[11].toInt()).toUShort() - - var position = 12 - - val questions = List(questionCount.toInt()) { - DnsQuestion.parse(data, position).also { position = it.second } - }.map { it.first } - - val answers = List(answerCount.toInt()) { - DnsResourceRecord.parse(data, position).also { position = it.second } - }.map { it.first } - - val authorities = List(authorityCount.toInt()) { - DnsResourceRecord.parse(data, position).also { position = it.second } - }.map { it.first } - - val additionals = List(additionalCount.toInt()) { - DnsResourceRecord.parse(data, position).also { position = it.second } - }.map { it.first } - - return DnsPacket( - header = DnsPacketHeader( - identifier = (span[0].toInt() shl 8 or span[1].toInt()).toUShort(), - queryResponse = ((flags.toUInt() shr 15) and 0b1u).toInt(), - opcode = ((flags.toUInt() shr 11) and 0b1111u).toInt(), - authoritativeAnswer = (flags.toInt() shr 10) and 0b1 != 0, - truncated = (flags.toInt() shr 9) and 0b1 != 0, - recursionDesired = (flags.toInt() shr 8) and 0b1 != 0, - recursionAvailable = (flags.toInt() shr 7) and 0b1 != 0, - answerAuthenticated = (flags.toInt() shr 5) and 0b1 != 0, - nonAuthenticatedData = (flags.toInt() shr 4) and 0b1 != 0, - responseCode = DnsResponseCode.entries[flags.toInt() and 0b1111] - ), - questions = questions, - answers = answers, - authorities = authorities, - additionals = additionals - ) - } - } -} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt deleted file mode 100644 index 01a7bd77..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt +++ /dev/null @@ -1,110 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.mdns.Extensions.readDomainName -import java.nio.ByteBuffer -import java.nio.ByteOrder - - -enum class QuestionType(val value: UShort) { - A(1u), - NS(2u), - MD(3u), - MF(4u), - CNAME(5u), - SOA(6u), - MB(7u), - MG(8u), - MR(9u), - NULL(10u), - WKS(11u), - PTR(12u), - HINFO(13u), - MINFO(14u), - MX(15u), - TXT(16u), - RP(17u), - AFSDB(18u), - SIG(24u), - KEY(25u), - AAAA(28u), - LOC(29u), - SRV(33u), - NAPTR(35u), - KX(36u), - CERT(37u), - DNAME(39u), - APL(42u), - DS(43u), - SSHFP(44u), - IPSECKEY(45u), - RRSIG(46u), - NSEC(47u), - DNSKEY(48u), - DHCID(49u), - NSEC3(50u), - NSEC3PARAM(51u), - TSLA(52u), - SMIMEA(53u), - HIP(55u), - CDS(59u), - CDNSKEY(60u), - OPENPGPKEY(61u), - CSYNC(62u), - ZONEMD(63u), - SVCB(64u), - HTTPS(65u), - EUI48(108u), - EUI64(109u), - TKEY(249u), - TSIG(250u), - URI(256u), - CAA(257u), - TA(32768u), - DLV(32769u), - AXFR(252u), - IXFR(251u), - OPT(41u), - MAILB(253u), - MALA(254u), - All(252u) -} - -enum class QuestionClass(val value: UShort) { - IN(1u), - CS(2u), - CH(3u), - HS(4u), - All(255u) -} - -data class DnsQuestion( - override val name: String, - override val type: Int, - override val clazz: Int, - val queryUnicast: Boolean -) : DnsResourceRecordBase(name, type, clazz) { - companion object { - fun parse(data: ByteArray, startPosition: Int): Pair { - val span = data.asUByteArray() - var position = startPosition - val qname = span.readDomainName(position).also { position = it.second } - val qtype = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() - position += 2 - val qclass = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() - position += 2 - - return DnsQuestion( - name = qname.first, - type = qtype.toInt(), - queryUnicast = ((qclass.toInt() shr 15) and 0b1) != 0, - clazz = qclass.toInt() and 0b111111111111111 - ) to position - } - } -} - -open class DnsResourceRecordBase( - open val name: String, - open val type: Int, - open val clazz: Int -) diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt deleted file mode 100644 index 83c329ff..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt +++ /dev/null @@ -1,514 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.mdns.Extensions.readDomainName -import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.nio.charset.StandardCharsets -import kotlin.math.pow -import java.net.InetAddress - -data class PTRRecord(val domainName: String) - -data class ARecord(val address: InetAddress) - -data class AAAARecord(val address: InetAddress) - -data class MXRecord(val preference: UShort, val exchange: String) - -data class CNAMERecord(val cname: String) - -data class TXTRecord(val texts: List) - -data class SOARecord( - val primaryNameServer: String, - val responsibleAuthorityMailbox: String, - val serialNumber: Int, - val refreshInterval: Int, - val retryInterval: Int, - val expiryLimit: Int, - val minimumTTL: Int -) - -data class SRVRecord(val priority: UShort, val weight: UShort, val port: UShort, val target: String) - -data class NSRecord(val nameServer: String) - -data class CAARecord(val flags: Byte, val tag: String, val value: String) - -data class HINFORecord(val cpu: String, val os: String) - -data class RPRecord(val mailbox: String, val txtDomainName: String) - - -data class AFSDBRecord(val subtype: UShort, val hostname: String) -data class LOCRecord( - val version: Byte, - val size: Double, - val horizontalPrecision: Double, - val verticalPrecision: Double, - val latitude: Double, - val longitude: Double, - val altitude: Double -) { - companion object { - fun decodeSizeOrPrecision(coded: Byte): Double { - val baseValue = (coded.toInt() shr 4) and 0x0F - val exponent = coded.toInt() and 0x0F - return baseValue * 10.0.pow(exponent.toDouble()) - } - - fun decodeLatitudeOrLongitude(coded: Int): Double { - val arcSeconds = coded / 1E3 - return arcSeconds / 3600.0 - } - - fun decodeAltitude(coded: Int): Double { - return (coded / 100.0) - 100000.0 - } - } -} - -data class NAPTRRecord( - val order: UShort, - val preference: UShort, - val flags: String, - val services: String, - val regexp: String, - val replacement: String -) - -data class RRSIGRecord( - val typeCovered: UShort, - val algorithm: Byte, - val labels: Byte, - val originalTTL: UInt, - val signatureExpiration: UInt, - val signatureInception: UInt, - val keyTag: UShort, - val signersName: String, - val signature: ByteArray -) - -data class KXRecord(val preference: UShort, val exchanger: String) - -data class CERTRecord(val type: UShort, val keyTag: UShort, val algorithm: Byte, val certificate: ByteArray) - - - -data class DNAMERecord(val target: String) - -data class DSRecord(val keyTag: UShort, val algorithm: Byte, val digestType: Byte, val digest: ByteArray) - -data class SSHFPRecord(val algorithm: Byte, val fingerprintType: Byte, val fingerprint: ByteArray) - -data class TLSARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray) - -data class SMIMEARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray) - -data class URIRecord(val priority: UShort, val weight: UShort, val target: String) - -data class NSECRecord(val ownerName: String, val typeBitMaps: List>) -data class NSEC3Record( - val hashAlgorithm: Byte, - val flags: Byte, - val iterations: UShort, - val salt: ByteArray, - val nextHashedOwnerName: ByteArray, - val typeBitMaps: List -) - -data class NSEC3PARAMRecord(val hashAlgorithm: Byte, val flags: Byte, val iterations: UShort, val salt: ByteArray) -data class SPFRecord(val texts: List) -data class TKEYRecord( - val algorithm: String, - val inception: UInt, - val expiration: UInt, - val mode: UShort, - val error: UShort, - val keyData: ByteArray, - val otherData: ByteArray -) - -data class TSIGRecord( - val algorithmName: String, - val timeSigned: UInt, - val fudge: UShort, - val mac: ByteArray, - val originalID: UShort, - val error: UShort, - val otherData: ByteArray -) - -data class OPTRecordOption(val code: UShort, val data: ByteArray) -data class OPTRecord(val options: List) - -class DnsReader(private val data: ByteArray, private var position: Int = 0, private val length: Int = data.size) { - - private val endPosition: Int = position + length - - fun readDomainName(): String { - return data.asUByteArray().readDomainName(position).also { position = it.second }.first - } - - fun readDouble(): Double { - checkRemainingBytes(Double.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).double - position += Double.SIZE_BYTES - return result - } - - fun readInt16(): Short { - checkRemainingBytes(Short.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short - position += Short.SIZE_BYTES - return result - } - - fun readInt32(): Int { - checkRemainingBytes(Int.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int - position += Int.SIZE_BYTES - return result - } - - fun readInt64(): Long { - checkRemainingBytes(Long.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long - position += Long.SIZE_BYTES - return result - } - - fun readSingle(): Float { - checkRemainingBytes(Float.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).float - position += Float.SIZE_BYTES - return result - } - - fun readByte(): Byte { - checkRemainingBytes(Byte.SIZE_BYTES) - return data[position++] - } - - fun readBytes(length: Int): ByteArray { - checkRemainingBytes(length) - return ByteArray(length).also { data.copyInto(it, startIndex = position, endIndex = position + length) } - .also { position += length } - } - - fun readUInt16(): UShort { - checkRemainingBytes(Short.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short.toUShort() - position += Short.SIZE_BYTES - return result - } - - fun readUInt32(): UInt { - checkRemainingBytes(Int.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int.toUInt() - position += Int.SIZE_BYTES - return result - } - - fun readUInt64(): ULong { - checkRemainingBytes(Long.SIZE_BYTES) - val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long.toULong() - position += Long.SIZE_BYTES - return result - } - - fun readString(): String { - val length = data[position++].toInt() - checkRemainingBytes(length) - return String(data, position, length, StandardCharsets.UTF_8).also { position += length } - } - - private fun checkRemainingBytes(requiredBytes: Int) { - if (position + requiredBytes > endPosition) throw IndexOutOfBoundsException() - } - - fun readRPRecord(): RPRecord { - return RPRecord(readDomainName(), readDomainName()) - } - - fun readKXRecord(): KXRecord { - val preference = readUInt16() - val exchanger = readDomainName() - return KXRecord(preference, exchanger) - } - - fun readCERTRecord(): CERTRecord { - val type = readUInt16() - val keyTag = readUInt16() - val algorithm = readByte() - val certificateLength = readUInt16().toInt() - 5 - val certificate = readBytes(certificateLength) - return CERTRecord(type, keyTag, algorithm, certificate) - } - - fun readPTRRecord(): PTRRecord { - return PTRRecord(readDomainName()) - } - - fun readARecord(): ARecord { - val address = readBytes(4) - return ARecord(InetAddress.getByAddress(address)) - } - - fun readAAAARecord(): AAAARecord { - val address = readBytes(16) - return AAAARecord(InetAddress.getByAddress(address)) - } - - fun readMXRecord(): MXRecord { - val preference = readUInt16() - val exchange = readDomainName() - return MXRecord(preference, exchange) - } - - fun readCNAMERecord(): CNAMERecord { - return CNAMERecord(readDomainName()) - } - - fun readTXTRecord(): TXTRecord { - val texts = mutableListOf() - while (position < endPosition) { - val textLength = data[position++].toInt() - checkRemainingBytes(textLength) - val text = String(data, position, textLength, StandardCharsets.UTF_8) - texts.add(text) - position += textLength - } - return TXTRecord(texts) - } - - fun readSOARecord(): SOARecord { - val primaryNameServer = readDomainName() - val responsibleAuthorityMailbox = readDomainName() - val serialNumber = readInt32() - val refreshInterval = readInt32() - val retryInterval = readInt32() - val expiryLimit = readInt32() - val minimumTTL = readInt32() - return SOARecord(primaryNameServer, responsibleAuthorityMailbox, serialNumber, refreshInterval, retryInterval, expiryLimit, minimumTTL) - } - - fun readSRVRecord(): SRVRecord { - val priority = readUInt16() - val weight = readUInt16() - val port = readUInt16() - val target = readDomainName() - return SRVRecord(priority, weight, port, target) - } - - fun readNSRecord(): NSRecord { - return NSRecord(readDomainName()) - } - - fun readCAARecord(): CAARecord { - val length = readUInt16().toInt() - val flags = readByte() - val tagLength = readByte().toInt() - val tag = String(data, position, tagLength, StandardCharsets.US_ASCII).also { position += tagLength } - val valueLength = length - 1 - 1 - tagLength - val value = String(data, position, valueLength, StandardCharsets.US_ASCII).also { position += valueLength } - return CAARecord(flags, tag, value) - } - - fun readHINFORecord(): HINFORecord { - val cpuLength = readByte().toInt() - val cpu = String(data, position, cpuLength, StandardCharsets.US_ASCII).also { position += cpuLength } - val osLength = readByte().toInt() - val os = String(data, position, osLength, StandardCharsets.US_ASCII).also { position += osLength } - return HINFORecord(cpu, os) - } - - fun readAFSDBRecord(): AFSDBRecord { - return AFSDBRecord(readUInt16(), readDomainName()) - } - - fun readLOCRecord(): LOCRecord { - val version = readByte() - val size = LOCRecord.decodeSizeOrPrecision(readByte()) - val horizontalPrecision = LOCRecord.decodeSizeOrPrecision(readByte()) - val verticalPrecision = LOCRecord.decodeSizeOrPrecision(readByte()) - val latitudeCoded = readInt32() - val longitudeCoded = readInt32() - val altitudeCoded = readInt32() - val latitude = LOCRecord.decodeLatitudeOrLongitude(latitudeCoded) - val longitude = LOCRecord.decodeLatitudeOrLongitude(longitudeCoded) - val altitude = LOCRecord.decodeAltitude(altitudeCoded) - return LOCRecord(version, size, horizontalPrecision, verticalPrecision, latitude, longitude, altitude) - } - - fun readNAPTRRecord(): NAPTRRecord { - val order = readUInt16() - val preference = readUInt16() - val flags = readString() - val services = readString() - val regexp = readString() - val replacement = readDomainName() - return NAPTRRecord(order, preference, flags, services, regexp, replacement) - } - - fun readDNAMERecord(): DNAMERecord { - return DNAMERecord(readDomainName()) - } - - fun readDSRecord(): DSRecord { - val keyTag = readUInt16() - val algorithm = readByte() - val digestType = readByte() - val digestLength = readUInt16().toInt() - 4 - val digest = readBytes(digestLength) - return DSRecord(keyTag, algorithm, digestType, digest) - } - - fun readSSHFPRecord(): SSHFPRecord { - val algorithm = readByte() - val fingerprintType = readByte() - val fingerprintLength = readUInt16().toInt() - 2 - val fingerprint = readBytes(fingerprintLength) - return SSHFPRecord(algorithm, fingerprintType, fingerprint) - } - - fun readTLSARecord(): TLSARecord { - val usage = readByte() - val selector = readByte() - val matchingType = readByte() - val dataLength = readUInt16().toInt() - 3 - val certificateAssociationData = readBytes(dataLength) - return TLSARecord(usage, selector, matchingType, certificateAssociationData) - } - - fun readSMIMEARecord(): SMIMEARecord { - val usage = readByte() - val selector = readByte() - val matchingType = readByte() - val dataLength = readUInt16().toInt() - 3 - val certificateAssociationData = readBytes(dataLength) - return SMIMEARecord(usage, selector, matchingType, certificateAssociationData) - } - - fun readURIRecord(): URIRecord { - val priority = readUInt16() - val weight = readUInt16() - val length = readUInt16().toInt() - val target = String(data, position, length, StandardCharsets.US_ASCII).also { position += length } - return URIRecord(priority, weight, target) - } - - fun readRRSIGRecord(): RRSIGRecord { - val typeCovered = readUInt16() - val algorithm = readByte() - val labels = readByte() - val originalTTL = readUInt32() - val signatureExpiration = readUInt32() - val signatureInception = readUInt32() - val keyTag = readUInt16() - val signersName = readDomainName() - val signatureLength = readUInt16().toInt() - val signature = readBytes(signatureLength) - return RRSIGRecord( - typeCovered, - algorithm, - labels, - originalTTL, - signatureExpiration, - signatureInception, - keyTag, - signersName, - signature - ) - } - - fun readNSECRecord(): NSECRecord { - val ownerName = readDomainName() - val typeBitMaps = mutableListOf>() - while (position < endPosition) { - val windowBlock = readByte() - val bitmapLength = readByte().toInt() - val bitmap = readBytes(bitmapLength) - typeBitMaps.add(windowBlock to bitmap) - } - return NSECRecord(ownerName, typeBitMaps) - } - - fun readNSEC3Record(): NSEC3Record { - val hashAlgorithm = readByte() - val flags = readByte() - val iterations = readUInt16() - val saltLength = readByte().toInt() - val salt = readBytes(saltLength) - val hashLength = readByte().toInt() - val nextHashedOwnerName = readBytes(hashLength) - val bitMapLength = readUInt16().toInt() - val typeBitMaps = mutableListOf() - val endPos = position + bitMapLength - while (position < endPos) { - typeBitMaps.add(readUInt16()) - } - return NSEC3Record(hashAlgorithm, flags, iterations, salt, nextHashedOwnerName, typeBitMaps) - } - - fun readNSEC3PARAMRecord(): NSEC3PARAMRecord { - val hashAlgorithm = readByte() - val flags = readByte() - val iterations = readUInt16() - val saltLength = readByte().toInt() - val salt = readBytes(saltLength) - return NSEC3PARAMRecord(hashAlgorithm, flags, iterations, salt) - } - - - fun readSPFRecord(): SPFRecord { - val length = readUInt16().toInt() - val texts = mutableListOf() - val endPos = position + length - while (position < endPos) { - val textLength = readByte().toInt() - val text = String(data, position, textLength, StandardCharsets.US_ASCII).also { position += textLength } - texts.add(text) - } - return SPFRecord(texts) - } - - fun readTKEYRecord(): TKEYRecord { - val algorithm = readDomainName() - val inception = readUInt32() - val expiration = readUInt32() - val mode = readUInt16() - val error = readUInt16() - val keySize = readUInt16().toInt() - val keyData = readBytes(keySize) - val otherSize = readUInt16().toInt() - val otherData = readBytes(otherSize) - return TKEYRecord(algorithm, inception, expiration, mode, error, keyData, otherData) - } - - fun readTSIGRecord(): TSIGRecord { - val algorithmName = readDomainName() - val timeSigned = readUInt32() - val fudge = readUInt16() - val macSize = readUInt16().toInt() - val mac = readBytes(macSize) - val originalID = readUInt16() - val error = readUInt16() - val otherSize = readUInt16().toInt() - val otherData = readBytes(otherSize) - return TSIGRecord(algorithmName, timeSigned, fudge, mac, originalID, error, otherData) - } - - - - fun readOPTRecord(): OPTRecord { - val options = mutableListOf() - while (position < endPosition) { - val optionCode = readUInt16() - val optionLength = readUInt16().toInt() - val optionData = readBytes(optionLength) - options.add(OPTRecordOption(optionCode, optionData)) - } - return OPTRecord(options) - } -} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt deleted file mode 100644 index 87ec0e5f..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt +++ /dev/null @@ -1,117 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.mdns.Extensions.readDomainName - -enum class ResourceRecordType(val value: UShort) { - None(0u), - A(1u), - NS(2u), - MD(3u), - MF(4u), - CNAME(5u), - SOA(6u), - MB(7u), - MG(8u), - MR(9u), - NULL(10u), - WKS(11u), - PTR(12u), - HINFO(13u), - MINFO(14u), - MX(15u), - TXT(16u), - RP(17u), - AFSDB(18u), - SIG(24u), - KEY(25u), - AAAA(28u), - LOC(29u), - SRV(33u), - NAPTR(35u), - KX(36u), - CERT(37u), - DNAME(39u), - APL(42u), - DS(43u), - SSHFP(44u), - IPSECKEY(45u), - RRSIG(46u), - NSEC(47u), - DNSKEY(48u), - DHCID(49u), - NSEC3(50u), - NSEC3PARAM(51u), - TSLA(52u), - SMIMEA(53u), - HIP(55u), - CDS(59u), - CDNSKEY(60u), - OPENPGPKEY(61u), - CSYNC(62u), - ZONEMD(63u), - SVCB(64u), - HTTPS(65u), - EUI48(108u), - EUI64(109u), - TKEY(249u), - TSIG(250u), - URI(256u), - CAA(257u), - TA(32768u), - DLV(32769u), - AXFR(252u), - IXFR(251u), - OPT(41u) -} - -enum class ResourceRecordClass(val value: UShort) { - IN(1u), - CS(2u), - CH(3u), - HS(4u) -} - -data class DnsResourceRecord( - override val name: String, - override val type: Int, - override val clazz: Int, - val timeToLive: UInt, - val cacheFlush: Boolean, - val dataPosition: Int = -1, - val dataLength: Int = -1, - private val data: ByteArray? = null -) : DnsResourceRecordBase(name, type, clazz) { - - companion object { - fun parse(data: ByteArray, startPosition: Int): Pair { - val span = data.asUByteArray() - var position = startPosition - val name = span.readDomainName(position).also { position = it.second } - val type = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() - position += 2 - val clazz = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() - position += 2 - val ttl = (span[position].toInt() shl 24 or (span[position + 1].toInt() shl 16) or - (span[position + 2].toInt() shl 8) or span[position + 3].toInt()).toUInt() - position += 4 - val rdlength = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() - val rdposition = position + 2 - position += 2 + rdlength.toInt() - - return DnsResourceRecord( - name = name.first, - type = type.toInt(), - clazz = clazz.toInt() and 0b1111111_11111111, - timeToLive = ttl, - cacheFlush = ((clazz.toInt() shr 15) and 0b1) != 0, - dataPosition = rdposition, - dataLength = rdlength.toInt(), - data = data - ) to position - } - } - - fun getDataReader(): DnsReader { - return DnsReader(data!!, dataPosition, dataLength) - } -} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt deleted file mode 100644 index 48a04580..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt +++ /dev/null @@ -1,208 +0,0 @@ -package com.futo.platformplayer.mdns - -import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.nio.charset.StandardCharsets - -class DnsWriter { - private val data = mutableListOf() - private val namePositions = mutableMapOf() - - fun toByteArray(): ByteArray = data.toByteArray() - - fun writePacket( - header: DnsPacketHeader, - questionCount: Int? = null, questionWriter: ((DnsWriter, Int) -> Unit)? = null, - answerCount: Int? = null, answerWriter: ((DnsWriter, Int) -> Unit)? = null, - authorityCount: Int? = null, authorityWriter: ((DnsWriter, Int) -> Unit)? = null, - additionalsCount: Int? = null, additionalWriter: ((DnsWriter, Int) -> Unit)? = null - ) { - if (questionCount != null && questionWriter == null || questionCount == null && questionWriter != null) - throw Exception("When question count is given, question writer should also be given.") - if (answerCount != null && answerWriter == null || answerCount == null && answerWriter != null) - throw Exception("When answer count is given, answer writer should also be given.") - if (authorityCount != null && authorityWriter == null || authorityCount == null && authorityWriter != null) - throw Exception("When authority count is given, authority writer should also be given.") - if (additionalsCount != null && additionalWriter == null || additionalsCount == null && additionalWriter != null) - throw Exception("When additionals count is given, additional writer should also be given.") - - writeHeader(header, questionCount ?: 0, answerCount ?: 0, authorityCount ?: 0, additionalsCount ?: 0) - - repeat(questionCount ?: 0) { questionWriter?.invoke(this, it) } - repeat(answerCount ?: 0) { answerWriter?.invoke(this, it) } - repeat(authorityCount ?: 0) { authorityWriter?.invoke(this, it) } - repeat(additionalsCount ?: 0) { additionalWriter?.invoke(this, it) } - } - - fun writeHeader(header: DnsPacketHeader, questionCount: Int, answerCount: Int, authorityCount: Int, additionalsCount: Int) { - write(header.identifier) - - var flags: UShort = 0u - flags = flags or ((header.queryResponse.toUInt() and 0xFFFFu) shl 15).toUShort() - flags = flags or ((header.opcode.toUInt() and 0xFFFFu) shl 11).toUShort() - flags = flags or ((if (header.authoritativeAnswer) 1u else 0u) shl 10).toUShort() - flags = flags or ((if (header.truncated) 1u else 0u) shl 9).toUShort() - flags = flags or ((if (header.recursionDesired) 1u else 0u) shl 8).toUShort() - flags = flags or ((if (header.recursionAvailable) 1u else 0u) shl 7).toUShort() - flags = flags or ((if (header.answerAuthenticated) 1u else 0u) shl 5).toUShort() - flags = flags or ((if (header.nonAuthenticatedData) 1u else 0u) shl 4).toUShort() - flags = flags or header.responseCode.value.toUShort() - write(flags) - - write(questionCount.toUShort()) - write(answerCount.toUShort()) - write(authorityCount.toUShort()) - write(additionalsCount.toUShort()) - } - - fun writeDomainName(name: String) { - synchronized(namePositions) { - val labels = name.split('.') - for (label in labels) { - val nameAtOffset = name.substring(name.indexOf(label)) - if (namePositions.containsKey(nameAtOffset)) { - val position = namePositions[nameAtOffset]!! - val pointer = (0b11000000_00000000 or position).toUShort() - write(pointer) - return - } - if (label.isNotEmpty()) { - val labelBytes = label.toByteArray(StandardCharsets.UTF_8) - val nameStartPos = data.size - write(labelBytes.size.toByte()) - write(labelBytes) - namePositions[nameAtOffset] = nameStartPos - } - } - write(0.toByte()) - } - } - - fun write(value: DnsResourceRecord, dataWriter: (DnsWriter) -> Unit) { - writeDomainName(value.name) - write(value.type.toUShort()) - val cls = ((if (value.cacheFlush) 1u else 0u) shl 15).toUShort() or value.clazz.toUShort() - write(cls) - write(value.timeToLive) - - val lengthOffset = data.size - write(0.toUShort()) - dataWriter(this) - val rdLength = data.size - lengthOffset - 2 - val rdLengthBytes = ByteBuffer.allocate(2).order(ByteOrder.BIG_ENDIAN).putShort(rdLength.toShort()).array() - data[lengthOffset] = rdLengthBytes[0] - data[lengthOffset + 1] = rdLengthBytes[1] - } - - fun write(value: DnsQuestion) { - writeDomainName(value.name) - write(value.type.toUShort()) - write(((if (value.queryUnicast) 1u else 0u shl 15).toUShort() or value.clazz.toUShort())) - } - - fun write(value: Double) { - val bytes = ByteBuffer.allocate(Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putDouble(value).array() - write(bytes) - } - - fun write(value: Short) { - val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value).array() - write(bytes) - } - - fun write(value: Int) { - val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value).array() - write(bytes) - } - - fun write(value: Long) { - val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value).array() - write(bytes) - } - - fun write(value: Float) { - val bytes = ByteBuffer.allocate(Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putFloat(value).array() - write(bytes) - } - - fun write(value: Byte) { - data.add(value) - } - - fun write(value: ByteArray) { - data.addAll(value.asIterable()) - } - - fun write(value: ByteArray, offset: Int, length: Int) { - data.addAll(value.slice(offset until offset + length)) - } - - fun write(value: UShort) { - val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value.toShort()).array() - write(bytes) - } - - fun write(value: UInt) { - val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value.toInt()).array() - write(bytes) - } - - fun write(value: ULong) { - val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value.toLong()).array() - write(bytes) - } - - fun write(value: String) { - val bytes = value.toByteArray(StandardCharsets.UTF_8) - write(bytes.size.toByte()) - write(bytes) - } - - fun write(value: PTRRecord) { - writeDomainName(value.domainName) - } - - fun write(value: ARecord) { - val bytes = value.address.address - if (bytes.size != 4) throw Exception("Unexpected amount of address bytes.") - write(bytes) - } - - fun write(value: AAAARecord) { - val bytes = value.address.address - if (bytes.size != 16) throw Exception("Unexpected amount of address bytes.") - write(bytes) - } - - fun write(value: TXTRecord) { - value.texts.forEach { - val bytes = it.toByteArray(StandardCharsets.UTF_8) - write(bytes.size.toByte()) - write(bytes) - } - } - - fun write(value: SRVRecord) { - write(value.priority) - write(value.weight) - write(value.port) - writeDomainName(value.target) - } - - fun write(value: NSECRecord) { - writeDomainName(value.ownerName) - value.typeBitMaps.forEach { (windowBlock, bitmap) -> - write(windowBlock) - write(bitmap.size.toByte()) - write(bitmap) - } - } - - fun write(value: OPTRecord) { - value.options.forEach { option -> - write(option.code) - write(option.data.size.toUShort()) - write(option.data) - } - } -} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt b/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt deleted file mode 100644 index 48bb4c6a..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt +++ /dev/null @@ -1,63 +0,0 @@ -package com.futo.platformplayer.mdns - -import android.util.Log - -object Extensions { - fun ByteArray.toByteDump(): String { - val result = StringBuilder() - for (i in indices) { - result.append(String.format("%02X ", this[i])) - - if ((i + 1) % 16 == 0 || i == size - 1) { - val padding = 3 * (16 - (i % 16 + 1)) - if (i == size - 1 && (i + 1) % 16 != 0) result.append(" ".repeat(padding)) - - result.append("; ") - val start = i - (i % 16) - val end = minOf(i, size - 1) - for (j in start..end) { - val ch = if (this[j] in 32..127) this[j].toChar() else '.' - result.append(ch) - } - if (i != size - 1) result.appendLine() - } - } - return result.toString() - } - - fun UByteArray.readDomainName(startPosition: Int): Pair { - var position = startPosition - return readDomainName(position, 0) - } - - private fun UByteArray.readDomainName(position: Int, depth: Int = 0): Pair { - if (depth > 16) throw Exception("Exceeded maximum recursion depth in DNS packet. Possible circular reference.") - - val domainParts = mutableListOf() - var newPosition = position - - while (true) { - if (newPosition < 0) - println() - - val length = this[newPosition].toUByte() - if ((length and 0b11000000u).toUInt() == 0b11000000u) { - val offset = (((length and 0b00111111u).toUInt()) shl 8) or this[newPosition + 1].toUInt() - val (part, _) = this.readDomainName(offset.toInt(), depth + 1) - domainParts.add(part) - newPosition += 2 - break - } else if (length.toUInt() == 0u) { - newPosition++ - break - } else { - newPosition++ - val part = String(this.asByteArray(), newPosition, length.toInt(), Charsets.UTF_8) - domainParts.add(part) - newPosition += length.toInt() - } - } - - return domainParts.joinToString(".") to newPosition - } -} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt b/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt deleted file mode 100644 index b8ef3eea..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt +++ /dev/null @@ -1,501 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.logging.Logger -import kotlinx.coroutines.* -import java.net.* -import java.util.* -import java.util.concurrent.locks.ReentrantLock -import kotlin.concurrent.withLock - - -class MDNSListener { - companion object { - private val TAG = "MDNSListener" - const val MulticastPort = 5353 - val MulticastAddressIPv4: InetAddress = InetAddress.getByName("224.0.0.251") - val MulticastAddressIPv6: InetAddress = InetAddress.getByName("FF02::FB") - val MdnsEndpointIPv6: InetSocketAddress = InetSocketAddress(MulticastAddressIPv6, MulticastPort) - val MdnsEndpointIPv4: InetSocketAddress = InetSocketAddress(MulticastAddressIPv4, MulticastPort) - } - - private val _lockObject = ReentrantLock() - private var _receiver4: MulticastSocket? = null - private var _receiver6: MulticastSocket? = null - private val _senders = mutableListOf() - private val _nicMonitor = NICMonitor() - private val _serviceRecordAggregator = ServiceRecordAggregator() - private var _started = false - private var _threadReceiver4: Thread? = null - private var _threadReceiver6: Thread? = null - private var _scope: CoroutineScope? = null - - var onPacket: ((DnsPacket) -> Unit)? = null - var onServicesUpdated: ((List) -> Unit)? = null - - private val _recordLockObject = ReentrantLock() - private val _recordsA = mutableListOf>() - private val _recordsAAAA = mutableListOf>() - private val _recordsPTR = mutableListOf>() - private val _recordsTXT = mutableListOf>() - private val _recordsSRV = mutableListOf>() - private val _services = mutableListOf() - - init { - _nicMonitor.added = { onNicsAdded(it) } - _nicMonitor.removed = { onNicsRemoved(it) } - _serviceRecordAggregator.onServicesUpdated = { onServicesUpdated?.invoke(it) } - } - - fun start() { - if (_started) { - Logger.i(TAG, "Already started.") - return - } - _started = true - - _scope = CoroutineScope(Dispatchers.IO); - - Logger.i(TAG, "Starting") - _lockObject.withLock { - val receiver4 = MulticastSocket(null).apply { - reuseAddress = true - bind(InetSocketAddress(InetAddress.getByName("0.0.0.0"), MulticastPort)) - } - - _receiver4 = receiver4 - - val receiver6 = MulticastSocket(null).apply { - reuseAddress = true - bind(InetSocketAddress(InetAddress.getByName("::"), MulticastPort)) - } - _receiver6 = receiver6 - - _nicMonitor.start() - _serviceRecordAggregator.start() - onNicsAdded(_nicMonitor.current) - - _threadReceiver4 = Thread { - receiveLoop(receiver4) - }.apply { start() } - - _threadReceiver6 = Thread { - receiveLoop(receiver6) - }.apply { start() } - } - } - - fun queryServices(names: Array) { - if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") - - val writer = DnsWriter() - writer.writePacket( - DnsPacketHeader( - identifier = 0u, - queryResponse = QueryResponse.Query.value.toInt(), - opcode = DnsOpcode.StandardQuery.value.toInt(), - truncated = false, - nonAuthenticatedData = false, - recursionDesired = false, - answerAuthenticated = false, - authoritativeAnswer = false, - recursionAvailable = false, - responseCode = DnsResponseCode.NoError - ), - questionCount = names.size, - questionWriter = { w, i -> - w.write( - DnsQuestion( - name = names[i], - type = QuestionType.PTR.value.toInt(), - clazz = QuestionClass.IN.value.toInt(), - queryUnicast = false - ) - ) - } - ) - - send(writer.toByteArray()) - } - - private fun send(data: ByteArray) { - _lockObject.withLock { - for (sender in _senders) { - try { - val endPoint = if (sender.localAddress is Inet4Address) MdnsEndpointIPv4 else MdnsEndpointIPv6 - sender.send(DatagramPacket(data, data.size, endPoint)) - } catch (e: Exception) { - Logger.i(TAG, "Failed to send on ${sender.localSocketAddress}: ${e.message}.") - } - } - } - } - - fun queryAllQuestions(names: Array) { - if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") - - val questions = names.flatMap { _serviceRecordAggregator.getAllQuestions(it) } - questions.groupBy { it.name }.forEach { (_, questionsForHost) -> - val writer = DnsWriter() - writer.writePacket( - DnsPacketHeader( - identifier = 0u, - queryResponse = QueryResponse.Query.value.toInt(), - opcode = DnsOpcode.StandardQuery.value.toInt(), - truncated = false, - nonAuthenticatedData = false, - recursionDesired = false, - answerAuthenticated = false, - authoritativeAnswer = false, - recursionAvailable = false, - responseCode = DnsResponseCode.NoError - ), - questionCount = questionsForHost.size, - questionWriter = { w, i -> w.write(questionsForHost[i]) } - ) - send(writer.toByteArray()) - } - } - - private fun onNicsAdded(nics: List) { - _lockObject.withLock { - if (!_started) return - - val addresses = nics.flatMap { nic -> - nic.interfaceAddresses.map { it.address } - .filter { it is Inet4Address || (it is Inet6Address && it.isLinkLocalAddress) } - } - - addresses.forEach { address -> - Logger.i(TAG, "New address discovered $address") - - try { - when (address) { - is Inet4Address -> { - _receiver4?.let { receiver4 -> - //receiver4.setOption(StandardSocketOptions.IP_MULTICAST_IF, NetworkInterface.getByInetAddress(address)) - val ni = NetworkInterface.getByInetAddress(address) - receiver4.networkInterface = ni - receiver4.joinGroup(InetSocketAddress(MulticastAddressIPv4, MulticastPort), ni) - } - - val sender = MulticastSocket(null).apply { - reuseAddress = true - bind(InetSocketAddress(address, MulticastPort)) - joinGroup(InetSocketAddress(MulticastAddressIPv4, MulticastPort), NetworkInterface.getByInetAddress(address)) - } - _senders.add(sender) - } - - is Inet6Address -> { - _receiver6?.let { receiver6 -> - //receiver6.setOption(StandardSocketOptions.IP_MULTICAST_IF, NetworkInterface.getByInetAddress(address)) - val ni = NetworkInterface.getByInetAddress(address) - receiver6.networkInterface = ni - receiver6.joinGroup(InetSocketAddress(MulticastAddressIPv6, MulticastPort), ni) - } - - val sender = MulticastSocket(null).apply { - reuseAddress = true - bind(InetSocketAddress(address, MulticastPort)) - joinGroup(InetSocketAddress(MulticastAddressIPv6, MulticastPort), NetworkInterface.getByInetAddress(address)) - } - _senders.add(sender) - } - - else -> throw UnsupportedOperationException("Address type ${address.javaClass.name} is not supported.") - } - } catch (e: Exception) { - Logger.i(TAG, "Exception occurred when processing added address $address: ${e.message}.") - // Close the socket if there was an error - (_senders.lastOrNull() as? MulticastSocket)?.close() - } - } - } - - if (nics.isNotEmpty()) { - try { - updateBroadcastRecords() - broadcastRecords() - } catch (e: Exception) { - Logger.i(TAG, "Exception occurred when broadcasting records: ${e.message}.") - } - } - } - - private fun onNicsRemoved(nics: List) { - _lockObject.withLock { - if (!_started) return - //TODO: Cleanup? - } - - if (nics.isNotEmpty()) { - try { - updateBroadcastRecords() - broadcastRecords() - } catch (e: Exception) { - Logger.e(TAG, "Exception occurred when broadcasting records", e) - } - } - } - - private fun receiveLoop(client: DatagramSocket) { - Logger.i(TAG, "Started receive loop") - - val buffer = ByteArray(8972) - val packet = DatagramPacket(buffer, buffer.size) - while (_started) { - try { - client.receive(packet) - handleResult(packet) - } catch (e: Exception) { - Logger.e(TAG, "An exception occurred while handling UDP result:", e) - } - } - - Logger.i(TAG, "Stopped receive loop") - } - - fun broadcastService( - deviceName: String, - serviceName: String, - port: UShort, - ttl: UInt = 120u, - weight: UShort = 0u, - priority: UShort = 0u, - texts: List? = null - ) { - _recordLockObject.withLock { - _services.add( - BroadcastService( - deviceName = deviceName, - port = port, - priority = priority, - serviceName = serviceName, - texts = texts, - ttl = ttl, - weight = weight - ) - ) - } - - updateBroadcastRecords() - broadcastRecords() - } - - private fun updateBroadcastRecords() { - _recordLockObject.withLock { - _recordsSRV.clear() - _recordsPTR.clear() - _recordsA.clear() - _recordsAAAA.clear() - _recordsTXT.clear() - - _services.forEach { service -> - val id = UUID.randomUUID().toString() - val deviceDomainName = "${service.deviceName}.${service.serviceName}" - val addressName = "$id.local" - - _recordsSRV.add( - DnsResourceRecord( - clazz = ResourceRecordClass.IN.value.toInt(), - type = ResourceRecordType.SRV.value.toInt(), - timeToLive = service.ttl, - name = deviceDomainName, - cacheFlush = false - ) to SRVRecord( - target = addressName, - port = service.port, - priority = service.priority, - weight = service.weight - ) - ) - - _recordsPTR.add( - DnsResourceRecord( - clazz = ResourceRecordClass.IN.value.toInt(), - type = ResourceRecordType.PTR.value.toInt(), - timeToLive = service.ttl, - name = service.serviceName, - cacheFlush = false - ) to PTRRecord( - domainName = deviceDomainName - ) - ) - - val addresses = _nicMonitor.current.flatMap { nic -> - nic.interfaceAddresses.map { it.address } - } - - addresses.forEach { address -> - when (address) { - is Inet4Address -> _recordsA.add( - DnsResourceRecord( - clazz = ResourceRecordClass.IN.value.toInt(), - type = ResourceRecordType.A.value.toInt(), - timeToLive = service.ttl, - name = addressName, - cacheFlush = false - ) to ARecord( - address = address - ) - ) - - is Inet6Address -> _recordsAAAA.add( - DnsResourceRecord( - clazz = ResourceRecordClass.IN.value.toInt(), - type = ResourceRecordType.AAAA.value.toInt(), - timeToLive = service.ttl, - name = addressName, - cacheFlush = false - ) to AAAARecord( - address = address - ) - ) - - else -> Logger.i(TAG, "Invalid address type: $address.") - } - } - - if (service.texts != null) { - _recordsTXT.add( - DnsResourceRecord( - clazz = ResourceRecordClass.IN.value.toInt(), - type = ResourceRecordType.TXT.value.toInt(), - timeToLive = service.ttl, - name = deviceDomainName, - cacheFlush = false - ) to TXTRecord( - texts = service.texts - ) - ) - } - } - } - } - - private fun broadcastRecords(questions: List? = null) { - val writer = DnsWriter() - _recordLockObject.withLock { - val recordsA: List> - val recordsAAAA: List> - val recordsPTR: List> - val recordsTXT: List> - val recordsSRV: List> - - if (questions != null) { - recordsA = _recordsA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } - recordsAAAA = _recordsAAAA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } - recordsPTR = _recordsPTR.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } - recordsSRV = _recordsSRV.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } - recordsTXT = _recordsTXT.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } - } else { - recordsA = _recordsA - recordsAAAA = _recordsAAAA - recordsPTR = _recordsPTR - recordsSRV = _recordsSRV - recordsTXT = _recordsTXT - } - - val answerCount = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size + recordsTXT.size - if (answerCount < 1) return - - val txtOffset = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size - val srvOffset = recordsA.size + recordsAAAA.size + recordsPTR.size - val ptrOffset = recordsA.size + recordsAAAA.size - val aaaaOffset = recordsA.size - - writer.writePacket( - DnsPacketHeader( - identifier = 0u, - queryResponse = QueryResponse.Response.value.toInt(), - opcode = DnsOpcode.StandardQuery.value.toInt(), - truncated = false, - nonAuthenticatedData = false, - recursionDesired = false, - answerAuthenticated = false, - authoritativeAnswer = true, - recursionAvailable = false, - responseCode = DnsResponseCode.NoError - ), - answerCount = answerCount, - answerWriter = { w, i -> - when { - i >= txtOffset -> { - val record = recordsTXT[i - txtOffset] - w.write(record.first) { it.write(record.second) } - } - - i >= srvOffset -> { - val record = recordsSRV[i - srvOffset] - w.write(record.first) { it.write(record.second) } - } - - i >= ptrOffset -> { - val record = recordsPTR[i - ptrOffset] - w.write(record.first) { it.write(record.second) } - } - - i >= aaaaOffset -> { - val record = recordsAAAA[i - aaaaOffset] - w.write(record.first) { it.write(record.second) } - } - - else -> { - val record = recordsA[i] - w.write(record.first) { it.write(record.second) } - } - } - } - ) - } - - send(writer.toByteArray()) - } - - private fun handleResult(result: DatagramPacket) { - try { - val packet = DnsPacket.parse(result.data) - if (packet.questions.isNotEmpty()) { - _scope?.launch(Dispatchers.IO) { - try { - broadcastRecords(packet.questions) - } catch (e: Throwable) { - Logger.i(TAG, "Broadcasting records failed", e) - } - } - - } - _serviceRecordAggregator.add(packet) - onPacket?.invoke(packet) - } catch (e: Exception) { - Logger.v(TAG, "Failed to handle packet: ${Base64.getEncoder().encodeToString(result.data.slice(IntRange(0, result.length - 1)).toByteArray())}", e) - } - } - - fun stop() { - _lockObject.withLock { - _started = false - - _scope?.cancel() - _scope = null - - _nicMonitor.stop() - _serviceRecordAggregator.stop() - - _receiver4?.close() - _receiver4 = null - - _receiver6?.close() - _receiver6 = null - - _senders.forEach { it.close() } - _senders.clear() - } - - _threadReceiver4?.join() - _threadReceiver4 = null - - _threadReceiver6?.join() - _threadReceiver6 = null - } -} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt b/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt deleted file mode 100644 index 884e1514..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt +++ /dev/null @@ -1,66 +0,0 @@ -package com.futo.platformplayer.mdns - -import kotlinx.coroutines.* -import java.net.NetworkInterface - -class NICMonitor { - private val lockObject = Any() - private val nics = mutableListOf() - private var cts: Job? = null - - val current: List - get() = synchronized(nics) { nics.toList() } - - var added: ((List) -> Unit)? = null - var removed: ((List) -> Unit)? = null - - fun start() { - synchronized(lockObject) { - if (cts != null) throw Exception("Already started.") - - cts = CoroutineScope(Dispatchers.Default).launch { - loopAsync() - } - } - - nics.clear() - nics.addAll(getCurrentInterfaces().toList()) - } - - fun stop() { - synchronized(lockObject) { - cts?.cancel() - cts = null - } - - synchronized(nics) { - nics.clear() - } - } - - private suspend fun loopAsync() { - while (cts?.isActive == true) { - try { - val currentNics = getCurrentInterfaces().toList() - removed?.invoke(nics.filter { k -> currentNics.none { n -> k.name == n.name } }) - added?.invoke(currentNics.filter { nic -> nics.none { k -> k.name == nic.name } }) - - synchronized(nics) { - nics.clear() - nics.addAll(currentNics) - } - } catch (ex: Exception) { - // Ignored - } - delay(5000) - } - } - - private fun getCurrentInterfaces(): List { - val nics = NetworkInterface.getNetworkInterfaces().toList() - .filter { it.isUp && !it.isLoopback } - - return if (nics.isNotEmpty()) nics else NetworkInterface.getNetworkInterfaces().toList() - .filter { it.isUp } - } -} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt b/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt deleted file mode 100644 index f4a3e5e9..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt +++ /dev/null @@ -1,71 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.logging.Logger -import java.lang.Thread.sleep - -class ServiceDiscoverer(names: Array, private val _onServicesUpdated: (List) -> Unit) { - private val _names: Array - private var _listener: MDNSListener? = null - private var _started = false - private var _thread: Thread? = null - - init { - if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") - _names = names - } - - fun broadcastService( - deviceName: String, - serviceName: String, - port: UShort, - ttl: UInt = 120u, - weight: UShort = 0u, - priority: UShort = 0u, - texts: List? = null - ) { - _listener?.let { - it.broadcastService(deviceName, serviceName, port, ttl, weight, priority, texts) - } - } - - fun stop() { - _started = false - _listener?.stop() - _listener = null - _thread?.join() - _thread = null - } - - fun start() { - if (_started) { - Logger.i(TAG, "Already started.") - return - } - _started = true - - val listener = MDNSListener() - _listener = listener - listener.onServicesUpdated = { _onServicesUpdated?.invoke(it) } - listener.start() - - _thread = Thread { - try { - sleep(2000) - - while (_started) { - listener.queryServices(_names) - sleep(2000) - listener.queryAllQuestions(_names) - sleep(2000) - } - } catch (e: Throwable) { - Logger.i(TAG, "Exception in loop thread", e) - stop() - } - }.apply { start() } - } - - companion object { - private val TAG = "ServiceDiscoverer" - } -} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt b/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt deleted file mode 100644 index 5292d375..00000000 --- a/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt +++ /dev/null @@ -1,226 +0,0 @@ -package com.futo.platformplayer.mdns - -import com.futo.platformplayer.logging.Logger -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job -import kotlinx.coroutines.delay -import kotlinx.coroutines.isActive -import kotlinx.coroutines.launch -import java.net.InetAddress -import java.util.Date - -data class DnsService( - var name: String, - var target: String, - var port: UShort, - val addresses: MutableList = mutableListOf(), - val pointers: MutableList = mutableListOf(), - val texts: MutableList = mutableListOf() -) - -data class CachedDnsAddressRecord( - val expirationTime: Date, - val address: InetAddress -) - -data class CachedDnsTxtRecord( - val expirationTime: Date, - val texts: List -) - -data class CachedDnsPtrRecord( - val expirationTime: Date, - val target: String -) - -data class CachedDnsSrvRecord( - val expirationTime: Date, - val service: SRVRecord -) - -class ServiceRecordAggregator { - private val _lockObject = Any() - private val _cachedAddressRecords = mutableMapOf>() - private val _cachedTxtRecords = mutableMapOf() - private val _cachedPtrRecords = mutableMapOf>() - private val _cachedSrvRecords = mutableMapOf() - private val _currentServices = mutableListOf() - private var _cts: Job? = null - - var onServicesUpdated: ((List) -> Unit)? = null - - fun start() { - synchronized(_lockObject) { - if (_cts != null) throw Exception("Already started.") - - _cts = CoroutineScope(Dispatchers.Default).launch { - try { - while (isActive) { - val now = Date() - synchronized(_currentServices) { - _cachedAddressRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } } - _cachedTxtRecords.entries.removeIf { now.after(it.value.expirationTime) } - _cachedSrvRecords.entries.removeIf { now.after(it.value.expirationTime) } - _cachedPtrRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } } - - val newServices = getCurrentServices() - _currentServices.clear() - _currentServices.addAll(newServices) - } - - onServicesUpdated?.invoke(_currentServices.toList()) - delay(5000) - } - } catch (e: Throwable) { - Logger.e(TAG, "Unexpected failure in MDNS loop", e) - } - } - } - } - - fun stop() { - synchronized(_lockObject) { - _cts?.cancel() - _cts = null - } - } - - fun add(packet: DnsPacket) { - val currentServices: List - val dnsResourceRecords = packet.answers + packet.additionals + packet.authorities - val txtRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.TXT.value.toInt() }.map { it to it.getDataReader().readTXTRecord() } - val aRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.A.value.toInt() }.map { it to it.getDataReader().readARecord() } - val aaaaRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.AAAA.value.toInt() }.map { it to it.getDataReader().readAAAARecord() } - val srvRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.SRV.value.toInt() }.map { it to it.getDataReader().readSRVRecord() } - val ptrRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.PTR.value.toInt() }.map { it to it.getDataReader().readPTRRecord() } - - /*val builder = StringBuilder() - builder.appendLine("Received records:") - srvRecords.forEach { builder.appendLine("SRV ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: (Port: ${it.second.port}, Target: ${it.second.target}, Priority: ${it.second.priority}, Weight: ${it.second.weight})") } - ptrRecords.forEach { builder.appendLine("PTR ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.domainName}") } - txtRecords.forEach { builder.appendLine("TXT ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.texts.joinToString(", ")}") } - aRecords.forEach { builder.appendLine("A ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") } - aaaaRecords.forEach { builder.appendLine("AAAA ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") } - Logger.i(TAG, "$builder")*/ - - synchronized(this._currentServices) { - ptrRecords.forEach { record -> - val cachedPtrRecord = _cachedPtrRecords.getOrPut(record.first.name) { mutableListOf() } - val newPtrRecord = CachedDnsPtrRecord(Date(System.currentTimeMillis() + record.first.timeToLive.toLong() * 1000L), record.second.domainName) - cachedPtrRecord.replaceOrAdd(newPtrRecord) { it.target == record.second.domainName } - } - - aRecords.forEach { aRecord -> - val cachedARecord = _cachedAddressRecords.getOrPut(aRecord.first.name) { mutableListOf() } - val newARecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aRecord.first.timeToLive.toLong() * 1000L), aRecord.second.address) - cachedARecord.replaceOrAdd(newARecord) { it.address == newARecord.address } - } - - aaaaRecords.forEach { aaaaRecord -> - val cachedAaaaRecord = _cachedAddressRecords.getOrPut(aaaaRecord.first.name) { mutableListOf() } - val newAaaaRecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aaaaRecord.first.timeToLive.toLong() * 1000L), aaaaRecord.second.address) - cachedAaaaRecord.replaceOrAdd(newAaaaRecord) { it.address == newAaaaRecord.address } - } - - txtRecords.forEach { txtRecord -> - _cachedTxtRecords[txtRecord.first.name] = CachedDnsTxtRecord(Date(System.currentTimeMillis() + txtRecord.first.timeToLive.toLong() * 1000L), txtRecord.second.texts) - } - - srvRecords.forEach { srvRecord -> - _cachedSrvRecords[srvRecord.first.name] = CachedDnsSrvRecord(Date(System.currentTimeMillis() + srvRecord.first.timeToLive.toLong() * 1000L), srvRecord.second) - } - - currentServices = getCurrentServices() - this._currentServices.clear() - this._currentServices.addAll(currentServices) - } - - onServicesUpdated?.invoke(currentServices) - } - - fun getAllQuestions(serviceName: String): List { - val questions = mutableListOf() - synchronized(_currentServices) { - val servicePtrRecords = _cachedPtrRecords[serviceName] ?: return emptyList() - - val ptrWithoutSrvRecord = servicePtrRecords.filterNot { _cachedSrvRecords.containsKey(it.target) }.map { it.target } - questions.addAll(ptrWithoutSrvRecord.flatMap { s -> - listOf( - DnsQuestion( - name = s, - type = QuestionType.SRV.value.toInt(), - clazz = QuestionClass.IN.value.toInt(), - queryUnicast = false - ) - ) - }) - - val incompleteCurrentServices = _currentServices.filter { it.addresses.isEmpty() && it.name.endsWith(serviceName) } - questions.addAll(incompleteCurrentServices.flatMap { s -> - listOf( - DnsQuestion( - name = s.name, - type = QuestionType.TXT.value.toInt(), - clazz = QuestionClass.IN.value.toInt(), - queryUnicast = false - ), - DnsQuestion( - name = s.target, - type = QuestionType.A.value.toInt(), - clazz = QuestionClass.IN.value.toInt(), - queryUnicast = false - ), - DnsQuestion( - name = s.target, - type = QuestionType.AAAA.value.toInt(), - clazz = QuestionClass.IN.value.toInt(), - queryUnicast = false - ) - ) - }) - } - return questions - } - - private fun getCurrentServices(): MutableList { - val currentServices = _cachedSrvRecords.map { (key, value) -> - DnsService( - name = key, - target = value.service.target, - port = value.service.port - ) - }.toMutableList() - - currentServices.forEach { service -> - _cachedAddressRecords[service.target]?.let { - service.addresses.addAll(it.map { record -> record.address }) - } - } - - currentServices.forEach { service -> - service.pointers.addAll(_cachedPtrRecords.filter { it.value.any { ptr -> ptr.target == service.name } }.map { it.key }) - } - - currentServices.forEach { service -> - _cachedTxtRecords[service.name]?.let { - service.texts.addAll(it.texts) - } - } - - return currentServices - } - - private inline fun MutableList.replaceOrAdd(newElement: T, predicate: (T) -> Boolean) { - val index = indexOfFirst(predicate) - if (index >= 0) { - this[index] = newElement - } else { - add(newElement) - } - } - - private companion object { - private const val TAG = "ServiceRecordAggregator" - } -} diff --git a/app/src/main/java/com/futo/platformplayer/states/StateApp.kt b/app/src/main/java/com/futo/platformplayer/states/StateApp.kt index bdad1395..d15a604a 100644 --- a/app/src/main/java/com/futo/platformplayer/states/StateApp.kt +++ b/app/src/main/java/com/futo/platformplayer/states/StateApp.kt @@ -411,7 +411,7 @@ class StateApp { } if (Settings.instance.synchronization.enabled) { - StateSync.instance.start() + StateSync.instance.start(context) } Logger.onLogSubmitted.subscribe { diff --git a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt index 7e5d02e9..6bbd0c7d 100644 --- a/app/src/main/java/com/futo/platformplayer/states/StateSync.kt +++ b/app/src/main/java/com/futo/platformplayer/states/StateSync.kt @@ -1,5 +1,8 @@ package com.futo.platformplayer.states +import android.content.Context +import android.net.nsd.NsdManager +import android.net.nsd.NsdServiceInfo import android.os.Build import android.util.Log import com.futo.platformplayer.LittleEndianDataInputStream @@ -9,14 +12,14 @@ import com.futo.platformplayer.UIDialogs import com.futo.platformplayer.activities.MainActivity import com.futo.platformplayer.activities.SyncShowPairingCodeActivity import com.futo.platformplayer.api.media.Serializer +import com.futo.platformplayer.casting.StateCasting +import com.futo.platformplayer.casting.StateCasting.Companion import com.futo.platformplayer.constructs.Event1 import com.futo.platformplayer.constructs.Event2 import com.futo.platformplayer.encryption.GEncryptionProvider import com.futo.platformplayer.generateReadablePassword import com.futo.platformplayer.getConnectedSocket import com.futo.platformplayer.logging.Logger -import com.futo.platformplayer.mdns.DnsService -import com.futo.platformplayer.mdns.ServiceDiscoverer import com.futo.platformplayer.models.HistoryVideo import com.futo.platformplayer.models.Subscription import com.futo.platformplayer.noise.protocol.DHState @@ -81,12 +84,30 @@ class StateSync { //TODO: Should sync mdns and casting mdns be merged? //TODO: Decrease interval that devices are updated //TODO: Send less data - private val _serviceDiscoverer = ServiceDiscoverer(arrayOf("_gsync._tcp.local")) { handleServiceUpdated(it) } + private val _pairingCode: String? = generateReadablePassword(8) val pairingCode: String? get() = _pairingCode private var _relaySession: SyncSocketSession? = null private var _threadRelay: Thread? = null private val _remotePendingStatusUpdate = mutableMapOf Unit>() + private var _nsdManager: NsdManager? = null + private val _registrationListener = object : NsdManager.RegistrationListener { + override fun onServiceRegistered(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "onServiceRegistered: ${serviceInfo.serviceName}") + } + + override fun onRegistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) { + Log.v(TAG, "onRegistrationFailed: ${serviceInfo.serviceName} (error code: $errorCode)") + } + + override fun onServiceUnregistered(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "onServiceUnregistered: ${serviceInfo.serviceName}") + } + + override fun onUnregistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) { + Log.v(TAG, "onUnregistrationFailed: ${serviceInfo.serviceName} (error code: $errorCode)") + } + } var keyPair: DHState? = null var publicKey: String? = null @@ -101,15 +122,116 @@ class StateSync { } } - fun start() { + fun start(context: Context) { if (_started) { Logger.i(TAG, "Already started.") return } _started = true + _nsdManager = context.getSystemService(Context.NSD_SERVICE) as NsdManager - if (Settings.instance.synchronization.broadcast || Settings.instance.synchronization.connectDiscovered) { - _serviceDiscoverer.start() + if (Settings.instance.synchronization.connectDiscovered) { + _nsdManager?.apply { + discoverServices("_gsync._tcp", NsdManager.PROTOCOL_DNS_SD, object : NsdManager.DiscoveryListener { + override fun onDiscoveryStarted(regType: String) { + Log.d(TAG, "Service discovery started for $regType") + } + + override fun onDiscoveryStopped(serviceType: String) { + Log.i(TAG, "Discovery stopped: $serviceType") + } + + override fun onServiceLost(service: NsdServiceInfo) { + Log.e(TAG, "service lost: $service") + // TODO: Handle service lost, e.g., remove device + } + + override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) { + Log.e(TAG, "Discovery failed for $serviceType: Error code:$errorCode") + _nsdManager?.stopServiceDiscovery(this) + } + + override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) { + Log.e(TAG, "Stop discovery failed for $serviceType: Error code:$errorCode") + _nsdManager?.stopServiceDiscovery(this) + } + + fun addOrUpdate(name: String, adrs: Array, port: Int, attributes: Map) { + if (!Settings.instance.synchronization.connectDiscovered) { + return + } + + val urlSafePkey = attributes.get("pk")?.decodeToString() ?: return + val pkey = Base64.getEncoder().encodeToString(Base64.getDecoder().decode(urlSafePkey.replace('-', '+').replace('_', '/'))) + val syncDeviceInfo = SyncDeviceInfo(pkey, adrs.map { it.hostAddress }.toTypedArray(), port, null) + val authorized = isAuthorized(pkey) + + if (authorized && !isConnected(pkey)) { + val now = System.currentTimeMillis() + val lastConnectTime = synchronized(_lastConnectTimesMdns) { + _lastConnectTimesMdns[pkey] ?: 0 + } + + //Connect once every 30 seconds, max + if (now - lastConnectTime > 30000) { + synchronized(_lastConnectTimesMdns) { + _lastConnectTimesMdns[pkey] = now + } + + Logger.i(TAG, "Found device authorized device '${name}' with pkey=$pkey, attempting to connect") + + try { + connect(syncDeviceInfo) + } catch (e: Throwable) { + Logger.i(TAG, "Failed to connect to $pkey", e) + } + } + } + } + + override fun onServiceFound(service: NsdServiceInfo) { + Log.v(TAG, "Service discovery success for ${service.serviceType}: $service") + addOrUpdate(service.serviceName, if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + service.hostAddresses.toTypedArray() + } else { + arrayOf(service.host) + }, service.port, service.attributes) + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + _nsdManager?.registerServiceInfoCallback(service, { it.run() }, object : NsdManager.ServiceInfoCallback { + override fun onServiceUpdated(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "onServiceUpdated: $serviceInfo") + addOrUpdate(serviceInfo.serviceName, serviceInfo.hostAddresses.toTypedArray(), serviceInfo.port, serviceInfo.attributes) + } + + override fun onServiceLost() { + Log.v(TAG, "onServiceLost: $service") + // TODO: Handle service lost + } + + override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) { + Log.v(TAG, "onServiceInfoCallbackRegistrationFailed: $errorCode") + } + + override fun onServiceInfoCallbackUnregistered() { + Log.v(TAG, "onServiceInfoCallbackUnregistered") + } + }) + } else { + _nsdManager?.resolveService(service, object : NsdManager.ResolveListener { + override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) { + Log.v(TAG, "Resolve failed: $errorCode") + } + + override fun onServiceResolved(serviceInfo: NsdServiceInfo) { + Log.v(TAG, "Resolve Succeeded: $serviceInfo") + addOrUpdate(serviceInfo.serviceName, arrayOf(serviceInfo.host), serviceInfo.port, serviceInfo.attributes) + } + }) + } + } + }) + } } try { @@ -142,7 +264,19 @@ class StateSync { } if (Settings.instance.synchronization.broadcast) { - publicKey?.let { _serviceDiscoverer.broadcastService(getDeviceName(), "_gsync._tcp.local", PORT.toUShort(), texts = arrayListOf("pk=${it.replace('+', '-').replace('/', '_').replace("=", "")}")) } + val pk = publicKey + val nsdManager = _nsdManager + + if (pk != null && nsdManager != null) { + val serviceInfo = NsdServiceInfo().apply { + serviceName = getDeviceName() + serviceType = "_gsync._tcp" + port = PORT + setAttribute("pk", pk.replace('+', '-').replace('/', '_').replace("=", "")) + } + + nsdManager.registerService(serviceInfo, NsdManager.PROTOCOL_DNS_SD, _registrationListener) + } } Logger.i(TAG, "Sync key pair initialized (public key = ${publicKey})") @@ -318,7 +452,7 @@ class StateSync { override val isAuthorized: Boolean get() = true } - _relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null) + _relaySession!!.runAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null) Log.i(TAG, "Started relay session.") } catch (e: Throwable) { @@ -331,6 +465,8 @@ class StateSync { } }.apply { start() } } + + } private fun getDeviceName(): String { @@ -382,48 +518,6 @@ class StateSync { _syncSessionData.setAndSave(data.publicKey, data); } - private fun handleServiceUpdated(services: List) { - if (!Settings.instance.synchronization.connectDiscovered) { - return - } - - for (s in services) { - //TODO: Addresses IPv4 only? - val addresses = s.addresses.mapNotNull { it.hostAddress }.toTypedArray() - val port = s.port.toInt() - if (s.name.endsWith("._gsync._tcp.local")) { - val name = s.name.substring(0, s.name.length - "._gsync._tcp.local".length) - val urlSafePkey = s.texts.firstOrNull { it.startsWith("pk=") }?.substring("pk=".length) ?: continue - val pkey = Base64.getEncoder().encodeToString(Base64.getDecoder().decode(urlSafePkey.replace('-', '+').replace('_', '/'))) - - val syncDeviceInfo = SyncDeviceInfo(pkey, addresses, port, null) - val authorized = isAuthorized(pkey) - - if (authorized && !isConnected(pkey)) { - val now = System.currentTimeMillis() - val lastConnectTime = synchronized(_lastConnectTimesMdns) { - _lastConnectTimesMdns[pkey] ?: 0 - } - - //Connect once every 30 seconds, max - if (now - lastConnectTime > 30000) { - synchronized(_lastConnectTimesMdns) { - _lastConnectTimesMdns[pkey] = now - } - - Logger.i(TAG, "Found device authorized device '${name}' with pkey=$pkey, attempting to connect") - - try { - connect(syncDeviceInfo) - } catch (e: Throwable) { - Logger.i(TAG, "Failed to connect to $pkey", e) - } - } - } - } - } - } - private fun unauthorize(remotePublicKey: String) { Logger.i(TAG, "${remotePublicKey} unauthorized received") _authorizedDevices.remove(remotePublicKey) @@ -899,7 +993,7 @@ class StateSync { fun stop() { _started = false - _serviceDiscoverer.stop() + _nsdManager?.unregisterService(_registrationListener) _serverSocket?.close() _serverSocket = null diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt index ad928698..e66da337 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt @@ -121,6 +121,19 @@ class SyncSocketSession { }.apply { start() } } + fun runAsInitiator(remotePublicKey: String, appId: UInt = 0u, pairingCode: String? = null) { + _started = true + try { + handshakeAsInitiator(remotePublicKey, appId, pairingCode) + _onHandshakeComplete?.invoke(this) + receiveLoop() + } catch (e: Throwable) { + Logger.e(TAG, "Failed to run as initiator", e) + } finally { + stop() + } + } + fun startAsResponder() { _started = true _thread = Thread {