diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/download/DownloadServer.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/download/DownloadServer.kt index 26c30e73..509d0074 100644 --- a/core/src/main/kotlin/me/rhunk/snapenhance/util/download/DownloadServer.kt +++ b/core/src/main/kotlin/me/rhunk/snapenhance/util/download/DownloadServer.kt @@ -2,6 +2,8 @@ package me.rhunk.snapenhance.util.download import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay import kotlinx.coroutines.launch import me.rhunk.snapenhance.Logger import java.io.BufferedReader @@ -10,16 +12,20 @@ import java.io.InputStreamReader import java.io.PrintWriter import java.net.ServerSocket import java.net.Socket -import java.net.SocketTimeoutException +import java.net.SocketException import java.util.Locale import java.util.StringTokenizer import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.ThreadLocalRandom +import kotlin.random.Random class DownloadServer( private val timeout: Int = 10000 ) { - private val port = ThreadLocalRandom.current().nextInt(10000, 65535) + private val port = Random.nextInt(10000, 65535) + + private val coroutineScope = CoroutineScope(Dispatchers.IO) + private var timeoutJob: Job? = null + private var socketJob: Job? = null private val cachedData = ConcurrentHashMap>() private var serverSocket: ServerSocket? = null @@ -30,27 +36,36 @@ class DownloadServer( return } - CoroutineScope(Dispatchers.IO).launch { + coroutineScope.launch(Dispatchers.IO) { Logger.debug("starting download server on port $port") serverSocket = ServerSocket(port) - serverSocket!!.soTimeout = timeout callback(this@DownloadServer) while (!serverSocket!!.isClosed) { try { val socket = serverSocket!!.accept() - launch(Dispatchers.IO) { + timeoutJob?.cancel() + launch { handleRequest(socket) + timeoutJob = launch { + delay(timeout.toLong()) + Logger.debug("download server closed due to timeout") + runCatching { + socketJob?.cancel() + socket.close() + serverSocket?.close() + }.onFailure { + Logger.error(it) + } + } } - } catch (e: SocketTimeoutException) { - serverSocket?.close() - serverSocket = null - Logger.debug("download server closed") + } catch (e: SocketException) { + Logger.debug("download server timed out") break; - } catch (e: Exception) { + } catch (e: Throwable) { Logger.error("failed to handle request", e) } } - } + }.also { socketJob = it } } fun close() {