diff --git a/app/src/main/java/io/legado/app/help/ConcurrentRateLimiter.kt b/app/src/main/java/io/legado/app/help/ConcurrentRateLimiter.kt new file mode 100644 index 000000000..d36144e02 --- /dev/null +++ b/app/src/main/java/io/legado/app/help/ConcurrentRateLimiter.kt @@ -0,0 +1,135 @@ +package io.legado.app.help + +import io.legado.app.data.entities.BaseSource +import io.legado.app.exception.ConcurrentException +import io.legado.app.model.analyzeRule.AnalyzeUrl.ConcurrentRecord +import kotlinx.coroutines.delay + +class ConcurrentRateLimiter(val source: BaseSource?) { + + companion object { + private val concurrentRecordMap = hashMapOf() + } + + /** + * 开始访问,并发判断 + */ + @Throws(ConcurrentException::class) + private fun fetchStart(): ConcurrentRecord? { + source ?: return null + val concurrentRate = source.concurrentRate + if (concurrentRate.isNullOrEmpty() || concurrentRate == "0") { + return null + } + val rateIndex = concurrentRate.indexOf("/") + var fetchRecord = concurrentRecordMap[source.getKey()] + if (fetchRecord == null) { + synchronized(concurrentRecordMap) { + fetchRecord = concurrentRecordMap[source.getKey()] + if (fetchRecord == null) { + fetchRecord = ConcurrentRecord(rateIndex > 0, System.currentTimeMillis(), 1) + concurrentRecordMap[source.getKey()] = fetchRecord + return fetchRecord + } + } + } + val waitTime: Int = synchronized(fetchRecord!!) { + try { + if (!fetchRecord.isConcurrent) { + //并发控制非 次数/毫秒 + if (fetchRecord.frequency > 0) { + //已经有访问线程,直接等待 + return@synchronized concurrentRate.toInt() + } + //没有线程访问,判断还剩多少时间可以访问 + val nextTime = fetchRecord.time + concurrentRate.toInt() + if (System.currentTimeMillis() >= nextTime) { + fetchRecord.time = System.currentTimeMillis() + fetchRecord.frequency = 1 + return@synchronized 0 + } + return@synchronized (nextTime - System.currentTimeMillis()).toInt() + } else { + //并发控制为 次数/毫秒 + val sj = concurrentRate.substring(rateIndex + 1) + val nextTime = fetchRecord.time + sj.toInt() + if (System.currentTimeMillis() >= nextTime) { + //已经过了限制时间,重置开始时间 + fetchRecord.time = System.currentTimeMillis() + fetchRecord.frequency = 1 + return@synchronized 0 + } + val cs = concurrentRate.substring(0, rateIndex) + if (fetchRecord.frequency > cs.toInt()) { + return@synchronized (nextTime - System.currentTimeMillis()).toInt() + } else { + fetchRecord.frequency += 1 + return@synchronized 0 + } + } + } catch (_: Exception) { + return@synchronized 0 + } + } + if (waitTime > 0) { + throw ConcurrentException( + "根据并发率还需等待${waitTime}毫秒才可以访问", + waitTime = waitTime + ) + } + return fetchRecord + } + + /** + * 访问结束 + */ + fun fetchEnd(concurrentRecord: ConcurrentRecord?) { + if (concurrentRecord != null && !concurrentRecord.isConcurrent) { + synchronized(concurrentRecord) { + concurrentRecord.frequency -= 1 + } + } + } + + /** + * 获取并发记录,若处于并发限制状态下则会等待 + */ + suspend fun getConcurrentRecord(): ConcurrentRecord? { + while (true) { + try { + return fetchStart() + } catch (e: ConcurrentException) { + delay(e.waitTime.toLong()) + } + } + } + + fun getConcurrentRecordBlocking(): ConcurrentRecord? { + while (true) { + try { + return fetchStart() + } catch (e: ConcurrentException) { + Thread.sleep(e.waitTime.toLong()) + } + } + } + + suspend inline fun withLimit(block: () -> T): T { + val concurrentRecord = getConcurrentRecord() + try { + return block() + } finally { + fetchEnd(concurrentRecord) + } + } + + inline fun withLimitBlocking(block: () -> T): T { + val concurrentRecord = getConcurrentRecordBlocking() + try { + return block() + } finally { + fetchEnd(concurrentRecord) + } + } + +} diff --git a/app/src/main/java/io/legado/app/help/JsExtensions.kt b/app/src/main/java/io/legado/app/help/JsExtensions.kt index 20c40b5ca..50be03c59 100644 --- a/app/src/main/java/io/legado/app/help/JsExtensions.kt +++ b/app/src/main/java/io/legado/app/help/JsExtensions.kt @@ -47,6 +47,7 @@ import io.legado.app.utils.toStringArray import io.legado.app.utils.toastOnUi import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.async +import kotlinx.coroutines.ensureActive import kotlinx.coroutines.runBlocking import okio.use import org.jsoup.Connection @@ -358,13 +359,17 @@ interface JsExtensions : JsEncodeUtils { val requestHeaders = if (getSource()?.enabledCookieJar == true) { headers.toMutableMap().apply { put(cookieJarHeader, "1") } } else headers - val response = Jsoup.connect(urlStr) - .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) - .ignoreContentType(true) - .followRedirects(false) - .headers(requestHeaders) - .method(Connection.Method.GET) - .execute() + val rateLimiter = ConcurrentRateLimiter(getSource()) + val response = rateLimiter.withLimitBlocking { + context.ensureActive() + Jsoup.connect(urlStr) + .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) + .ignoreContentType(true) + .followRedirects(false) + .headers(requestHeaders) + .method(Connection.Method.GET) + .execute() + } return response } @@ -375,13 +380,17 @@ interface JsExtensions : JsEncodeUtils { val requestHeaders = if (getSource()?.enabledCookieJar == true) { headers.toMutableMap().apply { put(cookieJarHeader, "1") } } else headers - val response = Jsoup.connect(urlStr) - .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) - .ignoreContentType(true) - .followRedirects(false) - .headers(requestHeaders) - .method(Connection.Method.HEAD) - .execute() + val rateLimiter = ConcurrentRateLimiter(getSource()) + val response = rateLimiter.withLimitBlocking { + context.ensureActive() + Jsoup.connect(urlStr) + .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) + .ignoreContentType(true) + .followRedirects(false) + .headers(requestHeaders) + .method(Connection.Method.HEAD) + .execute() + } return response } @@ -392,14 +401,18 @@ interface JsExtensions : JsEncodeUtils { val requestHeaders = if (getSource()?.enabledCookieJar == true) { headers.toMutableMap().apply { put(cookieJarHeader, "1") } } else headers - val response = Jsoup.connect(urlStr) - .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) - .ignoreContentType(true) - .followRedirects(false) - .requestBody(body) - .headers(requestHeaders) - .method(Connection.Method.POST) - .execute() + val rateLimiter = ConcurrentRateLimiter(getSource()) + val response = rateLimiter.withLimitBlocking { + context.ensureActive() + Jsoup.connect(urlStr) + .sslSocketFactory(SSLHelper.unsafeSSLSocketFactory) + .ignoreContentType(true) + .followRedirects(false) + .requestBody(body) + .headers(requestHeaders) + .method(Connection.Method.POST) + .execute() + } return response } diff --git a/app/src/main/java/io/legado/app/model/analyzeRule/AnalyzeUrl.kt b/app/src/main/java/io/legado/app/model/analyzeRule/AnalyzeUrl.kt index 1aaebb21b..5dde288b3 100644 --- a/app/src/main/java/io/legado/app/model/analyzeRule/AnalyzeUrl.kt +++ b/app/src/main/java/io/legado/app/model/analyzeRule/AnalyzeUrl.kt @@ -15,8 +15,8 @@ import io.legado.app.constant.AppPattern.dataUriRegex import io.legado.app.data.entities.BaseSource import io.legado.app.data.entities.Book import io.legado.app.data.entities.BookChapter -import io.legado.app.exception.ConcurrentException import io.legado.app.help.CacheManager +import io.legado.app.help.ConcurrentRateLimiter import io.legado.app.help.JsExtensions import io.legado.app.help.config.AppConfig import io.legado.app.help.exoplayer.ExoPlayerHelper @@ -46,7 +46,6 @@ import io.legado.app.utils.isJsonArray import io.legado.app.utils.isJsonObject import io.legado.app.utils.isXml import io.legado.app.utils.splitNotBlank -import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient @@ -86,7 +85,6 @@ class AnalyzeUrl( companion object { val paramPattern: Pattern = Pattern.compile("\\s*,\\s*(?=\\{)") private val pagePattern = Pattern.compile("<(.*?)>") - private val concurrentRecordMap = hashMapOf() } var ruleUrl = "" @@ -110,6 +108,7 @@ class AnalyzeUrl( private val enabledCookieJar = source?.enabledCookieJar ?: false private val domain: String private var webViewDelayTime: Long = 0 + private val concurrentRateLimiter = ConcurrentRateLimiter(source) // 服务器ID var serverID: Long? = null @@ -331,99 +330,6 @@ class AnalyzeUrl( ?: "" } - /** - * 开始访问,并发判断 - */ - @Throws(ConcurrentException::class) - private fun fetchStart(): ConcurrentRecord? { - source ?: return null - val concurrentRate = source.concurrentRate - if (concurrentRate.isNullOrEmpty() || concurrentRate == "0") { - return null - } - val rateIndex = concurrentRate.indexOf("/") - var fetchRecord = concurrentRecordMap[source.getKey()] - if (fetchRecord == null) { - synchronized(concurrentRecordMap) { - fetchRecord = concurrentRecordMap[source.getKey()] - if (fetchRecord == null) { - fetchRecord = ConcurrentRecord(rateIndex > 0, System.currentTimeMillis(), 1) - concurrentRecordMap[source.getKey()] = fetchRecord - return fetchRecord - } - } - } - val waitTime: Int = synchronized(fetchRecord!!) { - try { - if (!fetchRecord.isConcurrent) { - //并发控制非 次数/毫秒 - if (fetchRecord.frequency > 0) { - //已经有访问线程,直接等待 - return@synchronized concurrentRate.toInt() - } - //没有线程访问,判断还剩多少时间可以访问 - val nextTime = fetchRecord.time + concurrentRate.toInt() - if (System.currentTimeMillis() >= nextTime) { - fetchRecord.time = System.currentTimeMillis() - fetchRecord.frequency = 1 - return@synchronized 0 - } - return@synchronized (nextTime - System.currentTimeMillis()).toInt() - } else { - //并发控制为 次数/毫秒 - val sj = concurrentRate.substring(rateIndex + 1) - val nextTime = fetchRecord.time + sj.toInt() - if (System.currentTimeMillis() >= nextTime) { - //已经过了限制时间,重置开始时间 - fetchRecord.time = System.currentTimeMillis() - fetchRecord.frequency = 1 - return@synchronized 0 - } - val cs = concurrentRate.substring(0, rateIndex) - if (fetchRecord.frequency > cs.toInt()) { - return@synchronized (nextTime - System.currentTimeMillis()).toInt() - } else { - fetchRecord.frequency += 1 - return@synchronized 0 - } - } - } catch (e: Exception) { - return@synchronized 0 - } - } - if (waitTime > 0) { - throw ConcurrentException( - "根据并发率还需等待${waitTime}毫秒才可以访问", - waitTime = waitTime - ) - } - return fetchRecord - } - - /** - * 访问结束 - */ - private fun fetchEnd(concurrentRecord: ConcurrentRecord?) { - if (concurrentRecord != null && !concurrentRecord.isConcurrent) { - synchronized(concurrentRecord) { - concurrentRecord.frequency -= 1 - } - } - } - - /** - * 获取并发记录,若处于并发限制状态下则会等待 - */ - private suspend fun getConcurrentRecord(): ConcurrentRecord? { - while (true) { - try { - return fetchStart() - } catch (e: ConcurrentException) { - delay(e.waitTime.toLong()) - } - } - } - /** * 访问网站,返回StrResponse */ @@ -435,8 +341,7 @@ class AnalyzeUrl( if (type != null) { return StrResponse(url, HexUtil.encodeHexStr(getByteArrayAwait())) } - val concurrentRecord = getConcurrentRecord() - try { + concurrentRateLimiter.withLimit { setCookie() val strResponse: StrResponse if (this.useWebView && useWebView) { @@ -500,9 +405,6 @@ class AnalyzeUrl( } } return strResponse - } finally { - //saveCookie() - fetchEnd(concurrentRecord) } } @@ -521,8 +423,7 @@ class AnalyzeUrl( * 访问网站,返回Response */ suspend fun getResponseAwait(): Response { - val concurrentRecord = getConcurrentRecord() - try { + concurrentRateLimiter.withLimit { setCookie() val response = getClient().newCallResponse(retry) { addHeaders(headerMap) @@ -545,9 +446,6 @@ class AnalyzeUrl( } } return response - } finally { - //saveCookie() - fetchEnd(concurrentRecord) } }