mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2025-05-02 15:44:26 +02:00
Fixed various implementation bugs.
This commit is contained in:
parent
1ae9f0ea26
commit
955ba23b0d
@ -93,13 +93,17 @@ class SyncServerTests {
|
|||||||
|
|
||||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||||
channelB.setDataHandler { _, _, o, so, d ->
|
channelB.setDataHandler { _, _, o, so, d ->
|
||||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
val b = ByteArray(d.remaining())
|
||||||
|
d.get(b)
|
||||||
|
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||||
}
|
}
|
||||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
|
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
|
||||||
|
|
||||||
val tcsDataA = CompletableDeferred<ByteArray>()
|
val tcsDataA = CompletableDeferred<ByteArray>()
|
||||||
channelA.setDataHandler { _, _, o, so, d ->
|
channelA.setDataHandler { _, _, o, so, d ->
|
||||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array())
|
val b = ByteArray(d.remaining())
|
||||||
|
d.get(b)
|
||||||
|
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b)
|
||||||
}
|
}
|
||||||
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
|
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
|
||||||
|
|
||||||
@ -231,7 +235,9 @@ class SyncServerTests {
|
|||||||
|
|
||||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||||
channelB.setDataHandler { _, _, o, so, d ->
|
channelB.setDataHandler { _, _, o, so, d ->
|
||||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
val b = ByteArray(d.remaining())
|
||||||
|
d.get(b)
|
||||||
|
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||||
}
|
}
|
||||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
||||||
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
|
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
|
||||||
|
@ -14,7 +14,7 @@ interface IChannel : AutoCloseable {
|
|||||||
val remoteVersion: Int?
|
val remoteVersion: Int?
|
||||||
var authorizable: IAuthorizable?
|
var authorizable: IAuthorizable?
|
||||||
fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?)
|
fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?)
|
||||||
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null)
|
fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null)
|
||||||
fun setCloseHandler(onClose: ((IChannel) -> Unit)?)
|
fun setCloseHandler(onClose: ((IChannel) -> Unit)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,12 +326,4 @@ class ChannelRelayed(
|
|||||||
completeHandshake(remoteVersion, transport)
|
completeHandshake(remoteVersion, transport)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun handleData(data: ByteBuffer) {
|
|
||||||
val size = data.int
|
|
||||||
if (size != data.remaining()) throw IllegalStateException("Incomplete packet received")
|
|
||||||
val opcode = data.get().toUByte()
|
|
||||||
val subOpcode = data.get().toUByte()
|
|
||||||
invokeDataHandler(opcode, subOpcode, data)
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -155,7 +155,7 @@ class SyncSocketSession {
|
|||||||
val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize)
|
val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize)
|
||||||
//Logger.i(TAG, "Decrypted message (size = ${plen})")
|
//Logger.i(TAG, "Decrypted message (size = ${plen})")
|
||||||
|
|
||||||
handleData(_bufferDecrypted, plen)
|
handleData(_bufferDecrypted, plen, null)
|
||||||
} catch (e: Throwable) {
|
} catch (e: Throwable) {
|
||||||
Logger.e(TAG, "Exception while receiving data", e)
|
Logger.e(TAG, "Exception while receiving data", e)
|
||||||
break
|
break
|
||||||
@ -374,22 +374,25 @@ class SyncSocketSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@OptIn(ExperimentalUnsignedTypes::class)
|
private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) {
|
||||||
private fun handleData(data: ByteArray, length: Int) {
|
return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
|
val length = data.remaining()
|
||||||
if (length < HEADER_SIZE)
|
if (length < HEADER_SIZE)
|
||||||
throw Exception("Packet must be at least 6 bytes (header size)")
|
throw Exception("Packet must be at least 6 bytes (header size)")
|
||||||
|
|
||||||
val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int
|
val size = data.int
|
||||||
if (size != length - 4)
|
if (size != length - 4)
|
||||||
throw Exception("Incomplete packet received")
|
throw Exception("Incomplete packet received")
|
||||||
|
|
||||||
val opcode = data.asUByteArray()[4]
|
val opcode = data.get().toUByte()
|
||||||
val subOpcode = data.asUByteArray()[5]
|
val subOpcode = data.get().toUByte()
|
||||||
val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2)
|
handlePacket(opcode, subOpcode, data, sourceChannel)
|
||||||
handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleRequest(subOpcode: UByte, data: ByteBuffer) {
|
private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
when (subOpcode) {
|
when (subOpcode) {
|
||||||
RequestOpcode.TRANSPORT_RELAYED.value -> {
|
RequestOpcode.TRANSPORT_RELAYED.value -> {
|
||||||
Logger.i(TAG, "Received request for a relayed transport")
|
Logger.i(TAG, "Received request for a relayed transport")
|
||||||
@ -440,7 +443,7 @@ class SyncSocketSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleResponse(subOpcode: UByte, data: ByteBuffer) {
|
private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
if (data.remaining() < 8) {
|
if (data.remaining() < 8) {
|
||||||
Logger.e(TAG, "Response packet too short")
|
Logger.e(TAG, "Response packet too short")
|
||||||
return
|
return
|
||||||
@ -651,7 +654,7 @@ class SyncSocketSession {
|
|||||||
return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied)
|
return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleNotify(subOpcode: UByte, data: ByteBuffer) {
|
private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
when (subOpcode) {
|
when (subOpcode) {
|
||||||
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
|
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
|
||||||
NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ }
|
NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ }
|
||||||
@ -666,7 +669,7 @@ class SyncSocketSession {
|
|||||||
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
|
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleRelay(subOpcode: UByte, data: ByteBuffer) {
|
private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
when (subOpcode) {
|
when (subOpcode) {
|
||||||
RelayOpcode.RELAYED_DATA.value -> {
|
RelayOpcode.RELAYED_DATA.value -> {
|
||||||
if (data.remaining() < 8) {
|
if (data.remaining() < 8) {
|
||||||
@ -680,7 +683,7 @@ class SyncSocketSession {
|
|||||||
}
|
}
|
||||||
val decryptedPayload = channel.decrypt(data)
|
val decryptedPayload = channel.decrypt(data)
|
||||||
try {
|
try {
|
||||||
channel.handleData(decryptedPayload)
|
handleData(decryptedPayload, channel)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Logger.e(TAG, "Exception while handling relayed data", e)
|
Logger.e(TAG, "Exception while handling relayed data", e)
|
||||||
channel.sendError(SyncErrorCode.ConnectionClosed)
|
channel.sendError(SyncErrorCode.ConnectionClosed)
|
||||||
@ -726,33 +729,36 @@ class SyncSocketSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
|
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||||
Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||||
|
|
||||||
when (opcode) {
|
when (opcode) {
|
||||||
Opcode.PING.value -> {
|
Opcode.PING.value -> {
|
||||||
send(Opcode.PONG.value)
|
if (sourceChannel != null)
|
||||||
|
sourceChannel.send(Opcode.PONG.value)
|
||||||
|
else
|
||||||
|
send(Opcode.PONG.value)
|
||||||
//Logger.i(TAG, "Received ping, sent pong")
|
//Logger.i(TAG, "Received ping, sent pong")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Opcode.PONG.value -> {
|
Opcode.PONG.value -> {
|
||||||
//Logger.i(TAG, "Received pong")
|
Logger.v(TAG, "Received pong")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Opcode.REQUEST.value -> {
|
Opcode.REQUEST.value -> {
|
||||||
handleRequest(subOpcode, data)
|
handleRequest(subOpcode, data, sourceChannel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Opcode.RESPONSE.value -> {
|
Opcode.RESPONSE.value -> {
|
||||||
handleResponse(subOpcode, data)
|
handleResponse(subOpcode, data, sourceChannel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Opcode.NOTIFY.value -> {
|
Opcode.NOTIFY.value -> {
|
||||||
handleNotify(subOpcode, data)
|
handleNotify(subOpcode, data, sourceChannel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
Opcode.RELAY.value -> {
|
Opcode.RELAY.value -> {
|
||||||
handleRelay(subOpcode, data)
|
handleRelay(subOpcode, data, sourceChannel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
else -> if (isAuthorized) when (opcode) {
|
else -> if (isAuthorized) when (opcode) {
|
||||||
@ -809,12 +815,18 @@ class SyncSocketSession {
|
|||||||
throw Exception("After sync stream end, the stream must be complete")
|
throw Exception("After sync stream end, the stream must be complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) })
|
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel)
|
||||||
}
|
|
||||||
else -> {
|
|
||||||
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Opcode.DATA.value -> {
|
||||||
|
if (sourceChannel != null)
|
||||||
|
sourceChannel.invokeDataHandler(opcode, subOpcode, data)
|
||||||
|
else
|
||||||
|
_onData?.invoke(this, opcode, subOpcode, data)
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -995,18 +1007,27 @@ class SyncSocketSession {
|
|||||||
val deferred = CompletableDeferred<Boolean>()
|
val deferred = CompletableDeferred<Boolean>()
|
||||||
_pendingPublishRequests[requestId] = deferred
|
_pendingPublishRequests[requestId] = deferred
|
||||||
try {
|
try {
|
||||||
val MAX_PLAINTEXT_SIZE = 65535 - 16 // Adjust for tag size
|
val MAX_PLAINTEXT_SIZE = 65535
|
||||||
val HANDSHAKE_SIZE = 48
|
val HANDSHAKE_SIZE = 48
|
||||||
val LENGTH_SIZE = 4
|
val LENGTH_SIZE = 4
|
||||||
val TAG_SIZE = 16
|
val TAG_SIZE = 16
|
||||||
val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE
|
val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE
|
||||||
val blobSize = HANDSHAKE_SIZE + chunkCount * (LENGTH_SIZE + MAX_PLAINTEXT_SIZE + TAG_SIZE)
|
|
||||||
|
var blobSize = HANDSHAKE_SIZE
|
||||||
|
var dataOffset = 0
|
||||||
|
for (i in 0 until chunkCount) {
|
||||||
|
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||||
|
blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE)
|
||||||
|
dataOffset += chunkSize
|
||||||
|
}
|
||||||
|
|
||||||
val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize)
|
val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize)
|
||||||
val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN)
|
val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
packet.putInt(requestId)
|
packet.putInt(requestId)
|
||||||
packet.put(keyBytes.size.toByte())
|
packet.put(keyBytes.size.toByte())
|
||||||
packet.put(keyBytes)
|
packet.put(keyBytes)
|
||||||
packet.put(consumerPublicKeys.size.toByte())
|
packet.put(consumerPublicKeys.size.toByte())
|
||||||
|
|
||||||
for (consumer in consumerPublicKeys) {
|
for (consumer in consumerPublicKeys) {
|
||||||
val consumerBytes = Base64.getDecoder().decode(consumer)
|
val consumerBytes = Base64.getDecoder().decode(consumer)
|
||||||
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
|
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
|
||||||
@ -1020,9 +1041,10 @@ class SyncSocketSession {
|
|||||||
val transportPair = protocol.split()
|
val transportPair = protocol.split()
|
||||||
packet.putInt(blobSize)
|
packet.putInt(blobSize)
|
||||||
packet.put(handshakeMessage)
|
packet.put(handshakeMessage)
|
||||||
var dataOffset = 0
|
|
||||||
|
dataOffset = 0
|
||||||
for (i in 0 until chunkCount) {
|
for (i in 0 until chunkCount) {
|
||||||
val chunkSize = min(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||||
val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize)
|
val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize)
|
||||||
val ciphertext = ByteArray(chunkSize + TAG_SIZE)
|
val ciphertext = ByteArray(chunkSize + TAG_SIZE)
|
||||||
val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)
|
val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user