mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2025-04-29 22:24:29 +02:00
Switch to NsdManager.
This commit is contained in:
parent
d9d00e452e
commit
5b143bdc76
@ -4,10 +4,12 @@ import android.app.AlertDialog
|
||||
import android.content.ContentResolver
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.net.nsd.NsdManager
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.os.Looper
|
||||
import android.util.Base64
|
||||
import android.util.Log
|
||||
import android.util.Xml
|
||||
import androidx.annotation.OptIn
|
||||
import androidx.media3.common.util.UnstableApi
|
||||
import com.futo.platformplayer.R
|
||||
@ -40,8 +42,6 @@ import com.futo.platformplayer.constructs.Event1
|
||||
import com.futo.platformplayer.constructs.Event2
|
||||
import com.futo.platformplayer.exceptions.UnsupportedCastException
|
||||
import com.futo.platformplayer.logging.Logger
|
||||
import com.futo.platformplayer.mdns.DnsService
|
||||
import com.futo.platformplayer.mdns.ServiceDiscoverer
|
||||
import com.futo.platformplayer.models.CastingDeviceInfo
|
||||
import com.futo.platformplayer.parsers.HLS
|
||||
import com.futo.platformplayer.states.StateApp
|
||||
@ -55,7 +55,6 @@ import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.net.InetAddress
|
||||
import java.net.URLDecoder
|
||||
import java.net.URLEncoder
|
||||
@ -84,48 +83,15 @@ class StateCasting {
|
||||
private var _audioExecutor: JSRequestExecutor? = null
|
||||
private val _client = ManagedHttpClient();
|
||||
var _resumeCastingDevice: CastingDeviceInfo? = null;
|
||||
val _serviceDiscoverer = ServiceDiscoverer(arrayOf(
|
||||
"_googlecast._tcp.local",
|
||||
"_airplay._tcp.local",
|
||||
"_fastcast._tcp.local",
|
||||
"_fcast._tcp.local"
|
||||
)) { handleServiceUpdated(it) }
|
||||
|
||||
private var _nsdManager: NsdManager? = null
|
||||
val isCasting: Boolean get() = activeDevice != null;
|
||||
|
||||
private fun handleServiceUpdated(services: List<DnsService>) {
|
||||
for (s in services) {
|
||||
//TODO: Addresses IPv4 only?
|
||||
val addresses = s.addresses.toTypedArray()
|
||||
val port = s.port.toInt()
|
||||
var name = s.texts.firstOrNull { it.startsWith("md=") }?.substring("md=".length)
|
||||
if (s.name.endsWith("._googlecast._tcp.local")) {
|
||||
if (name == null) {
|
||||
name = s.name.substring(0, s.name.length - "._googlecast._tcp.local".length)
|
||||
}
|
||||
|
||||
addOrUpdateChromeCastDevice(name, addresses, port)
|
||||
} else if (s.name.endsWith("._airplay._tcp.local")) {
|
||||
if (name == null) {
|
||||
name = s.name.substring(0, s.name.length - "._airplay._tcp.local".length)
|
||||
}
|
||||
|
||||
addOrUpdateAirPlayDevice(name, addresses, port)
|
||||
} else if (s.name.endsWith("._fastcast._tcp.local")) {
|
||||
if (name == null) {
|
||||
name = s.name.substring(0, s.name.length - "._fastcast._tcp.local".length)
|
||||
}
|
||||
|
||||
addOrUpdateFastCastDevice(name, addresses, port)
|
||||
} else if (s.name.endsWith("._fcast._tcp.local")) {
|
||||
if (name == null) {
|
||||
name = s.name.substring(0, s.name.length - "._fcast._tcp.local".length)
|
||||
}
|
||||
|
||||
addOrUpdateFastCastDevice(name, addresses, port)
|
||||
}
|
||||
}
|
||||
}
|
||||
private val _discoveryListeners = mapOf(
|
||||
"_googlecast._tcp" to createDiscoveryListener(::addOrUpdateChromeCastDevice),
|
||||
"_airplay._tcp" to createDiscoveryListener(::addOrUpdateAirPlayDevice),
|
||||
"_fastcast._tcp" to createDiscoveryListener(::addOrUpdateFastCastDevice),
|
||||
"_fcast._tcp" to createDiscoveryListener(::addOrUpdateFastCastDevice)
|
||||
)
|
||||
|
||||
fun handleUrl(context: Context, url: String) {
|
||||
val uri = Uri.parse(url)
|
||||
@ -197,23 +163,25 @@ class StateCasting {
|
||||
enableDeveloper(true);
|
||||
|
||||
Logger.i(TAG, "CastingService started.");
|
||||
|
||||
_nsdManager = context.getSystemService(Context.NSD_SERVICE) as NsdManager
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun startDiscovering() {
|
||||
try {
|
||||
_serviceDiscoverer.start()
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Failed to start ServiceDiscoverer", e)
|
||||
_nsdManager?.apply {
|
||||
_discoveryListeners.forEach {
|
||||
discoverServices(it.key, NsdManager.PROTOCOL_DNS_SD, it.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun stopDiscovering() {
|
||||
try {
|
||||
_serviceDiscoverer.stop()
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Failed to stop ServiceDiscoverer", e)
|
||||
_nsdManager?.apply {
|
||||
_discoveryListeners.forEach {
|
||||
stopServiceDiscovery(it.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +207,77 @@ class StateCasting {
|
||||
_castServer.removeAllHandlers();
|
||||
|
||||
Logger.i(TAG, "CastingService stopped.")
|
||||
|
||||
_nsdManager = null
|
||||
}
|
||||
|
||||
private fun createDiscoveryListener(addOrUpdate: (String, Array<InetAddress>, Int) -> Unit): NsdManager.DiscoveryListener {
|
||||
return object : NsdManager.DiscoveryListener {
|
||||
override fun onDiscoveryStarted(regType: String) {
|
||||
Log.d(TAG, "Service discovery started for $regType")
|
||||
}
|
||||
|
||||
override fun onDiscoveryStopped(serviceType: String) {
|
||||
Log.i(TAG, "Discovery stopped: $serviceType")
|
||||
}
|
||||
|
||||
override fun onServiceLost(service: NsdServiceInfo) {
|
||||
Log.e(TAG, "service lost: $service")
|
||||
// TODO: Handle service lost, e.g., remove device
|
||||
}
|
||||
|
||||
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {
|
||||
Log.e(TAG, "Discovery failed for $serviceType: Error code:$errorCode")
|
||||
_nsdManager?.stopServiceDiscovery(this)
|
||||
}
|
||||
|
||||
override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {
|
||||
Log.e(TAG, "Stop discovery failed for $serviceType: Error code:$errorCode")
|
||||
_nsdManager?.stopServiceDiscovery(this)
|
||||
}
|
||||
|
||||
override fun onServiceFound(service: NsdServiceInfo) {
|
||||
Log.v(TAG, "Service discovery success for ${service.serviceType}: $service")
|
||||
val addresses = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
|
||||
service.hostAddresses.toTypedArray()
|
||||
} else {
|
||||
arrayOf(service.host)
|
||||
}
|
||||
addOrUpdate(service.serviceName, addresses, service.port)
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
|
||||
_nsdManager?.registerServiceInfoCallback(service, { it.run() }, object : NsdManager.ServiceInfoCallback {
|
||||
override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "onServiceUpdated: $serviceInfo")
|
||||
addOrUpdate(serviceInfo.serviceName, serviceInfo.hostAddresses.toTypedArray(), serviceInfo.port)
|
||||
}
|
||||
|
||||
override fun onServiceLost() {
|
||||
Log.v(TAG, "onServiceLost: $service")
|
||||
// TODO: Handle service lost
|
||||
}
|
||||
|
||||
override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) {
|
||||
Log.v(TAG, "onServiceInfoCallbackRegistrationFailed: $errorCode")
|
||||
}
|
||||
|
||||
override fun onServiceInfoCallbackUnregistered() {
|
||||
Log.v(TAG, "onServiceInfoCallbackUnregistered")
|
||||
}
|
||||
})
|
||||
} else {
|
||||
_nsdManager?.resolveService(service, object : NsdManager.ResolveListener {
|
||||
override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
|
||||
Log.v(TAG, "Resolve failed: $errorCode")
|
||||
}
|
||||
|
||||
override fun onServiceResolved(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "Resolve Succeeded: $serviceInfo")
|
||||
addOrUpdate(serviceInfo.serviceName, arrayOf(serviceInfo.host), serviceInfo.port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val _castingDialogLock = Any();
|
||||
|
@ -1,11 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
data class BroadcastService(
|
||||
val deviceName: String,
|
||||
val serviceName: String,
|
||||
val port: UShort,
|
||||
val ttl: UInt,
|
||||
val weight: UShort,
|
||||
val priority: UShort,
|
||||
val texts: List<String>? = null
|
||||
)
|
@ -1,93 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
enum class QueryResponse(val value: Byte) {
|
||||
Query(0),
|
||||
Response(1)
|
||||
}
|
||||
|
||||
enum class DnsOpcode(val value: Byte) {
|
||||
StandardQuery(0),
|
||||
InverseQuery(1),
|
||||
ServerStatusRequest(2)
|
||||
}
|
||||
|
||||
enum class DnsResponseCode(val value: Byte) {
|
||||
NoError(0),
|
||||
FormatError(1),
|
||||
ServerFailure(2),
|
||||
NameError(3),
|
||||
NotImplemented(4),
|
||||
Refused(5)
|
||||
}
|
||||
|
||||
data class DnsPacketHeader(
|
||||
val identifier: UShort,
|
||||
val queryResponse: Int,
|
||||
val opcode: Int,
|
||||
val authoritativeAnswer: Boolean,
|
||||
val truncated: Boolean,
|
||||
val recursionDesired: Boolean,
|
||||
val recursionAvailable: Boolean,
|
||||
val answerAuthenticated: Boolean,
|
||||
val nonAuthenticatedData: Boolean,
|
||||
val responseCode: DnsResponseCode
|
||||
)
|
||||
|
||||
data class DnsPacket(
|
||||
val header: DnsPacketHeader,
|
||||
val questions: List<DnsQuestion>,
|
||||
val answers: List<DnsResourceRecord>,
|
||||
val authorities: List<DnsResourceRecord>,
|
||||
val additionals: List<DnsResourceRecord>
|
||||
) {
|
||||
companion object {
|
||||
fun parse(data: ByteArray): DnsPacket {
|
||||
val span = data.asUByteArray()
|
||||
val flags = (span[2].toInt() shl 8 or span[3].toInt()).toUShort()
|
||||
val questionCount = (span[4].toInt() shl 8 or span[5].toInt()).toUShort()
|
||||
val answerCount = (span[6].toInt() shl 8 or span[7].toInt()).toUShort()
|
||||
val authorityCount = (span[8].toInt() shl 8 or span[9].toInt()).toUShort()
|
||||
val additionalCount = (span[10].toInt() shl 8 or span[11].toInt()).toUShort()
|
||||
|
||||
var position = 12
|
||||
|
||||
val questions = List(questionCount.toInt()) {
|
||||
DnsQuestion.parse(data, position).also { position = it.second }
|
||||
}.map { it.first }
|
||||
|
||||
val answers = List(answerCount.toInt()) {
|
||||
DnsResourceRecord.parse(data, position).also { position = it.second }
|
||||
}.map { it.first }
|
||||
|
||||
val authorities = List(authorityCount.toInt()) {
|
||||
DnsResourceRecord.parse(data, position).also { position = it.second }
|
||||
}.map { it.first }
|
||||
|
||||
val additionals = List(additionalCount.toInt()) {
|
||||
DnsResourceRecord.parse(data, position).also { position = it.second }
|
||||
}.map { it.first }
|
||||
|
||||
return DnsPacket(
|
||||
header = DnsPacketHeader(
|
||||
identifier = (span[0].toInt() shl 8 or span[1].toInt()).toUShort(),
|
||||
queryResponse = ((flags.toUInt() shr 15) and 0b1u).toInt(),
|
||||
opcode = ((flags.toUInt() shr 11) and 0b1111u).toInt(),
|
||||
authoritativeAnswer = (flags.toInt() shr 10) and 0b1 != 0,
|
||||
truncated = (flags.toInt() shr 9) and 0b1 != 0,
|
||||
recursionDesired = (flags.toInt() shr 8) and 0b1 != 0,
|
||||
recursionAvailable = (flags.toInt() shr 7) and 0b1 != 0,
|
||||
answerAuthenticated = (flags.toInt() shr 5) and 0b1 != 0,
|
||||
nonAuthenticatedData = (flags.toInt() shr 4) and 0b1 != 0,
|
||||
responseCode = DnsResponseCode.entries[flags.toInt() and 0b1111]
|
||||
),
|
||||
questions = questions,
|
||||
answers = answers,
|
||||
authorities = authorities,
|
||||
additionals = additionals
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,110 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.mdns.Extensions.readDomainName
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
|
||||
enum class QuestionType(val value: UShort) {
|
||||
A(1u),
|
||||
NS(2u),
|
||||
MD(3u),
|
||||
MF(4u),
|
||||
CNAME(5u),
|
||||
SOA(6u),
|
||||
MB(7u),
|
||||
MG(8u),
|
||||
MR(9u),
|
||||
NULL(10u),
|
||||
WKS(11u),
|
||||
PTR(12u),
|
||||
HINFO(13u),
|
||||
MINFO(14u),
|
||||
MX(15u),
|
||||
TXT(16u),
|
||||
RP(17u),
|
||||
AFSDB(18u),
|
||||
SIG(24u),
|
||||
KEY(25u),
|
||||
AAAA(28u),
|
||||
LOC(29u),
|
||||
SRV(33u),
|
||||
NAPTR(35u),
|
||||
KX(36u),
|
||||
CERT(37u),
|
||||
DNAME(39u),
|
||||
APL(42u),
|
||||
DS(43u),
|
||||
SSHFP(44u),
|
||||
IPSECKEY(45u),
|
||||
RRSIG(46u),
|
||||
NSEC(47u),
|
||||
DNSKEY(48u),
|
||||
DHCID(49u),
|
||||
NSEC3(50u),
|
||||
NSEC3PARAM(51u),
|
||||
TSLA(52u),
|
||||
SMIMEA(53u),
|
||||
HIP(55u),
|
||||
CDS(59u),
|
||||
CDNSKEY(60u),
|
||||
OPENPGPKEY(61u),
|
||||
CSYNC(62u),
|
||||
ZONEMD(63u),
|
||||
SVCB(64u),
|
||||
HTTPS(65u),
|
||||
EUI48(108u),
|
||||
EUI64(109u),
|
||||
TKEY(249u),
|
||||
TSIG(250u),
|
||||
URI(256u),
|
||||
CAA(257u),
|
||||
TA(32768u),
|
||||
DLV(32769u),
|
||||
AXFR(252u),
|
||||
IXFR(251u),
|
||||
OPT(41u),
|
||||
MAILB(253u),
|
||||
MALA(254u),
|
||||
All(252u)
|
||||
}
|
||||
|
||||
enum class QuestionClass(val value: UShort) {
|
||||
IN(1u),
|
||||
CS(2u),
|
||||
CH(3u),
|
||||
HS(4u),
|
||||
All(255u)
|
||||
}
|
||||
|
||||
data class DnsQuestion(
|
||||
override val name: String,
|
||||
override val type: Int,
|
||||
override val clazz: Int,
|
||||
val queryUnicast: Boolean
|
||||
) : DnsResourceRecordBase(name, type, clazz) {
|
||||
companion object {
|
||||
fun parse(data: ByteArray, startPosition: Int): Pair<DnsQuestion, Int> {
|
||||
val span = data.asUByteArray()
|
||||
var position = startPosition
|
||||
val qname = span.readDomainName(position).also { position = it.second }
|
||||
val qtype = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort()
|
||||
position += 2
|
||||
val qclass = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort()
|
||||
position += 2
|
||||
|
||||
return DnsQuestion(
|
||||
name = qname.first,
|
||||
type = qtype.toInt(),
|
||||
queryUnicast = ((qclass.toInt() shr 15) and 0b1) != 0,
|
||||
clazz = qclass.toInt() and 0b111111111111111
|
||||
) to position
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
open class DnsResourceRecordBase(
|
||||
open val name: String,
|
||||
open val type: Int,
|
||||
open val clazz: Int
|
||||
)
|
@ -1,514 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.mdns.Extensions.readDomainName
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.nio.charset.StandardCharsets
|
||||
import kotlin.math.pow
|
||||
import java.net.InetAddress
|
||||
|
||||
data class PTRRecord(val domainName: String)
|
||||
|
||||
data class ARecord(val address: InetAddress)
|
||||
|
||||
data class AAAARecord(val address: InetAddress)
|
||||
|
||||
data class MXRecord(val preference: UShort, val exchange: String)
|
||||
|
||||
data class CNAMERecord(val cname: String)
|
||||
|
||||
data class TXTRecord(val texts: List<String>)
|
||||
|
||||
data class SOARecord(
|
||||
val primaryNameServer: String,
|
||||
val responsibleAuthorityMailbox: String,
|
||||
val serialNumber: Int,
|
||||
val refreshInterval: Int,
|
||||
val retryInterval: Int,
|
||||
val expiryLimit: Int,
|
||||
val minimumTTL: Int
|
||||
)
|
||||
|
||||
data class SRVRecord(val priority: UShort, val weight: UShort, val port: UShort, val target: String)
|
||||
|
||||
data class NSRecord(val nameServer: String)
|
||||
|
||||
data class CAARecord(val flags: Byte, val tag: String, val value: String)
|
||||
|
||||
data class HINFORecord(val cpu: String, val os: String)
|
||||
|
||||
data class RPRecord(val mailbox: String, val txtDomainName: String)
|
||||
|
||||
|
||||
data class AFSDBRecord(val subtype: UShort, val hostname: String)
|
||||
data class LOCRecord(
|
||||
val version: Byte,
|
||||
val size: Double,
|
||||
val horizontalPrecision: Double,
|
||||
val verticalPrecision: Double,
|
||||
val latitude: Double,
|
||||
val longitude: Double,
|
||||
val altitude: Double
|
||||
) {
|
||||
companion object {
|
||||
fun decodeSizeOrPrecision(coded: Byte): Double {
|
||||
val baseValue = (coded.toInt() shr 4) and 0x0F
|
||||
val exponent = coded.toInt() and 0x0F
|
||||
return baseValue * 10.0.pow(exponent.toDouble())
|
||||
}
|
||||
|
||||
fun decodeLatitudeOrLongitude(coded: Int): Double {
|
||||
val arcSeconds = coded / 1E3
|
||||
return arcSeconds / 3600.0
|
||||
}
|
||||
|
||||
fun decodeAltitude(coded: Int): Double {
|
||||
return (coded / 100.0) - 100000.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class NAPTRRecord(
|
||||
val order: UShort,
|
||||
val preference: UShort,
|
||||
val flags: String,
|
||||
val services: String,
|
||||
val regexp: String,
|
||||
val replacement: String
|
||||
)
|
||||
|
||||
data class RRSIGRecord(
|
||||
val typeCovered: UShort,
|
||||
val algorithm: Byte,
|
||||
val labels: Byte,
|
||||
val originalTTL: UInt,
|
||||
val signatureExpiration: UInt,
|
||||
val signatureInception: UInt,
|
||||
val keyTag: UShort,
|
||||
val signersName: String,
|
||||
val signature: ByteArray
|
||||
)
|
||||
|
||||
data class KXRecord(val preference: UShort, val exchanger: String)
|
||||
|
||||
data class CERTRecord(val type: UShort, val keyTag: UShort, val algorithm: Byte, val certificate: ByteArray)
|
||||
|
||||
|
||||
|
||||
data class DNAMERecord(val target: String)
|
||||
|
||||
data class DSRecord(val keyTag: UShort, val algorithm: Byte, val digestType: Byte, val digest: ByteArray)
|
||||
|
||||
data class SSHFPRecord(val algorithm: Byte, val fingerprintType: Byte, val fingerprint: ByteArray)
|
||||
|
||||
data class TLSARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray)
|
||||
|
||||
data class SMIMEARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray)
|
||||
|
||||
data class URIRecord(val priority: UShort, val weight: UShort, val target: String)
|
||||
|
||||
data class NSECRecord(val ownerName: String, val typeBitMaps: List<Pair<Byte, ByteArray>>)
|
||||
data class NSEC3Record(
|
||||
val hashAlgorithm: Byte,
|
||||
val flags: Byte,
|
||||
val iterations: UShort,
|
||||
val salt: ByteArray,
|
||||
val nextHashedOwnerName: ByteArray,
|
||||
val typeBitMaps: List<UShort>
|
||||
)
|
||||
|
||||
data class NSEC3PARAMRecord(val hashAlgorithm: Byte, val flags: Byte, val iterations: UShort, val salt: ByteArray)
|
||||
data class SPFRecord(val texts: List<String>)
|
||||
data class TKEYRecord(
|
||||
val algorithm: String,
|
||||
val inception: UInt,
|
||||
val expiration: UInt,
|
||||
val mode: UShort,
|
||||
val error: UShort,
|
||||
val keyData: ByteArray,
|
||||
val otherData: ByteArray
|
||||
)
|
||||
|
||||
data class TSIGRecord(
|
||||
val algorithmName: String,
|
||||
val timeSigned: UInt,
|
||||
val fudge: UShort,
|
||||
val mac: ByteArray,
|
||||
val originalID: UShort,
|
||||
val error: UShort,
|
||||
val otherData: ByteArray
|
||||
)
|
||||
|
||||
data class OPTRecordOption(val code: UShort, val data: ByteArray)
|
||||
data class OPTRecord(val options: List<OPTRecordOption>)
|
||||
|
||||
class DnsReader(private val data: ByteArray, private var position: Int = 0, private val length: Int = data.size) {
|
||||
|
||||
private val endPosition: Int = position + length
|
||||
|
||||
fun readDomainName(): String {
|
||||
return data.asUByteArray().readDomainName(position).also { position = it.second }.first
|
||||
}
|
||||
|
||||
fun readDouble(): Double {
|
||||
checkRemainingBytes(Double.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).double
|
||||
position += Double.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readInt16(): Short {
|
||||
checkRemainingBytes(Short.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short
|
||||
position += Short.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readInt32(): Int {
|
||||
checkRemainingBytes(Int.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int
|
||||
position += Int.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readInt64(): Long {
|
||||
checkRemainingBytes(Long.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long
|
||||
position += Long.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readSingle(): Float {
|
||||
checkRemainingBytes(Float.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).float
|
||||
position += Float.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readByte(): Byte {
|
||||
checkRemainingBytes(Byte.SIZE_BYTES)
|
||||
return data[position++]
|
||||
}
|
||||
|
||||
fun readBytes(length: Int): ByteArray {
|
||||
checkRemainingBytes(length)
|
||||
return ByteArray(length).also { data.copyInto(it, startIndex = position, endIndex = position + length) }
|
||||
.also { position += length }
|
||||
}
|
||||
|
||||
fun readUInt16(): UShort {
|
||||
checkRemainingBytes(Short.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short.toUShort()
|
||||
position += Short.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readUInt32(): UInt {
|
||||
checkRemainingBytes(Int.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int.toUInt()
|
||||
position += Int.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readUInt64(): ULong {
|
||||
checkRemainingBytes(Long.SIZE_BYTES)
|
||||
val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long.toULong()
|
||||
position += Long.SIZE_BYTES
|
||||
return result
|
||||
}
|
||||
|
||||
fun readString(): String {
|
||||
val length = data[position++].toInt()
|
||||
checkRemainingBytes(length)
|
||||
return String(data, position, length, StandardCharsets.UTF_8).also { position += length }
|
||||
}
|
||||
|
||||
private fun checkRemainingBytes(requiredBytes: Int) {
|
||||
if (position + requiredBytes > endPosition) throw IndexOutOfBoundsException()
|
||||
}
|
||||
|
||||
fun readRPRecord(): RPRecord {
|
||||
return RPRecord(readDomainName(), readDomainName())
|
||||
}
|
||||
|
||||
fun readKXRecord(): KXRecord {
|
||||
val preference = readUInt16()
|
||||
val exchanger = readDomainName()
|
||||
return KXRecord(preference, exchanger)
|
||||
}
|
||||
|
||||
fun readCERTRecord(): CERTRecord {
|
||||
val type = readUInt16()
|
||||
val keyTag = readUInt16()
|
||||
val algorithm = readByte()
|
||||
val certificateLength = readUInt16().toInt() - 5
|
||||
val certificate = readBytes(certificateLength)
|
||||
return CERTRecord(type, keyTag, algorithm, certificate)
|
||||
}
|
||||
|
||||
fun readPTRRecord(): PTRRecord {
|
||||
return PTRRecord(readDomainName())
|
||||
}
|
||||
|
||||
fun readARecord(): ARecord {
|
||||
val address = readBytes(4)
|
||||
return ARecord(InetAddress.getByAddress(address))
|
||||
}
|
||||
|
||||
fun readAAAARecord(): AAAARecord {
|
||||
val address = readBytes(16)
|
||||
return AAAARecord(InetAddress.getByAddress(address))
|
||||
}
|
||||
|
||||
fun readMXRecord(): MXRecord {
|
||||
val preference = readUInt16()
|
||||
val exchange = readDomainName()
|
||||
return MXRecord(preference, exchange)
|
||||
}
|
||||
|
||||
fun readCNAMERecord(): CNAMERecord {
|
||||
return CNAMERecord(readDomainName())
|
||||
}
|
||||
|
||||
fun readTXTRecord(): TXTRecord {
|
||||
val texts = mutableListOf<String>()
|
||||
while (position < endPosition) {
|
||||
val textLength = data[position++].toInt()
|
||||
checkRemainingBytes(textLength)
|
||||
val text = String(data, position, textLength, StandardCharsets.UTF_8)
|
||||
texts.add(text)
|
||||
position += textLength
|
||||
}
|
||||
return TXTRecord(texts)
|
||||
}
|
||||
|
||||
fun readSOARecord(): SOARecord {
|
||||
val primaryNameServer = readDomainName()
|
||||
val responsibleAuthorityMailbox = readDomainName()
|
||||
val serialNumber = readInt32()
|
||||
val refreshInterval = readInt32()
|
||||
val retryInterval = readInt32()
|
||||
val expiryLimit = readInt32()
|
||||
val minimumTTL = readInt32()
|
||||
return SOARecord(primaryNameServer, responsibleAuthorityMailbox, serialNumber, refreshInterval, retryInterval, expiryLimit, minimumTTL)
|
||||
}
|
||||
|
||||
fun readSRVRecord(): SRVRecord {
|
||||
val priority = readUInt16()
|
||||
val weight = readUInt16()
|
||||
val port = readUInt16()
|
||||
val target = readDomainName()
|
||||
return SRVRecord(priority, weight, port, target)
|
||||
}
|
||||
|
||||
fun readNSRecord(): NSRecord {
|
||||
return NSRecord(readDomainName())
|
||||
}
|
||||
|
||||
fun readCAARecord(): CAARecord {
|
||||
val length = readUInt16().toInt()
|
||||
val flags = readByte()
|
||||
val tagLength = readByte().toInt()
|
||||
val tag = String(data, position, tagLength, StandardCharsets.US_ASCII).also { position += tagLength }
|
||||
val valueLength = length - 1 - 1 - tagLength
|
||||
val value = String(data, position, valueLength, StandardCharsets.US_ASCII).also { position += valueLength }
|
||||
return CAARecord(flags, tag, value)
|
||||
}
|
||||
|
||||
fun readHINFORecord(): HINFORecord {
|
||||
val cpuLength = readByte().toInt()
|
||||
val cpu = String(data, position, cpuLength, StandardCharsets.US_ASCII).also { position += cpuLength }
|
||||
val osLength = readByte().toInt()
|
||||
val os = String(data, position, osLength, StandardCharsets.US_ASCII).also { position += osLength }
|
||||
return HINFORecord(cpu, os)
|
||||
}
|
||||
|
||||
fun readAFSDBRecord(): AFSDBRecord {
|
||||
return AFSDBRecord(readUInt16(), readDomainName())
|
||||
}
|
||||
|
||||
fun readLOCRecord(): LOCRecord {
|
||||
val version = readByte()
|
||||
val size = LOCRecord.decodeSizeOrPrecision(readByte())
|
||||
val horizontalPrecision = LOCRecord.decodeSizeOrPrecision(readByte())
|
||||
val verticalPrecision = LOCRecord.decodeSizeOrPrecision(readByte())
|
||||
val latitudeCoded = readInt32()
|
||||
val longitudeCoded = readInt32()
|
||||
val altitudeCoded = readInt32()
|
||||
val latitude = LOCRecord.decodeLatitudeOrLongitude(latitudeCoded)
|
||||
val longitude = LOCRecord.decodeLatitudeOrLongitude(longitudeCoded)
|
||||
val altitude = LOCRecord.decodeAltitude(altitudeCoded)
|
||||
return LOCRecord(version, size, horizontalPrecision, verticalPrecision, latitude, longitude, altitude)
|
||||
}
|
||||
|
||||
fun readNAPTRRecord(): NAPTRRecord {
|
||||
val order = readUInt16()
|
||||
val preference = readUInt16()
|
||||
val flags = readString()
|
||||
val services = readString()
|
||||
val regexp = readString()
|
||||
val replacement = readDomainName()
|
||||
return NAPTRRecord(order, preference, flags, services, regexp, replacement)
|
||||
}
|
||||
|
||||
fun readDNAMERecord(): DNAMERecord {
|
||||
return DNAMERecord(readDomainName())
|
||||
}
|
||||
|
||||
fun readDSRecord(): DSRecord {
|
||||
val keyTag = readUInt16()
|
||||
val algorithm = readByte()
|
||||
val digestType = readByte()
|
||||
val digestLength = readUInt16().toInt() - 4
|
||||
val digest = readBytes(digestLength)
|
||||
return DSRecord(keyTag, algorithm, digestType, digest)
|
||||
}
|
||||
|
||||
fun readSSHFPRecord(): SSHFPRecord {
|
||||
val algorithm = readByte()
|
||||
val fingerprintType = readByte()
|
||||
val fingerprintLength = readUInt16().toInt() - 2
|
||||
val fingerprint = readBytes(fingerprintLength)
|
||||
return SSHFPRecord(algorithm, fingerprintType, fingerprint)
|
||||
}
|
||||
|
||||
fun readTLSARecord(): TLSARecord {
|
||||
val usage = readByte()
|
||||
val selector = readByte()
|
||||
val matchingType = readByte()
|
||||
val dataLength = readUInt16().toInt() - 3
|
||||
val certificateAssociationData = readBytes(dataLength)
|
||||
return TLSARecord(usage, selector, matchingType, certificateAssociationData)
|
||||
}
|
||||
|
||||
fun readSMIMEARecord(): SMIMEARecord {
|
||||
val usage = readByte()
|
||||
val selector = readByte()
|
||||
val matchingType = readByte()
|
||||
val dataLength = readUInt16().toInt() - 3
|
||||
val certificateAssociationData = readBytes(dataLength)
|
||||
return SMIMEARecord(usage, selector, matchingType, certificateAssociationData)
|
||||
}
|
||||
|
||||
fun readURIRecord(): URIRecord {
|
||||
val priority = readUInt16()
|
||||
val weight = readUInt16()
|
||||
val length = readUInt16().toInt()
|
||||
val target = String(data, position, length, StandardCharsets.US_ASCII).also { position += length }
|
||||
return URIRecord(priority, weight, target)
|
||||
}
|
||||
|
||||
fun readRRSIGRecord(): RRSIGRecord {
|
||||
val typeCovered = readUInt16()
|
||||
val algorithm = readByte()
|
||||
val labels = readByte()
|
||||
val originalTTL = readUInt32()
|
||||
val signatureExpiration = readUInt32()
|
||||
val signatureInception = readUInt32()
|
||||
val keyTag = readUInt16()
|
||||
val signersName = readDomainName()
|
||||
val signatureLength = readUInt16().toInt()
|
||||
val signature = readBytes(signatureLength)
|
||||
return RRSIGRecord(
|
||||
typeCovered,
|
||||
algorithm,
|
||||
labels,
|
||||
originalTTL,
|
||||
signatureExpiration,
|
||||
signatureInception,
|
||||
keyTag,
|
||||
signersName,
|
||||
signature
|
||||
)
|
||||
}
|
||||
|
||||
fun readNSECRecord(): NSECRecord {
|
||||
val ownerName = readDomainName()
|
||||
val typeBitMaps = mutableListOf<Pair<Byte, ByteArray>>()
|
||||
while (position < endPosition) {
|
||||
val windowBlock = readByte()
|
||||
val bitmapLength = readByte().toInt()
|
||||
val bitmap = readBytes(bitmapLength)
|
||||
typeBitMaps.add(windowBlock to bitmap)
|
||||
}
|
||||
return NSECRecord(ownerName, typeBitMaps)
|
||||
}
|
||||
|
||||
fun readNSEC3Record(): NSEC3Record {
|
||||
val hashAlgorithm = readByte()
|
||||
val flags = readByte()
|
||||
val iterations = readUInt16()
|
||||
val saltLength = readByte().toInt()
|
||||
val salt = readBytes(saltLength)
|
||||
val hashLength = readByte().toInt()
|
||||
val nextHashedOwnerName = readBytes(hashLength)
|
||||
val bitMapLength = readUInt16().toInt()
|
||||
val typeBitMaps = mutableListOf<UShort>()
|
||||
val endPos = position + bitMapLength
|
||||
while (position < endPos) {
|
||||
typeBitMaps.add(readUInt16())
|
||||
}
|
||||
return NSEC3Record(hashAlgorithm, flags, iterations, salt, nextHashedOwnerName, typeBitMaps)
|
||||
}
|
||||
|
||||
fun readNSEC3PARAMRecord(): NSEC3PARAMRecord {
|
||||
val hashAlgorithm = readByte()
|
||||
val flags = readByte()
|
||||
val iterations = readUInt16()
|
||||
val saltLength = readByte().toInt()
|
||||
val salt = readBytes(saltLength)
|
||||
return NSEC3PARAMRecord(hashAlgorithm, flags, iterations, salt)
|
||||
}
|
||||
|
||||
|
||||
fun readSPFRecord(): SPFRecord {
|
||||
val length = readUInt16().toInt()
|
||||
val texts = mutableListOf<String>()
|
||||
val endPos = position + length
|
||||
while (position < endPos) {
|
||||
val textLength = readByte().toInt()
|
||||
val text = String(data, position, textLength, StandardCharsets.US_ASCII).also { position += textLength }
|
||||
texts.add(text)
|
||||
}
|
||||
return SPFRecord(texts)
|
||||
}
|
||||
|
||||
fun readTKEYRecord(): TKEYRecord {
|
||||
val algorithm = readDomainName()
|
||||
val inception = readUInt32()
|
||||
val expiration = readUInt32()
|
||||
val mode = readUInt16()
|
||||
val error = readUInt16()
|
||||
val keySize = readUInt16().toInt()
|
||||
val keyData = readBytes(keySize)
|
||||
val otherSize = readUInt16().toInt()
|
||||
val otherData = readBytes(otherSize)
|
||||
return TKEYRecord(algorithm, inception, expiration, mode, error, keyData, otherData)
|
||||
}
|
||||
|
||||
fun readTSIGRecord(): TSIGRecord {
|
||||
val algorithmName = readDomainName()
|
||||
val timeSigned = readUInt32()
|
||||
val fudge = readUInt16()
|
||||
val macSize = readUInt16().toInt()
|
||||
val mac = readBytes(macSize)
|
||||
val originalID = readUInt16()
|
||||
val error = readUInt16()
|
||||
val otherSize = readUInt16().toInt()
|
||||
val otherData = readBytes(otherSize)
|
||||
return TSIGRecord(algorithmName, timeSigned, fudge, mac, originalID, error, otherData)
|
||||
}
|
||||
|
||||
|
||||
|
||||
fun readOPTRecord(): OPTRecord {
|
||||
val options = mutableListOf<OPTRecordOption>()
|
||||
while (position < endPosition) {
|
||||
val optionCode = readUInt16()
|
||||
val optionLength = readUInt16().toInt()
|
||||
val optionData = readBytes(optionLength)
|
||||
options.add(OPTRecordOption(optionCode, optionData))
|
||||
}
|
||||
return OPTRecord(options)
|
||||
}
|
||||
}
|
@ -1,117 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.mdns.Extensions.readDomainName
|
||||
|
||||
enum class ResourceRecordType(val value: UShort) {
|
||||
None(0u),
|
||||
A(1u),
|
||||
NS(2u),
|
||||
MD(3u),
|
||||
MF(4u),
|
||||
CNAME(5u),
|
||||
SOA(6u),
|
||||
MB(7u),
|
||||
MG(8u),
|
||||
MR(9u),
|
||||
NULL(10u),
|
||||
WKS(11u),
|
||||
PTR(12u),
|
||||
HINFO(13u),
|
||||
MINFO(14u),
|
||||
MX(15u),
|
||||
TXT(16u),
|
||||
RP(17u),
|
||||
AFSDB(18u),
|
||||
SIG(24u),
|
||||
KEY(25u),
|
||||
AAAA(28u),
|
||||
LOC(29u),
|
||||
SRV(33u),
|
||||
NAPTR(35u),
|
||||
KX(36u),
|
||||
CERT(37u),
|
||||
DNAME(39u),
|
||||
APL(42u),
|
||||
DS(43u),
|
||||
SSHFP(44u),
|
||||
IPSECKEY(45u),
|
||||
RRSIG(46u),
|
||||
NSEC(47u),
|
||||
DNSKEY(48u),
|
||||
DHCID(49u),
|
||||
NSEC3(50u),
|
||||
NSEC3PARAM(51u),
|
||||
TSLA(52u),
|
||||
SMIMEA(53u),
|
||||
HIP(55u),
|
||||
CDS(59u),
|
||||
CDNSKEY(60u),
|
||||
OPENPGPKEY(61u),
|
||||
CSYNC(62u),
|
||||
ZONEMD(63u),
|
||||
SVCB(64u),
|
||||
HTTPS(65u),
|
||||
EUI48(108u),
|
||||
EUI64(109u),
|
||||
TKEY(249u),
|
||||
TSIG(250u),
|
||||
URI(256u),
|
||||
CAA(257u),
|
||||
TA(32768u),
|
||||
DLV(32769u),
|
||||
AXFR(252u),
|
||||
IXFR(251u),
|
||||
OPT(41u)
|
||||
}
|
||||
|
||||
enum class ResourceRecordClass(val value: UShort) {
|
||||
IN(1u),
|
||||
CS(2u),
|
||||
CH(3u),
|
||||
HS(4u)
|
||||
}
|
||||
|
||||
data class DnsResourceRecord(
|
||||
override val name: String,
|
||||
override val type: Int,
|
||||
override val clazz: Int,
|
||||
val timeToLive: UInt,
|
||||
val cacheFlush: Boolean,
|
||||
val dataPosition: Int = -1,
|
||||
val dataLength: Int = -1,
|
||||
private val data: ByteArray? = null
|
||||
) : DnsResourceRecordBase(name, type, clazz) {
|
||||
|
||||
companion object {
|
||||
fun parse(data: ByteArray, startPosition: Int): Pair<DnsResourceRecord, Int> {
|
||||
val span = data.asUByteArray()
|
||||
var position = startPosition
|
||||
val name = span.readDomainName(position).also { position = it.second }
|
||||
val type = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort()
|
||||
position += 2
|
||||
val clazz = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort()
|
||||
position += 2
|
||||
val ttl = (span[position].toInt() shl 24 or (span[position + 1].toInt() shl 16) or
|
||||
(span[position + 2].toInt() shl 8) or span[position + 3].toInt()).toUInt()
|
||||
position += 4
|
||||
val rdlength = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort()
|
||||
val rdposition = position + 2
|
||||
position += 2 + rdlength.toInt()
|
||||
|
||||
return DnsResourceRecord(
|
||||
name = name.first,
|
||||
type = type.toInt(),
|
||||
clazz = clazz.toInt() and 0b1111111_11111111,
|
||||
timeToLive = ttl,
|
||||
cacheFlush = ((clazz.toInt() shr 15) and 0b1) != 0,
|
||||
dataPosition = rdposition,
|
||||
dataLength = rdlength.toInt(),
|
||||
data = data
|
||||
) to position
|
||||
}
|
||||
}
|
||||
|
||||
fun getDataReader(): DnsReader {
|
||||
return DnsReader(data!!, dataPosition, dataLength)
|
||||
}
|
||||
}
|
@ -1,208 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
class DnsWriter {
|
||||
private val data = mutableListOf<Byte>()
|
||||
private val namePositions = mutableMapOf<String, Int>()
|
||||
|
||||
fun toByteArray(): ByteArray = data.toByteArray()
|
||||
|
||||
fun writePacket(
|
||||
header: DnsPacketHeader,
|
||||
questionCount: Int? = null, questionWriter: ((DnsWriter, Int) -> Unit)? = null,
|
||||
answerCount: Int? = null, answerWriter: ((DnsWriter, Int) -> Unit)? = null,
|
||||
authorityCount: Int? = null, authorityWriter: ((DnsWriter, Int) -> Unit)? = null,
|
||||
additionalsCount: Int? = null, additionalWriter: ((DnsWriter, Int) -> Unit)? = null
|
||||
) {
|
||||
if (questionCount != null && questionWriter == null || questionCount == null && questionWriter != null)
|
||||
throw Exception("When question count is given, question writer should also be given.")
|
||||
if (answerCount != null && answerWriter == null || answerCount == null && answerWriter != null)
|
||||
throw Exception("When answer count is given, answer writer should also be given.")
|
||||
if (authorityCount != null && authorityWriter == null || authorityCount == null && authorityWriter != null)
|
||||
throw Exception("When authority count is given, authority writer should also be given.")
|
||||
if (additionalsCount != null && additionalWriter == null || additionalsCount == null && additionalWriter != null)
|
||||
throw Exception("When additionals count is given, additional writer should also be given.")
|
||||
|
||||
writeHeader(header, questionCount ?: 0, answerCount ?: 0, authorityCount ?: 0, additionalsCount ?: 0)
|
||||
|
||||
repeat(questionCount ?: 0) { questionWriter?.invoke(this, it) }
|
||||
repeat(answerCount ?: 0) { answerWriter?.invoke(this, it) }
|
||||
repeat(authorityCount ?: 0) { authorityWriter?.invoke(this, it) }
|
||||
repeat(additionalsCount ?: 0) { additionalWriter?.invoke(this, it) }
|
||||
}
|
||||
|
||||
fun writeHeader(header: DnsPacketHeader, questionCount: Int, answerCount: Int, authorityCount: Int, additionalsCount: Int) {
|
||||
write(header.identifier)
|
||||
|
||||
var flags: UShort = 0u
|
||||
flags = flags or ((header.queryResponse.toUInt() and 0xFFFFu) shl 15).toUShort()
|
||||
flags = flags or ((header.opcode.toUInt() and 0xFFFFu) shl 11).toUShort()
|
||||
flags = flags or ((if (header.authoritativeAnswer) 1u else 0u) shl 10).toUShort()
|
||||
flags = flags or ((if (header.truncated) 1u else 0u) shl 9).toUShort()
|
||||
flags = flags or ((if (header.recursionDesired) 1u else 0u) shl 8).toUShort()
|
||||
flags = flags or ((if (header.recursionAvailable) 1u else 0u) shl 7).toUShort()
|
||||
flags = flags or ((if (header.answerAuthenticated) 1u else 0u) shl 5).toUShort()
|
||||
flags = flags or ((if (header.nonAuthenticatedData) 1u else 0u) shl 4).toUShort()
|
||||
flags = flags or header.responseCode.value.toUShort()
|
||||
write(flags)
|
||||
|
||||
write(questionCount.toUShort())
|
||||
write(answerCount.toUShort())
|
||||
write(authorityCount.toUShort())
|
||||
write(additionalsCount.toUShort())
|
||||
}
|
||||
|
||||
fun writeDomainName(name: String) {
|
||||
synchronized(namePositions) {
|
||||
val labels = name.split('.')
|
||||
for (label in labels) {
|
||||
val nameAtOffset = name.substring(name.indexOf(label))
|
||||
if (namePositions.containsKey(nameAtOffset)) {
|
||||
val position = namePositions[nameAtOffset]!!
|
||||
val pointer = (0b11000000_00000000 or position).toUShort()
|
||||
write(pointer)
|
||||
return
|
||||
}
|
||||
if (label.isNotEmpty()) {
|
||||
val labelBytes = label.toByteArray(StandardCharsets.UTF_8)
|
||||
val nameStartPos = data.size
|
||||
write(labelBytes.size.toByte())
|
||||
write(labelBytes)
|
||||
namePositions[nameAtOffset] = nameStartPos
|
||||
}
|
||||
}
|
||||
write(0.toByte())
|
||||
}
|
||||
}
|
||||
|
||||
fun write(value: DnsResourceRecord, dataWriter: (DnsWriter) -> Unit) {
|
||||
writeDomainName(value.name)
|
||||
write(value.type.toUShort())
|
||||
val cls = ((if (value.cacheFlush) 1u else 0u) shl 15).toUShort() or value.clazz.toUShort()
|
||||
write(cls)
|
||||
write(value.timeToLive)
|
||||
|
||||
val lengthOffset = data.size
|
||||
write(0.toUShort())
|
||||
dataWriter(this)
|
||||
val rdLength = data.size - lengthOffset - 2
|
||||
val rdLengthBytes = ByteBuffer.allocate(2).order(ByteOrder.BIG_ENDIAN).putShort(rdLength.toShort()).array()
|
||||
data[lengthOffset] = rdLengthBytes[0]
|
||||
data[lengthOffset + 1] = rdLengthBytes[1]
|
||||
}
|
||||
|
||||
fun write(value: DnsQuestion) {
|
||||
writeDomainName(value.name)
|
||||
write(value.type.toUShort())
|
||||
write(((if (value.queryUnicast) 1u else 0u shl 15).toUShort() or value.clazz.toUShort()))
|
||||
}
|
||||
|
||||
fun write(value: Double) {
|
||||
val bytes = ByteBuffer.allocate(Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putDouble(value).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: Short) {
|
||||
val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: Int) {
|
||||
val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: Long) {
|
||||
val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: Float) {
|
||||
val bytes = ByteBuffer.allocate(Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putFloat(value).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: Byte) {
|
||||
data.add(value)
|
||||
}
|
||||
|
||||
fun write(value: ByteArray) {
|
||||
data.addAll(value.asIterable())
|
||||
}
|
||||
|
||||
fun write(value: ByteArray, offset: Int, length: Int) {
|
||||
data.addAll(value.slice(offset until offset + length))
|
||||
}
|
||||
|
||||
fun write(value: UShort) {
|
||||
val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value.toShort()).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: UInt) {
|
||||
val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value.toInt()).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: ULong) {
|
||||
val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value.toLong()).array()
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: String) {
|
||||
val bytes = value.toByteArray(StandardCharsets.UTF_8)
|
||||
write(bytes.size.toByte())
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: PTRRecord) {
|
||||
writeDomainName(value.domainName)
|
||||
}
|
||||
|
||||
fun write(value: ARecord) {
|
||||
val bytes = value.address.address
|
||||
if (bytes.size != 4) throw Exception("Unexpected amount of address bytes.")
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: AAAARecord) {
|
||||
val bytes = value.address.address
|
||||
if (bytes.size != 16) throw Exception("Unexpected amount of address bytes.")
|
||||
write(bytes)
|
||||
}
|
||||
|
||||
fun write(value: TXTRecord) {
|
||||
value.texts.forEach {
|
||||
val bytes = it.toByteArray(StandardCharsets.UTF_8)
|
||||
write(bytes.size.toByte())
|
||||
write(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fun write(value: SRVRecord) {
|
||||
write(value.priority)
|
||||
write(value.weight)
|
||||
write(value.port)
|
||||
writeDomainName(value.target)
|
||||
}
|
||||
|
||||
fun write(value: NSECRecord) {
|
||||
writeDomainName(value.ownerName)
|
||||
value.typeBitMaps.forEach { (windowBlock, bitmap) ->
|
||||
write(windowBlock)
|
||||
write(bitmap.size.toByte())
|
||||
write(bitmap)
|
||||
}
|
||||
}
|
||||
|
||||
fun write(value: OPTRecord) {
|
||||
value.options.forEach { option ->
|
||||
write(option.code)
|
||||
write(option.data.size.toUShort())
|
||||
write(option.data)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import android.util.Log
|
||||
|
||||
object Extensions {
|
||||
fun ByteArray.toByteDump(): String {
|
||||
val result = StringBuilder()
|
||||
for (i in indices) {
|
||||
result.append(String.format("%02X ", this[i]))
|
||||
|
||||
if ((i + 1) % 16 == 0 || i == size - 1) {
|
||||
val padding = 3 * (16 - (i % 16 + 1))
|
||||
if (i == size - 1 && (i + 1) % 16 != 0) result.append(" ".repeat(padding))
|
||||
|
||||
result.append("; ")
|
||||
val start = i - (i % 16)
|
||||
val end = minOf(i, size - 1)
|
||||
for (j in start..end) {
|
||||
val ch = if (this[j] in 32..127) this[j].toChar() else '.'
|
||||
result.append(ch)
|
||||
}
|
||||
if (i != size - 1) result.appendLine()
|
||||
}
|
||||
}
|
||||
return result.toString()
|
||||
}
|
||||
|
||||
fun UByteArray.readDomainName(startPosition: Int): Pair<String, Int> {
|
||||
var position = startPosition
|
||||
return readDomainName(position, 0)
|
||||
}
|
||||
|
||||
private fun UByteArray.readDomainName(position: Int, depth: Int = 0): Pair<String, Int> {
|
||||
if (depth > 16) throw Exception("Exceeded maximum recursion depth in DNS packet. Possible circular reference.")
|
||||
|
||||
val domainParts = mutableListOf<String>()
|
||||
var newPosition = position
|
||||
|
||||
while (true) {
|
||||
if (newPosition < 0)
|
||||
println()
|
||||
|
||||
val length = this[newPosition].toUByte()
|
||||
if ((length and 0b11000000u).toUInt() == 0b11000000u) {
|
||||
val offset = (((length and 0b00111111u).toUInt()) shl 8) or this[newPosition + 1].toUInt()
|
||||
val (part, _) = this.readDomainName(offset.toInt(), depth + 1)
|
||||
domainParts.add(part)
|
||||
newPosition += 2
|
||||
break
|
||||
} else if (length.toUInt() == 0u) {
|
||||
newPosition++
|
||||
break
|
||||
} else {
|
||||
newPosition++
|
||||
val part = String(this.asByteArray(), newPosition, length.toInt(), Charsets.UTF_8)
|
||||
domainParts.add(part)
|
||||
newPosition += length.toInt()
|
||||
}
|
||||
}
|
||||
|
||||
return domainParts.joinToString(".") to newPosition
|
||||
}
|
||||
}
|
@ -1,501 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.logging.Logger
|
||||
import kotlinx.coroutines.*
|
||||
import java.net.*
|
||||
import java.util.*
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
|
||||
class MDNSListener {
|
||||
companion object {
|
||||
private val TAG = "MDNSListener"
|
||||
const val MulticastPort = 5353
|
||||
val MulticastAddressIPv4: InetAddress = InetAddress.getByName("224.0.0.251")
|
||||
val MulticastAddressIPv6: InetAddress = InetAddress.getByName("FF02::FB")
|
||||
val MdnsEndpointIPv6: InetSocketAddress = InetSocketAddress(MulticastAddressIPv6, MulticastPort)
|
||||
val MdnsEndpointIPv4: InetSocketAddress = InetSocketAddress(MulticastAddressIPv4, MulticastPort)
|
||||
}
|
||||
|
||||
private val _lockObject = ReentrantLock()
|
||||
private var _receiver4: MulticastSocket? = null
|
||||
private var _receiver6: MulticastSocket? = null
|
||||
private val _senders = mutableListOf<MulticastSocket>()
|
||||
private val _nicMonitor = NICMonitor()
|
||||
private val _serviceRecordAggregator = ServiceRecordAggregator()
|
||||
private var _started = false
|
||||
private var _threadReceiver4: Thread? = null
|
||||
private var _threadReceiver6: Thread? = null
|
||||
private var _scope: CoroutineScope? = null
|
||||
|
||||
var onPacket: ((DnsPacket) -> Unit)? = null
|
||||
var onServicesUpdated: ((List<DnsService>) -> Unit)? = null
|
||||
|
||||
private val _recordLockObject = ReentrantLock()
|
||||
private val _recordsA = mutableListOf<Pair<DnsResourceRecord, ARecord>>()
|
||||
private val _recordsAAAA = mutableListOf<Pair<DnsResourceRecord, AAAARecord>>()
|
||||
private val _recordsPTR = mutableListOf<Pair<DnsResourceRecord, PTRRecord>>()
|
||||
private val _recordsTXT = mutableListOf<Pair<DnsResourceRecord, TXTRecord>>()
|
||||
private val _recordsSRV = mutableListOf<Pair<DnsResourceRecord, SRVRecord>>()
|
||||
private val _services = mutableListOf<BroadcastService>()
|
||||
|
||||
init {
|
||||
_nicMonitor.added = { onNicsAdded(it) }
|
||||
_nicMonitor.removed = { onNicsRemoved(it) }
|
||||
_serviceRecordAggregator.onServicesUpdated = { onServicesUpdated?.invoke(it) }
|
||||
}
|
||||
|
||||
fun start() {
|
||||
if (_started) {
|
||||
Logger.i(TAG, "Already started.")
|
||||
return
|
||||
}
|
||||
_started = true
|
||||
|
||||
_scope = CoroutineScope(Dispatchers.IO);
|
||||
|
||||
Logger.i(TAG, "Starting")
|
||||
_lockObject.withLock {
|
||||
val receiver4 = MulticastSocket(null).apply {
|
||||
reuseAddress = true
|
||||
bind(InetSocketAddress(InetAddress.getByName("0.0.0.0"), MulticastPort))
|
||||
}
|
||||
|
||||
_receiver4 = receiver4
|
||||
|
||||
val receiver6 = MulticastSocket(null).apply {
|
||||
reuseAddress = true
|
||||
bind(InetSocketAddress(InetAddress.getByName("::"), MulticastPort))
|
||||
}
|
||||
_receiver6 = receiver6
|
||||
|
||||
_nicMonitor.start()
|
||||
_serviceRecordAggregator.start()
|
||||
onNicsAdded(_nicMonitor.current)
|
||||
|
||||
_threadReceiver4 = Thread {
|
||||
receiveLoop(receiver4)
|
||||
}.apply { start() }
|
||||
|
||||
_threadReceiver6 = Thread {
|
||||
receiveLoop(receiver6)
|
||||
}.apply { start() }
|
||||
}
|
||||
}
|
||||
|
||||
fun queryServices(names: Array<String>) {
|
||||
if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.")
|
||||
|
||||
val writer = DnsWriter()
|
||||
writer.writePacket(
|
||||
DnsPacketHeader(
|
||||
identifier = 0u,
|
||||
queryResponse = QueryResponse.Query.value.toInt(),
|
||||
opcode = DnsOpcode.StandardQuery.value.toInt(),
|
||||
truncated = false,
|
||||
nonAuthenticatedData = false,
|
||||
recursionDesired = false,
|
||||
answerAuthenticated = false,
|
||||
authoritativeAnswer = false,
|
||||
recursionAvailable = false,
|
||||
responseCode = DnsResponseCode.NoError
|
||||
),
|
||||
questionCount = names.size,
|
||||
questionWriter = { w, i ->
|
||||
w.write(
|
||||
DnsQuestion(
|
||||
name = names[i],
|
||||
type = QuestionType.PTR.value.toInt(),
|
||||
clazz = QuestionClass.IN.value.toInt(),
|
||||
queryUnicast = false
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
send(writer.toByteArray())
|
||||
}
|
||||
|
||||
private fun send(data: ByteArray) {
|
||||
_lockObject.withLock {
|
||||
for (sender in _senders) {
|
||||
try {
|
||||
val endPoint = if (sender.localAddress is Inet4Address) MdnsEndpointIPv4 else MdnsEndpointIPv6
|
||||
sender.send(DatagramPacket(data, data.size, endPoint))
|
||||
} catch (e: Exception) {
|
||||
Logger.i(TAG, "Failed to send on ${sender.localSocketAddress}: ${e.message}.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun queryAllQuestions(names: Array<String>) {
|
||||
if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.")
|
||||
|
||||
val questions = names.flatMap { _serviceRecordAggregator.getAllQuestions(it) }
|
||||
questions.groupBy { it.name }.forEach { (_, questionsForHost) ->
|
||||
val writer = DnsWriter()
|
||||
writer.writePacket(
|
||||
DnsPacketHeader(
|
||||
identifier = 0u,
|
||||
queryResponse = QueryResponse.Query.value.toInt(),
|
||||
opcode = DnsOpcode.StandardQuery.value.toInt(),
|
||||
truncated = false,
|
||||
nonAuthenticatedData = false,
|
||||
recursionDesired = false,
|
||||
answerAuthenticated = false,
|
||||
authoritativeAnswer = false,
|
||||
recursionAvailable = false,
|
||||
responseCode = DnsResponseCode.NoError
|
||||
),
|
||||
questionCount = questionsForHost.size,
|
||||
questionWriter = { w, i -> w.write(questionsForHost[i]) }
|
||||
)
|
||||
send(writer.toByteArray())
|
||||
}
|
||||
}
|
||||
|
||||
private fun onNicsAdded(nics: List<NetworkInterface>) {
|
||||
_lockObject.withLock {
|
||||
if (!_started) return
|
||||
|
||||
val addresses = nics.flatMap { nic ->
|
||||
nic.interfaceAddresses.map { it.address }
|
||||
.filter { it is Inet4Address || (it is Inet6Address && it.isLinkLocalAddress) }
|
||||
}
|
||||
|
||||
addresses.forEach { address ->
|
||||
Logger.i(TAG, "New address discovered $address")
|
||||
|
||||
try {
|
||||
when (address) {
|
||||
is Inet4Address -> {
|
||||
_receiver4?.let { receiver4 ->
|
||||
//receiver4.setOption(StandardSocketOptions.IP_MULTICAST_IF, NetworkInterface.getByInetAddress(address))
|
||||
val ni = NetworkInterface.getByInetAddress(address)
|
||||
receiver4.networkInterface = ni
|
||||
receiver4.joinGroup(InetSocketAddress(MulticastAddressIPv4, MulticastPort), ni)
|
||||
}
|
||||
|
||||
val sender = MulticastSocket(null).apply {
|
||||
reuseAddress = true
|
||||
bind(InetSocketAddress(address, MulticastPort))
|
||||
joinGroup(InetSocketAddress(MulticastAddressIPv4, MulticastPort), NetworkInterface.getByInetAddress(address))
|
||||
}
|
||||
_senders.add(sender)
|
||||
}
|
||||
|
||||
is Inet6Address -> {
|
||||
_receiver6?.let { receiver6 ->
|
||||
//receiver6.setOption(StandardSocketOptions.IP_MULTICAST_IF, NetworkInterface.getByInetAddress(address))
|
||||
val ni = NetworkInterface.getByInetAddress(address)
|
||||
receiver6.networkInterface = ni
|
||||
receiver6.joinGroup(InetSocketAddress(MulticastAddressIPv6, MulticastPort), ni)
|
||||
}
|
||||
|
||||
val sender = MulticastSocket(null).apply {
|
||||
reuseAddress = true
|
||||
bind(InetSocketAddress(address, MulticastPort))
|
||||
joinGroup(InetSocketAddress(MulticastAddressIPv6, MulticastPort), NetworkInterface.getByInetAddress(address))
|
||||
}
|
||||
_senders.add(sender)
|
||||
}
|
||||
|
||||
else -> throw UnsupportedOperationException("Address type ${address.javaClass.name} is not supported.")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Logger.i(TAG, "Exception occurred when processing added address $address: ${e.message}.")
|
||||
// Close the socket if there was an error
|
||||
(_senders.lastOrNull() as? MulticastSocket)?.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nics.isNotEmpty()) {
|
||||
try {
|
||||
updateBroadcastRecords()
|
||||
broadcastRecords()
|
||||
} catch (e: Exception) {
|
||||
Logger.i(TAG, "Exception occurred when broadcasting records: ${e.message}.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun onNicsRemoved(nics: List<NetworkInterface>) {
|
||||
_lockObject.withLock {
|
||||
if (!_started) return
|
||||
//TODO: Cleanup?
|
||||
}
|
||||
|
||||
if (nics.isNotEmpty()) {
|
||||
try {
|
||||
updateBroadcastRecords()
|
||||
broadcastRecords()
|
||||
} catch (e: Exception) {
|
||||
Logger.e(TAG, "Exception occurred when broadcasting records", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun receiveLoop(client: DatagramSocket) {
|
||||
Logger.i(TAG, "Started receive loop")
|
||||
|
||||
val buffer = ByteArray(8972)
|
||||
val packet = DatagramPacket(buffer, buffer.size)
|
||||
while (_started) {
|
||||
try {
|
||||
client.receive(packet)
|
||||
handleResult(packet)
|
||||
} catch (e: Exception) {
|
||||
Logger.e(TAG, "An exception occurred while handling UDP result:", e)
|
||||
}
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Stopped receive loop")
|
||||
}
|
||||
|
||||
fun broadcastService(
|
||||
deviceName: String,
|
||||
serviceName: String,
|
||||
port: UShort,
|
||||
ttl: UInt = 120u,
|
||||
weight: UShort = 0u,
|
||||
priority: UShort = 0u,
|
||||
texts: List<String>? = null
|
||||
) {
|
||||
_recordLockObject.withLock {
|
||||
_services.add(
|
||||
BroadcastService(
|
||||
deviceName = deviceName,
|
||||
port = port,
|
||||
priority = priority,
|
||||
serviceName = serviceName,
|
||||
texts = texts,
|
||||
ttl = ttl,
|
||||
weight = weight
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
updateBroadcastRecords()
|
||||
broadcastRecords()
|
||||
}
|
||||
|
||||
private fun updateBroadcastRecords() {
|
||||
_recordLockObject.withLock {
|
||||
_recordsSRV.clear()
|
||||
_recordsPTR.clear()
|
||||
_recordsA.clear()
|
||||
_recordsAAAA.clear()
|
||||
_recordsTXT.clear()
|
||||
|
||||
_services.forEach { service ->
|
||||
val id = UUID.randomUUID().toString()
|
||||
val deviceDomainName = "${service.deviceName}.${service.serviceName}"
|
||||
val addressName = "$id.local"
|
||||
|
||||
_recordsSRV.add(
|
||||
DnsResourceRecord(
|
||||
clazz = ResourceRecordClass.IN.value.toInt(),
|
||||
type = ResourceRecordType.SRV.value.toInt(),
|
||||
timeToLive = service.ttl,
|
||||
name = deviceDomainName,
|
||||
cacheFlush = false
|
||||
) to SRVRecord(
|
||||
target = addressName,
|
||||
port = service.port,
|
||||
priority = service.priority,
|
||||
weight = service.weight
|
||||
)
|
||||
)
|
||||
|
||||
_recordsPTR.add(
|
||||
DnsResourceRecord(
|
||||
clazz = ResourceRecordClass.IN.value.toInt(),
|
||||
type = ResourceRecordType.PTR.value.toInt(),
|
||||
timeToLive = service.ttl,
|
||||
name = service.serviceName,
|
||||
cacheFlush = false
|
||||
) to PTRRecord(
|
||||
domainName = deviceDomainName
|
||||
)
|
||||
)
|
||||
|
||||
val addresses = _nicMonitor.current.flatMap { nic ->
|
||||
nic.interfaceAddresses.map { it.address }
|
||||
}
|
||||
|
||||
addresses.forEach { address ->
|
||||
when (address) {
|
||||
is Inet4Address -> _recordsA.add(
|
||||
DnsResourceRecord(
|
||||
clazz = ResourceRecordClass.IN.value.toInt(),
|
||||
type = ResourceRecordType.A.value.toInt(),
|
||||
timeToLive = service.ttl,
|
||||
name = addressName,
|
||||
cacheFlush = false
|
||||
) to ARecord(
|
||||
address = address
|
||||
)
|
||||
)
|
||||
|
||||
is Inet6Address -> _recordsAAAA.add(
|
||||
DnsResourceRecord(
|
||||
clazz = ResourceRecordClass.IN.value.toInt(),
|
||||
type = ResourceRecordType.AAAA.value.toInt(),
|
||||
timeToLive = service.ttl,
|
||||
name = addressName,
|
||||
cacheFlush = false
|
||||
) to AAAARecord(
|
||||
address = address
|
||||
)
|
||||
)
|
||||
|
||||
else -> Logger.i(TAG, "Invalid address type: $address.")
|
||||
}
|
||||
}
|
||||
|
||||
if (service.texts != null) {
|
||||
_recordsTXT.add(
|
||||
DnsResourceRecord(
|
||||
clazz = ResourceRecordClass.IN.value.toInt(),
|
||||
type = ResourceRecordType.TXT.value.toInt(),
|
||||
timeToLive = service.ttl,
|
||||
name = deviceDomainName,
|
||||
cacheFlush = false
|
||||
) to TXTRecord(
|
||||
texts = service.texts
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun broadcastRecords(questions: List<DnsQuestion>? = null) {
|
||||
val writer = DnsWriter()
|
||||
_recordLockObject.withLock {
|
||||
val recordsA: List<Pair<DnsResourceRecord, ARecord>>
|
||||
val recordsAAAA: List<Pair<DnsResourceRecord, AAAARecord>>
|
||||
val recordsPTR: List<Pair<DnsResourceRecord, PTRRecord>>
|
||||
val recordsTXT: List<Pair<DnsResourceRecord, TXTRecord>>
|
||||
val recordsSRV: List<Pair<DnsResourceRecord, SRVRecord>>
|
||||
|
||||
if (questions != null) {
|
||||
recordsA = _recordsA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } }
|
||||
recordsAAAA = _recordsAAAA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } }
|
||||
recordsPTR = _recordsPTR.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } }
|
||||
recordsSRV = _recordsSRV.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } }
|
||||
recordsTXT = _recordsTXT.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } }
|
||||
} else {
|
||||
recordsA = _recordsA
|
||||
recordsAAAA = _recordsAAAA
|
||||
recordsPTR = _recordsPTR
|
||||
recordsSRV = _recordsSRV
|
||||
recordsTXT = _recordsTXT
|
||||
}
|
||||
|
||||
val answerCount = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size + recordsTXT.size
|
||||
if (answerCount < 1) return
|
||||
|
||||
val txtOffset = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size
|
||||
val srvOffset = recordsA.size + recordsAAAA.size + recordsPTR.size
|
||||
val ptrOffset = recordsA.size + recordsAAAA.size
|
||||
val aaaaOffset = recordsA.size
|
||||
|
||||
writer.writePacket(
|
||||
DnsPacketHeader(
|
||||
identifier = 0u,
|
||||
queryResponse = QueryResponse.Response.value.toInt(),
|
||||
opcode = DnsOpcode.StandardQuery.value.toInt(),
|
||||
truncated = false,
|
||||
nonAuthenticatedData = false,
|
||||
recursionDesired = false,
|
||||
answerAuthenticated = false,
|
||||
authoritativeAnswer = true,
|
||||
recursionAvailable = false,
|
||||
responseCode = DnsResponseCode.NoError
|
||||
),
|
||||
answerCount = answerCount,
|
||||
answerWriter = { w, i ->
|
||||
when {
|
||||
i >= txtOffset -> {
|
||||
val record = recordsTXT[i - txtOffset]
|
||||
w.write(record.first) { it.write(record.second) }
|
||||
}
|
||||
|
||||
i >= srvOffset -> {
|
||||
val record = recordsSRV[i - srvOffset]
|
||||
w.write(record.first) { it.write(record.second) }
|
||||
}
|
||||
|
||||
i >= ptrOffset -> {
|
||||
val record = recordsPTR[i - ptrOffset]
|
||||
w.write(record.first) { it.write(record.second) }
|
||||
}
|
||||
|
||||
i >= aaaaOffset -> {
|
||||
val record = recordsAAAA[i - aaaaOffset]
|
||||
w.write(record.first) { it.write(record.second) }
|
||||
}
|
||||
|
||||
else -> {
|
||||
val record = recordsA[i]
|
||||
w.write(record.first) { it.write(record.second) }
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
send(writer.toByteArray())
|
||||
}
|
||||
|
||||
private fun handleResult(result: DatagramPacket) {
|
||||
try {
|
||||
val packet = DnsPacket.parse(result.data)
|
||||
if (packet.questions.isNotEmpty()) {
|
||||
_scope?.launch(Dispatchers.IO) {
|
||||
try {
|
||||
broadcastRecords(packet.questions)
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Broadcasting records failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
_serviceRecordAggregator.add(packet)
|
||||
onPacket?.invoke(packet)
|
||||
} catch (e: Exception) {
|
||||
Logger.v(TAG, "Failed to handle packet: ${Base64.getEncoder().encodeToString(result.data.slice(IntRange(0, result.length - 1)).toByteArray())}", e)
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
_lockObject.withLock {
|
||||
_started = false
|
||||
|
||||
_scope?.cancel()
|
||||
_scope = null
|
||||
|
||||
_nicMonitor.stop()
|
||||
_serviceRecordAggregator.stop()
|
||||
|
||||
_receiver4?.close()
|
||||
_receiver4 = null
|
||||
|
||||
_receiver6?.close()
|
||||
_receiver6 = null
|
||||
|
||||
_senders.forEach { it.close() }
|
||||
_senders.clear()
|
||||
}
|
||||
|
||||
_threadReceiver4?.join()
|
||||
_threadReceiver4 = null
|
||||
|
||||
_threadReceiver6?.join()
|
||||
_threadReceiver6 = null
|
||||
}
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import java.net.NetworkInterface
|
||||
|
||||
class NICMonitor {
|
||||
private val lockObject = Any()
|
||||
private val nics = mutableListOf<NetworkInterface>()
|
||||
private var cts: Job? = null
|
||||
|
||||
val current: List<NetworkInterface>
|
||||
get() = synchronized(nics) { nics.toList() }
|
||||
|
||||
var added: ((List<NetworkInterface>) -> Unit)? = null
|
||||
var removed: ((List<NetworkInterface>) -> Unit)? = null
|
||||
|
||||
fun start() {
|
||||
synchronized(lockObject) {
|
||||
if (cts != null) throw Exception("Already started.")
|
||||
|
||||
cts = CoroutineScope(Dispatchers.Default).launch {
|
||||
loopAsync()
|
||||
}
|
||||
}
|
||||
|
||||
nics.clear()
|
||||
nics.addAll(getCurrentInterfaces().toList())
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
synchronized(lockObject) {
|
||||
cts?.cancel()
|
||||
cts = null
|
||||
}
|
||||
|
||||
synchronized(nics) {
|
||||
nics.clear()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun loopAsync() {
|
||||
while (cts?.isActive == true) {
|
||||
try {
|
||||
val currentNics = getCurrentInterfaces().toList()
|
||||
removed?.invoke(nics.filter { k -> currentNics.none { n -> k.name == n.name } })
|
||||
added?.invoke(currentNics.filter { nic -> nics.none { k -> k.name == nic.name } })
|
||||
|
||||
synchronized(nics) {
|
||||
nics.clear()
|
||||
nics.addAll(currentNics)
|
||||
}
|
||||
} catch (ex: Exception) {
|
||||
// Ignored
|
||||
}
|
||||
delay(5000)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getCurrentInterfaces(): List<NetworkInterface> {
|
||||
val nics = NetworkInterface.getNetworkInterfaces().toList()
|
||||
.filter { it.isUp && !it.isLoopback }
|
||||
|
||||
return if (nics.isNotEmpty()) nics else NetworkInterface.getNetworkInterfaces().toList()
|
||||
.filter { it.isUp }
|
||||
}
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.logging.Logger
|
||||
import java.lang.Thread.sleep
|
||||
|
||||
class ServiceDiscoverer(names: Array<String>, private val _onServicesUpdated: (List<DnsService>) -> Unit) {
|
||||
private val _names: Array<String>
|
||||
private var _listener: MDNSListener? = null
|
||||
private var _started = false
|
||||
private var _thread: Thread? = null
|
||||
|
||||
init {
|
||||
if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.")
|
||||
_names = names
|
||||
}
|
||||
|
||||
fun broadcastService(
|
||||
deviceName: String,
|
||||
serviceName: String,
|
||||
port: UShort,
|
||||
ttl: UInt = 120u,
|
||||
weight: UShort = 0u,
|
||||
priority: UShort = 0u,
|
||||
texts: List<String>? = null
|
||||
) {
|
||||
_listener?.let {
|
||||
it.broadcastService(deviceName, serviceName, port, ttl, weight, priority, texts)
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
_started = false
|
||||
_listener?.stop()
|
||||
_listener = null
|
||||
_thread?.join()
|
||||
_thread = null
|
||||
}
|
||||
|
||||
fun start() {
|
||||
if (_started) {
|
||||
Logger.i(TAG, "Already started.")
|
||||
return
|
||||
}
|
||||
_started = true
|
||||
|
||||
val listener = MDNSListener()
|
||||
_listener = listener
|
||||
listener.onServicesUpdated = { _onServicesUpdated?.invoke(it) }
|
||||
listener.start()
|
||||
|
||||
_thread = Thread {
|
||||
try {
|
||||
sleep(2000)
|
||||
|
||||
while (_started) {
|
||||
listener.queryServices(_names)
|
||||
sleep(2000)
|
||||
listener.queryAllQuestions(_names)
|
||||
sleep(2000)
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Exception in loop thread", e)
|
||||
stop()
|
||||
}
|
||||
}.apply { start() }
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = "ServiceDiscoverer"
|
||||
}
|
||||
}
|
@ -1,226 +0,0 @@
|
||||
package com.futo.platformplayer.mdns
|
||||
|
||||
import com.futo.platformplayer.logging.Logger
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import java.net.InetAddress
|
||||
import java.util.Date
|
||||
|
||||
data class DnsService(
|
||||
var name: String,
|
||||
var target: String,
|
||||
var port: UShort,
|
||||
val addresses: MutableList<InetAddress> = mutableListOf(),
|
||||
val pointers: MutableList<String> = mutableListOf(),
|
||||
val texts: MutableList<String> = mutableListOf()
|
||||
)
|
||||
|
||||
data class CachedDnsAddressRecord(
|
||||
val expirationTime: Date,
|
||||
val address: InetAddress
|
||||
)
|
||||
|
||||
data class CachedDnsTxtRecord(
|
||||
val expirationTime: Date,
|
||||
val texts: List<String>
|
||||
)
|
||||
|
||||
data class CachedDnsPtrRecord(
|
||||
val expirationTime: Date,
|
||||
val target: String
|
||||
)
|
||||
|
||||
data class CachedDnsSrvRecord(
|
||||
val expirationTime: Date,
|
||||
val service: SRVRecord
|
||||
)
|
||||
|
||||
class ServiceRecordAggregator {
|
||||
private val _lockObject = Any()
|
||||
private val _cachedAddressRecords = mutableMapOf<String, MutableList<CachedDnsAddressRecord>>()
|
||||
private val _cachedTxtRecords = mutableMapOf<String, CachedDnsTxtRecord>()
|
||||
private val _cachedPtrRecords = mutableMapOf<String, MutableList<CachedDnsPtrRecord>>()
|
||||
private val _cachedSrvRecords = mutableMapOf<String, CachedDnsSrvRecord>()
|
||||
private val _currentServices = mutableListOf<DnsService>()
|
||||
private var _cts: Job? = null
|
||||
|
||||
var onServicesUpdated: ((List<DnsService>) -> Unit)? = null
|
||||
|
||||
fun start() {
|
||||
synchronized(_lockObject) {
|
||||
if (_cts != null) throw Exception("Already started.")
|
||||
|
||||
_cts = CoroutineScope(Dispatchers.Default).launch {
|
||||
try {
|
||||
while (isActive) {
|
||||
val now = Date()
|
||||
synchronized(_currentServices) {
|
||||
_cachedAddressRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } }
|
||||
_cachedTxtRecords.entries.removeIf { now.after(it.value.expirationTime) }
|
||||
_cachedSrvRecords.entries.removeIf { now.after(it.value.expirationTime) }
|
||||
_cachedPtrRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } }
|
||||
|
||||
val newServices = getCurrentServices()
|
||||
_currentServices.clear()
|
||||
_currentServices.addAll(newServices)
|
||||
}
|
||||
|
||||
onServicesUpdated?.invoke(_currentServices.toList())
|
||||
delay(5000)
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Unexpected failure in MDNS loop", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
synchronized(_lockObject) {
|
||||
_cts?.cancel()
|
||||
_cts = null
|
||||
}
|
||||
}
|
||||
|
||||
fun add(packet: DnsPacket) {
|
||||
val currentServices: List<DnsService>
|
||||
val dnsResourceRecords = packet.answers + packet.additionals + packet.authorities
|
||||
val txtRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.TXT.value.toInt() }.map { it to it.getDataReader().readTXTRecord() }
|
||||
val aRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.A.value.toInt() }.map { it to it.getDataReader().readARecord() }
|
||||
val aaaaRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.AAAA.value.toInt() }.map { it to it.getDataReader().readAAAARecord() }
|
||||
val srvRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.SRV.value.toInt() }.map { it to it.getDataReader().readSRVRecord() }
|
||||
val ptrRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.PTR.value.toInt() }.map { it to it.getDataReader().readPTRRecord() }
|
||||
|
||||
/*val builder = StringBuilder()
|
||||
builder.appendLine("Received records:")
|
||||
srvRecords.forEach { builder.appendLine("SRV ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: (Port: ${it.second.port}, Target: ${it.second.target}, Priority: ${it.second.priority}, Weight: ${it.second.weight})") }
|
||||
ptrRecords.forEach { builder.appendLine("PTR ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.domainName}") }
|
||||
txtRecords.forEach { builder.appendLine("TXT ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.texts.joinToString(", ")}") }
|
||||
aRecords.forEach { builder.appendLine("A ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") }
|
||||
aaaaRecords.forEach { builder.appendLine("AAAA ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") }
|
||||
Logger.i(TAG, "$builder")*/
|
||||
|
||||
synchronized(this._currentServices) {
|
||||
ptrRecords.forEach { record ->
|
||||
val cachedPtrRecord = _cachedPtrRecords.getOrPut(record.first.name) { mutableListOf() }
|
||||
val newPtrRecord = CachedDnsPtrRecord(Date(System.currentTimeMillis() + record.first.timeToLive.toLong() * 1000L), record.second.domainName)
|
||||
cachedPtrRecord.replaceOrAdd(newPtrRecord) { it.target == record.second.domainName }
|
||||
}
|
||||
|
||||
aRecords.forEach { aRecord ->
|
||||
val cachedARecord = _cachedAddressRecords.getOrPut(aRecord.first.name) { mutableListOf() }
|
||||
val newARecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aRecord.first.timeToLive.toLong() * 1000L), aRecord.second.address)
|
||||
cachedARecord.replaceOrAdd(newARecord) { it.address == newARecord.address }
|
||||
}
|
||||
|
||||
aaaaRecords.forEach { aaaaRecord ->
|
||||
val cachedAaaaRecord = _cachedAddressRecords.getOrPut(aaaaRecord.first.name) { mutableListOf() }
|
||||
val newAaaaRecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aaaaRecord.first.timeToLive.toLong() * 1000L), aaaaRecord.second.address)
|
||||
cachedAaaaRecord.replaceOrAdd(newAaaaRecord) { it.address == newAaaaRecord.address }
|
||||
}
|
||||
|
||||
txtRecords.forEach { txtRecord ->
|
||||
_cachedTxtRecords[txtRecord.first.name] = CachedDnsTxtRecord(Date(System.currentTimeMillis() + txtRecord.first.timeToLive.toLong() * 1000L), txtRecord.second.texts)
|
||||
}
|
||||
|
||||
srvRecords.forEach { srvRecord ->
|
||||
_cachedSrvRecords[srvRecord.first.name] = CachedDnsSrvRecord(Date(System.currentTimeMillis() + srvRecord.first.timeToLive.toLong() * 1000L), srvRecord.second)
|
||||
}
|
||||
|
||||
currentServices = getCurrentServices()
|
||||
this._currentServices.clear()
|
||||
this._currentServices.addAll(currentServices)
|
||||
}
|
||||
|
||||
onServicesUpdated?.invoke(currentServices)
|
||||
}
|
||||
|
||||
fun getAllQuestions(serviceName: String): List<DnsQuestion> {
|
||||
val questions = mutableListOf<DnsQuestion>()
|
||||
synchronized(_currentServices) {
|
||||
val servicePtrRecords = _cachedPtrRecords[serviceName] ?: return emptyList()
|
||||
|
||||
val ptrWithoutSrvRecord = servicePtrRecords.filterNot { _cachedSrvRecords.containsKey(it.target) }.map { it.target }
|
||||
questions.addAll(ptrWithoutSrvRecord.flatMap { s ->
|
||||
listOf(
|
||||
DnsQuestion(
|
||||
name = s,
|
||||
type = QuestionType.SRV.value.toInt(),
|
||||
clazz = QuestionClass.IN.value.toInt(),
|
||||
queryUnicast = false
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
val incompleteCurrentServices = _currentServices.filter { it.addresses.isEmpty() && it.name.endsWith(serviceName) }
|
||||
questions.addAll(incompleteCurrentServices.flatMap { s ->
|
||||
listOf(
|
||||
DnsQuestion(
|
||||
name = s.name,
|
||||
type = QuestionType.TXT.value.toInt(),
|
||||
clazz = QuestionClass.IN.value.toInt(),
|
||||
queryUnicast = false
|
||||
),
|
||||
DnsQuestion(
|
||||
name = s.target,
|
||||
type = QuestionType.A.value.toInt(),
|
||||
clazz = QuestionClass.IN.value.toInt(),
|
||||
queryUnicast = false
|
||||
),
|
||||
DnsQuestion(
|
||||
name = s.target,
|
||||
type = QuestionType.AAAA.value.toInt(),
|
||||
clazz = QuestionClass.IN.value.toInt(),
|
||||
queryUnicast = false
|
||||
)
|
||||
)
|
||||
})
|
||||
}
|
||||
return questions
|
||||
}
|
||||
|
||||
private fun getCurrentServices(): MutableList<DnsService> {
|
||||
val currentServices = _cachedSrvRecords.map { (key, value) ->
|
||||
DnsService(
|
||||
name = key,
|
||||
target = value.service.target,
|
||||
port = value.service.port
|
||||
)
|
||||
}.toMutableList()
|
||||
|
||||
currentServices.forEach { service ->
|
||||
_cachedAddressRecords[service.target]?.let {
|
||||
service.addresses.addAll(it.map { record -> record.address })
|
||||
}
|
||||
}
|
||||
|
||||
currentServices.forEach { service ->
|
||||
service.pointers.addAll(_cachedPtrRecords.filter { it.value.any { ptr -> ptr.target == service.name } }.map { it.key })
|
||||
}
|
||||
|
||||
currentServices.forEach { service ->
|
||||
_cachedTxtRecords[service.name]?.let {
|
||||
service.texts.addAll(it.texts)
|
||||
}
|
||||
}
|
||||
|
||||
return currentServices
|
||||
}
|
||||
|
||||
private inline fun <T> MutableList<T>.replaceOrAdd(newElement: T, predicate: (T) -> Boolean) {
|
||||
val index = indexOfFirst(predicate)
|
||||
if (index >= 0) {
|
||||
this[index] = newElement
|
||||
} else {
|
||||
add(newElement)
|
||||
}
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private const val TAG = "ServiceRecordAggregator"
|
||||
}
|
||||
}
|
@ -411,7 +411,7 @@ class StateApp {
|
||||
}
|
||||
|
||||
if (Settings.instance.synchronization.enabled) {
|
||||
StateSync.instance.start()
|
||||
StateSync.instance.start(context)
|
||||
}
|
||||
|
||||
Logger.onLogSubmitted.subscribe {
|
||||
|
@ -1,5 +1,8 @@
|
||||
package com.futo.platformplayer.states
|
||||
|
||||
import android.content.Context
|
||||
import android.net.nsd.NsdManager
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.util.Log
|
||||
import com.futo.platformplayer.LittleEndianDataInputStream
|
||||
@ -9,14 +12,14 @@ import com.futo.platformplayer.UIDialogs
|
||||
import com.futo.platformplayer.activities.MainActivity
|
||||
import com.futo.platformplayer.activities.SyncShowPairingCodeActivity
|
||||
import com.futo.platformplayer.api.media.Serializer
|
||||
import com.futo.platformplayer.casting.StateCasting
|
||||
import com.futo.platformplayer.casting.StateCasting.Companion
|
||||
import com.futo.platformplayer.constructs.Event1
|
||||
import com.futo.platformplayer.constructs.Event2
|
||||
import com.futo.platformplayer.encryption.GEncryptionProvider
|
||||
import com.futo.platformplayer.generateReadablePassword
|
||||
import com.futo.platformplayer.getConnectedSocket
|
||||
import com.futo.platformplayer.logging.Logger
|
||||
import com.futo.platformplayer.mdns.DnsService
|
||||
import com.futo.platformplayer.mdns.ServiceDiscoverer
|
||||
import com.futo.platformplayer.models.HistoryVideo
|
||||
import com.futo.platformplayer.models.Subscription
|
||||
import com.futo.platformplayer.noise.protocol.DHState
|
||||
@ -81,12 +84,30 @@ class StateSync {
|
||||
//TODO: Should sync mdns and casting mdns be merged?
|
||||
//TODO: Decrease interval that devices are updated
|
||||
//TODO: Send less data
|
||||
private val _serviceDiscoverer = ServiceDiscoverer(arrayOf("_gsync._tcp.local")) { handleServiceUpdated(it) }
|
||||
|
||||
private val _pairingCode: String? = generateReadablePassword(8)
|
||||
val pairingCode: String? get() = _pairingCode
|
||||
private var _relaySession: SyncSocketSession? = null
|
||||
private var _threadRelay: Thread? = null
|
||||
private val _remotePendingStatusUpdate = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
|
||||
private var _nsdManager: NsdManager? = null
|
||||
private val _registrationListener = object : NsdManager.RegistrationListener {
|
||||
override fun onServiceRegistered(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "onServiceRegistered: ${serviceInfo.serviceName}")
|
||||
}
|
||||
|
||||
override fun onRegistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
|
||||
Log.v(TAG, "onRegistrationFailed: ${serviceInfo.serviceName} (error code: $errorCode)")
|
||||
}
|
||||
|
||||
override fun onServiceUnregistered(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "onServiceUnregistered: ${serviceInfo.serviceName}")
|
||||
}
|
||||
|
||||
override fun onUnregistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
|
||||
Log.v(TAG, "onUnregistrationFailed: ${serviceInfo.serviceName} (error code: $errorCode)")
|
||||
}
|
||||
}
|
||||
|
||||
var keyPair: DHState? = null
|
||||
var publicKey: String? = null
|
||||
@ -101,15 +122,116 @@ class StateSync {
|
||||
}
|
||||
}
|
||||
|
||||
fun start() {
|
||||
fun start(context: Context) {
|
||||
if (_started) {
|
||||
Logger.i(TAG, "Already started.")
|
||||
return
|
||||
}
|
||||
_started = true
|
||||
_nsdManager = context.getSystemService(Context.NSD_SERVICE) as NsdManager
|
||||
|
||||
if (Settings.instance.synchronization.broadcast || Settings.instance.synchronization.connectDiscovered) {
|
||||
_serviceDiscoverer.start()
|
||||
if (Settings.instance.synchronization.connectDiscovered) {
|
||||
_nsdManager?.apply {
|
||||
discoverServices("_gsync._tcp", NsdManager.PROTOCOL_DNS_SD, object : NsdManager.DiscoveryListener {
|
||||
override fun onDiscoveryStarted(regType: String) {
|
||||
Log.d(TAG, "Service discovery started for $regType")
|
||||
}
|
||||
|
||||
override fun onDiscoveryStopped(serviceType: String) {
|
||||
Log.i(TAG, "Discovery stopped: $serviceType")
|
||||
}
|
||||
|
||||
override fun onServiceLost(service: NsdServiceInfo) {
|
||||
Log.e(TAG, "service lost: $service")
|
||||
// TODO: Handle service lost, e.g., remove device
|
||||
}
|
||||
|
||||
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {
|
||||
Log.e(TAG, "Discovery failed for $serviceType: Error code:$errorCode")
|
||||
_nsdManager?.stopServiceDiscovery(this)
|
||||
}
|
||||
|
||||
override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {
|
||||
Log.e(TAG, "Stop discovery failed for $serviceType: Error code:$errorCode")
|
||||
_nsdManager?.stopServiceDiscovery(this)
|
||||
}
|
||||
|
||||
fun addOrUpdate(name: String, adrs: Array<InetAddress>, port: Int, attributes: Map<String, ByteArray>) {
|
||||
if (!Settings.instance.synchronization.connectDiscovered) {
|
||||
return
|
||||
}
|
||||
|
||||
val urlSafePkey = attributes.get("pk")?.decodeToString() ?: return
|
||||
val pkey = Base64.getEncoder().encodeToString(Base64.getDecoder().decode(urlSafePkey.replace('-', '+').replace('_', '/')))
|
||||
val syncDeviceInfo = SyncDeviceInfo(pkey, adrs.map { it.hostAddress }.toTypedArray(), port, null)
|
||||
val authorized = isAuthorized(pkey)
|
||||
|
||||
if (authorized && !isConnected(pkey)) {
|
||||
val now = System.currentTimeMillis()
|
||||
val lastConnectTime = synchronized(_lastConnectTimesMdns) {
|
||||
_lastConnectTimesMdns[pkey] ?: 0
|
||||
}
|
||||
|
||||
//Connect once every 30 seconds, max
|
||||
if (now - lastConnectTime > 30000) {
|
||||
synchronized(_lastConnectTimesMdns) {
|
||||
_lastConnectTimesMdns[pkey] = now
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Found device authorized device '${name}' with pkey=$pkey, attempting to connect")
|
||||
|
||||
try {
|
||||
connect(syncDeviceInfo)
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Failed to connect to $pkey", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onServiceFound(service: NsdServiceInfo) {
|
||||
Log.v(TAG, "Service discovery success for ${service.serviceType}: $service")
|
||||
addOrUpdate(service.serviceName, if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
|
||||
service.hostAddresses.toTypedArray()
|
||||
} else {
|
||||
arrayOf(service.host)
|
||||
}, service.port, service.attributes)
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
|
||||
_nsdManager?.registerServiceInfoCallback(service, { it.run() }, object : NsdManager.ServiceInfoCallback {
|
||||
override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "onServiceUpdated: $serviceInfo")
|
||||
addOrUpdate(serviceInfo.serviceName, serviceInfo.hostAddresses.toTypedArray(), serviceInfo.port, serviceInfo.attributes)
|
||||
}
|
||||
|
||||
override fun onServiceLost() {
|
||||
Log.v(TAG, "onServiceLost: $service")
|
||||
// TODO: Handle service lost
|
||||
}
|
||||
|
||||
override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) {
|
||||
Log.v(TAG, "onServiceInfoCallbackRegistrationFailed: $errorCode")
|
||||
}
|
||||
|
||||
override fun onServiceInfoCallbackUnregistered() {
|
||||
Log.v(TAG, "onServiceInfoCallbackUnregistered")
|
||||
}
|
||||
})
|
||||
} else {
|
||||
_nsdManager?.resolveService(service, object : NsdManager.ResolveListener {
|
||||
override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
|
||||
Log.v(TAG, "Resolve failed: $errorCode")
|
||||
}
|
||||
|
||||
override fun onServiceResolved(serviceInfo: NsdServiceInfo) {
|
||||
Log.v(TAG, "Resolve Succeeded: $serviceInfo")
|
||||
addOrUpdate(serviceInfo.serviceName, arrayOf(serviceInfo.host), serviceInfo.port, serviceInfo.attributes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
@ -142,7 +264,19 @@ class StateSync {
|
||||
}
|
||||
|
||||
if (Settings.instance.synchronization.broadcast) {
|
||||
publicKey?.let { _serviceDiscoverer.broadcastService(getDeviceName(), "_gsync._tcp.local", PORT.toUShort(), texts = arrayListOf("pk=${it.replace('+', '-').replace('/', '_').replace("=", "")}")) }
|
||||
val pk = publicKey
|
||||
val nsdManager = _nsdManager
|
||||
|
||||
if (pk != null && nsdManager != null) {
|
||||
val serviceInfo = NsdServiceInfo().apply {
|
||||
serviceName = getDeviceName()
|
||||
serviceType = "_gsync._tcp"
|
||||
port = PORT
|
||||
setAttribute("pk", pk.replace('+', '-').replace('/', '_').replace("=", ""))
|
||||
}
|
||||
|
||||
nsdManager.registerService(serviceInfo, NsdManager.PROTOCOL_DNS_SD, _registrationListener)
|
||||
}
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Sync key pair initialized (public key = ${publicKey})")
|
||||
@ -318,7 +452,7 @@ class StateSync {
|
||||
override val isAuthorized: Boolean get() = true
|
||||
}
|
||||
|
||||
_relaySession!!.startAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null)
|
||||
_relaySession!!.runAsInitiator(RELAY_PUBLIC_KEY, APP_ID, null)
|
||||
|
||||
Log.i(TAG, "Started relay session.")
|
||||
} catch (e: Throwable) {
|
||||
@ -331,6 +465,8 @@ class StateSync {
|
||||
}
|
||||
}.apply { start() }
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
private fun getDeviceName(): String {
|
||||
@ -382,48 +518,6 @@ class StateSync {
|
||||
_syncSessionData.setAndSave(data.publicKey, data);
|
||||
}
|
||||
|
||||
private fun handleServiceUpdated(services: List<DnsService>) {
|
||||
if (!Settings.instance.synchronization.connectDiscovered) {
|
||||
return
|
||||
}
|
||||
|
||||
for (s in services) {
|
||||
//TODO: Addresses IPv4 only?
|
||||
val addresses = s.addresses.mapNotNull { it.hostAddress }.toTypedArray()
|
||||
val port = s.port.toInt()
|
||||
if (s.name.endsWith("._gsync._tcp.local")) {
|
||||
val name = s.name.substring(0, s.name.length - "._gsync._tcp.local".length)
|
||||
val urlSafePkey = s.texts.firstOrNull { it.startsWith("pk=") }?.substring("pk=".length) ?: continue
|
||||
val pkey = Base64.getEncoder().encodeToString(Base64.getDecoder().decode(urlSafePkey.replace('-', '+').replace('_', '/')))
|
||||
|
||||
val syncDeviceInfo = SyncDeviceInfo(pkey, addresses, port, null)
|
||||
val authorized = isAuthorized(pkey)
|
||||
|
||||
if (authorized && !isConnected(pkey)) {
|
||||
val now = System.currentTimeMillis()
|
||||
val lastConnectTime = synchronized(_lastConnectTimesMdns) {
|
||||
_lastConnectTimesMdns[pkey] ?: 0
|
||||
}
|
||||
|
||||
//Connect once every 30 seconds, max
|
||||
if (now - lastConnectTime > 30000) {
|
||||
synchronized(_lastConnectTimesMdns) {
|
||||
_lastConnectTimesMdns[pkey] = now
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Found device authorized device '${name}' with pkey=$pkey, attempting to connect")
|
||||
|
||||
try {
|
||||
connect(syncDeviceInfo)
|
||||
} catch (e: Throwable) {
|
||||
Logger.i(TAG, "Failed to connect to $pkey", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun unauthorize(remotePublicKey: String) {
|
||||
Logger.i(TAG, "${remotePublicKey} unauthorized received")
|
||||
_authorizedDevices.remove(remotePublicKey)
|
||||
@ -899,7 +993,7 @@ class StateSync {
|
||||
|
||||
fun stop() {
|
||||
_started = false
|
||||
_serviceDiscoverer.stop()
|
||||
_nsdManager?.unregisterService(_registrationListener)
|
||||
|
||||
_serverSocket?.close()
|
||||
_serverSocket = null
|
||||
|
@ -121,6 +121,19 @@ class SyncSocketSession {
|
||||
}.apply { start() }
|
||||
}
|
||||
|
||||
fun runAsInitiator(remotePublicKey: String, appId: UInt = 0u, pairingCode: String? = null) {
|
||||
_started = true
|
||||
try {
|
||||
handshakeAsInitiator(remotePublicKey, appId, pairingCode)
|
||||
_onHandshakeComplete?.invoke(this)
|
||||
receiveLoop()
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to run as initiator", e)
|
||||
} finally {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
|
||||
fun startAsResponder() {
|
||||
_started = true
|
||||
_thread = Thread {
|
||||
|
Loading…
x
Reference in New Issue
Block a user