diff --git a/packages/react-native/ReactAndroid/build.gradle.kts b/packages/react-native/ReactAndroid/build.gradle.kts index b0136db3c8f5..46d877b2790b 100644 --- a/packages/react-native/ReactAndroid/build.gradle.kts +++ b/packages/react-native/ReactAndroid/build.gradle.kts @@ -750,7 +750,26 @@ kotlin { explicitApi() } -tasks.withType { jvmArgs = listOf("-Xshare:off") } +tasks.withType { + jvmArgs = listOf("-Xshare:off") + + // Performance / memory tests are tagged with @Category(PerformanceTest::class) and excluded + // from the default test run because they take seconds and need extra heap. Opt in with + // `-PrunPerfTests=true`. + val runPerfTests = + (project.findProperty("runPerfTests") as? String)?.toBoolean() ?: false + useJUnit { + if (runPerfTests) { + includeCategories("com.facebook.react.devsupport.PerformanceTest") + } else { + excludeCategories("com.facebook.react.devsupport.PerformanceTest") + } + } + if (runPerfTests) { + maxHeapSize = "2g" + jvmArgs("-XX:+AlwaysPreTouch") + } +} /* Publishing Configuration */ apply(from = "./publish.gradle") diff --git a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt index e315d5a2d44d..eba09b1e19bf 100644 --- a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt +++ b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/BundleDownloader.kt @@ -27,6 +27,7 @@ import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response import okio.Buffer +import okio.BufferedSink import okio.BufferedSource import okio.Okio import org.json.JSONException @@ -183,82 +184,28 @@ public class BundleDownloader public constructor(private val client: OkHttpClien } val source = checkNotNull(response.body()?.source()) val bodyReader = MultipartStreamReader(source, boundary) - val completed = - bodyReader.readAllParts( - object : ChunkListener { - @Throws(IOException::class) - override fun onChunkComplete( - headers: Map, - body: Buffer, - isLastChunk: Boolean, - ) { - // This will get executed for every chunk of the multipart response. The last chunk - // (isLastChunk = true) will be the JS bundle, the other ones will be progress - // events - // encoded as JSON. - if (isLastChunk) { - // The http status code for each separate chunk is in the X-Http-Status header. - var status = response.code() - if (headers.containsKey("X-Http-Status")) { - status = headers.getOrDefault("X-Http-Status", "0").toInt() - } - processBundleResult( - url, - status, - Headers.of(headers), - body, - outputFile, - bundleInfo, - callback, - ) - } else { - if ( - !headers.containsKey("Content-Type") || - headers["Content-Type"] != "application/json" - ) { - return - } - - try { - val progress = JSONObject(body.readUtf8()) - val status = - if (progress.has("status")) progress.getString("status") else "Bundling" - var done: Int? = null - if (progress.has("done")) { - done = progress.getInt("done") - } - var total: Int? = null - if (progress.has("total")) { - total = progress.getInt("total") - } - var percent: Int? = null - if (progress.has("percent")) { - percent = progress.getInt("percent") - } - callback.onProgress(status, done, total, percent) - } catch (e: JSONException) { - FLog.e(ReactConstants.TAG, "Error parsing progress JSON. $e") - } - } - } - - override fun onChunkProgress( - headers: Map, - loaded: Long, - total: Long, - ) { - if ("application/javascript" == headers["Content-Type"]) { - callback.onProgress( - "Downloading", - (loaded / 1024).toInt(), - (total / 1024).toInt(), - null, - ) - } - } - } + val tmpFile = File(outputFile.path + ".tmp") + val streamingHandler = + StreamingBundleChunkListener( + url = url, + outerStatus = response.code(), + outputFile = outputFile, + tmpFile = tmpFile, + bundleInfo = bundleInfo, + callback = callback, ) + val completed: Boolean = + try { + bodyReader.readAllParts(streamingHandler) + } finally { + streamingHandler.closeOpenSinkQuietly() + } if (!completed) { + // If we partially wrote a tmp file before the upstream died, scrap it so we don't leave + // half-baked bundles on disk. + if (tmpFile.exists()) { + tmpFile.delete() + } callback.onFailure( DebugServerException( (""" @@ -276,6 +223,142 @@ public class BundleDownloader public constructor(private val client: OkHttpClien } } + /** + * Routes multipart chunks for a bundle download. The JS bundle chunk (Content-Type + * `application/javascript` with an effective HTTP status of 200) is streamed directly into + * a temporary file via a [BufferedSink], so no copy of the body is held in heap. Progress + * JSON chunks and error responses are buffered in memory because they're either tiny or + * bounded, and the listener needs to parse them in full. + */ + private inner class StreamingBundleChunkListener( + private val url: String, + private val outerStatus: Int, + private val outputFile: File, + private val tmpFile: File, + private val bundleInfo: BundleInfo?, + private val callback: DevBundleDownloadListener, + ) : ChunkListener { + + private var bundleSink: BufferedSink? = null + + @Throws(IOException::class) + override fun onChunkHeader(headers: Map): BufferedSink? { + if (!isJsBundleChunk(headers)) return null + val effectiveStatus = effectiveStatus(headers) + if (effectiveStatus != 200) return null + // Stream the JS bundle straight to disk — never materialize in heap. + val sink = Okio.buffer(Okio.sink(tmpFile)) + bundleSink = sink + return sink + } + + @Throws(IOException::class) + override fun onChunkComplete( + headers: Map, + body: Buffer?, + isLastChunk: Boolean, + ) { + val sink = bundleSink + if (sink != null) { + bundleSink = null + sink.close() + finalizeStreamedBundle(headers) + return + } + when { + isJsBundleChunk(headers) -> { + // Bundle returned with an error status — it was buffered so we can surface a useful + // diagnostic to the developer. + val buffered = body ?: Buffer() + processBundleResult( + url, + effectiveStatus(headers), + Headers.of(headers), + buffered, + outputFile, + bundleInfo, + callback, + ) + } + isProgressChunk(headers) -> dispatchProgressJson(body) + else -> { + // Unknown chunk type. Log so a future Metro change is visible in logcat instead of + // silently stranding the dev loading view at 99%. + FLog.w(TAG, "Ignoring multipart chunk with unrecognized Content-Type: ${headers["Content-Type"]}") + } + } + } + + override fun onChunkProgress( + headers: Map, + loaded: Long, + total: Long, + ) { + if ("application/javascript" == headers["Content-Type"]) { + callback.onProgress( + "Downloading", + (loaded / 1024).toInt(), + (total / 1024).toInt(), + null, + ) + } + } + + /** Make sure we never leak the tmp-file sink if [readAllParts] throws mid-stream. */ + fun closeOpenSinkQuietly() { + val sink = bundleSink ?: return + bundleSink = null + try { + sink.close() + } catch (e: IOException) { + FLog.w(TAG, "Failed to close partial bundle sink", e) + } + } + + @Throws(IOException::class) + private fun finalizeStreamedBundle(headers: Map) { + if (bundleInfo != null) { + populateBundleInfo(url, Headers.of(headers), bundleInfo) + } + if (!tmpFile.renameTo(outputFile)) { + throw IOException("Couldn't rename $tmpFile to $outputFile") + } + callback.onSuccess() + } + + private fun dispatchProgressJson(body: Buffer?) { + val payload = body ?: return + try { + val progress = JSONObject(payload.readUtf8()) + val status = + if (progress.has("status")) progress.getString("status") else "Bundling" + val done = if (progress.has("done")) progress.getInt("done") else null + val total = if (progress.has("total")) progress.getInt("total") else null + val percent = if (progress.has("percent")) progress.getInt("percent") else null + callback.onProgress(status, done, total, percent) + } catch (e: JSONException) { + FLog.e(ReactConstants.TAG, "Error parsing progress JSON. $e") + } + } + + private fun effectiveStatus(headers: Map): Int = + headers["X-Http-Status"]?.toIntOrNull() ?: outerStatus + + /** + * Extract the media type (the part before `;`) from a Content-Type header, lower-cased. + * Metro sends e.g. `application/javascript; charset=UTF-8`, so a bare-string equality + * check would miss the bundle chunk and leave the dev loading view stranded. + */ + private fun mediaType(headers: Map): String? = + headers["Content-Type"]?.substringBefore(';')?.trim()?.lowercase() + + private fun isJsBundleChunk(headers: Map): Boolean = + mediaType(headers) == "application/javascript" + + private fun isProgressChunk(headers: Map): Boolean = + mediaType(headers) == "application/json" + } + @Throws(IOException::class) private fun processBundleResult( url: String, diff --git a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt index 5ff3dc94532b..647371ee0932 100644 --- a/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt +++ b/packages/react-native/ReactAndroid/src/main/java/com/facebook/react/devsupport/MultipartStreamReader.kt @@ -10,158 +10,242 @@ package com.facebook.react.devsupport import java.io.IOException -import kotlin.math.max import okio.Buffer +import okio.BufferedSink import okio.BufferedSource import okio.ByteString -/** Utility class to parse the body of a response of type multipart/mixed. */ +/** + * Streaming parser for `multipart/mixed` responses. + * + * Unlike a buffer-all-then-split parser, this implementation keeps a working buffer that is at + * most `READ_CHUNK_SIZE + maxDelimLen` bytes large. Body bytes for a chunk are either: + * + * * delivered to a [BufferedSink] returned by [ChunkListener.onChunkHeader] (preferred for + * large bodies — e.g. the JS bundle), in which case they never accumulate in the reader's + * heap; or + * * accumulated into a per-chunk [Buffer] and delivered via [ChunkListener.onChunkComplete] + * (preferred for small bodies like progress JSON, where the listener wants to parse them). + * + * The reader does not know whether a given chunk is the final one until it encounters the next + * delimiter. Listeners that need to route based on "is this the last chunk?" must instead use + * the chunk headers (e.g. Content-Type or X-Http-Status). + */ internal class MultipartStreamReader( private val source: BufferedSource, - private val boundary: String, + boundary: String, ) { + + private val regularDelim: ByteString = ByteString.encodeUtf8("$CRLF--$boundary$CRLF") + private val closeDelim: ByteString = ByteString.encodeUtf8("$CRLF--$boundary--$CRLF") + private val headerSep: ByteString = ByteString.encodeUtf8("$CRLF$CRLF") + private val maxDelimLen: Long = maxOf(regularDelim.size(), closeDelim.size()).toLong() + private var lastProgressEvent: Long = 0 interface ChunkListener { - /** Invoked when a chunk of a multipart response is fully downloaded. */ + /** + * Invoked when a new chunk's headers have been parsed but before its body is read. + * + * Return a [BufferedSink] to have the reader stream the chunk body directly into it. In + * that case, the body bytes are never buffered in the reader and the `body` argument to + * [onChunkComplete] will be `null`. + * + * Return `null` to have the reader accumulate the body in memory and pass it to + * [onChunkComplete] as a [Buffer]. This is appropriate for small chunks that the listener + * intends to parse in full (e.g. JSON progress events). + * + * The reader does not know whether this chunk is the last one until it encounters the + * next delimiter — routing decisions must rely on the supplied [headers]. + */ @Throws(IOException::class) - fun onChunkComplete(headers: Map, body: Buffer, isLastChunk: Boolean) + fun onChunkHeader(headers: Map): BufferedSink? - /** Invoked as bytes of the current chunk are read. */ + /** + * Invoked when the chunk body is fully consumed. + * + * @param body the accumulated body, non-null iff [onChunkHeader] returned `null`. + */ + @Throws(IOException::class) + fun onChunkComplete(headers: Map, body: Buffer?, isLastChunk: Boolean) + + /** Invoked at most once every ~16 ms while the current chunk's body is being read. */ @Throws(IOException::class) fun onChunkProgress(headers: Map, loaded: Long, total: Long) } /** - * Reads all parts of the multipart response and execute the listener for each chunk received. + * Read all parts of the multipart response and invoke the listener for each chunk received. * - * @param listener Listener invoked when chunks are received. - * @return If the read was successful + * @return `true` if a valid closing delimiter was reached; `false` if the upstream ended + * prematurely. */ @Throws(IOException::class) fun readAllParts(listener: ChunkListener): Boolean { - val delimiter: ByteString = ByteString.encodeUtf8("$CRLF--$boundary$CRLF") - val closeDelimiter: ByteString = ByteString.encodeUtf8("$CRLF--$boundary--$CRLF") - val headersDelimiter: ByteString = ByteString.encodeUtf8(CRLF + CRLF) + val buffer = Buffer() - val bufferLen = 4 * 1024 - var chunkStart: Long = 0 - var bytesSeen: Long = 0 - val content = Buffer() - var currentHeaders: Map? = null - var currentHeadersLength: Long = 0 + // Skip the preamble — discard bytes until the first regular delimiter appears. We never + // observe a close delimiter before the first regular one in a well-formed response. + if (!skipUntil(buffer, regularDelim)) return false + buffer.skip(regularDelim.size().toLong()) while (true) { - var isCloseDelimiter = false - - // Search only a subset of chunk that we haven't seen before + few bytes - // to allow for the edge case when the delimiter is cut by read call. - val searchStart = - max((bytesSeen - closeDelimiter.size()).toDouble(), chunkStart.toDouble()).toLong() - var indexOfDelimiter = content.indexOf(delimiter, searchStart) - if (indexOfDelimiter == -1L) { - isCloseDelimiter = true - indexOfDelimiter = content.indexOf(closeDelimiter, searchStart) - } - - if (indexOfDelimiter == -1L) { - bytesSeen = content.size() - - if (currentHeaders == null) { - val indexOfHeaders = content.indexOf(headersDelimiter, searchStart) - if (indexOfHeaders >= 0) { - source.read(content, indexOfHeaders) - val headers = Buffer() - content.copyTo(headers, searchStart, indexOfHeaders - searchStart) - currentHeadersLength = headers.size() + headersDelimiter.size() - currentHeaders = parseHeaders(headers) + val headers = readHeaders(buffer) ?: return false + + val sink = listener.onChunkHeader(headers) + val accumulator: Buffer? = if (sink == null) Buffer() else null + val contentLength = headers["Content-Length"]?.toLongOrNull() ?: 0L + + var bodyDelivered = 0L + var isLast = false + var done = false + + while (!done) { + val hit = findDelimiter(buffer) + if (hit != null) { + // Body ends at hit.index; transfer those bytes, then consume the delimiter. + if (hit.index > 0) { + transfer(buffer, hit.index, sink, accumulator) + bodyDelivered += hit.index } + buffer.skip(hit.delimSize) + isLast = hit.isClose + done = true } else { - emitProgress(currentHeaders, content.size() - currentHeadersLength, false, listener) + // No delimiter yet — drain bytes that cannot possibly start an upcoming match (keep + // the last `maxDelimLen - 1` bytes as lookahead) and read more from upstream. + val safeToDrain = buffer.size() - (maxDelimLen - 1) + if (safeToDrain > 0) { + transfer(buffer, safeToDrain, sink, accumulator) + bodyDelivered += safeToDrain + emitProgress(headers, bodyDelivered, contentLength, isFinal = false, listener) + } + val read = source.read(buffer, READ_CHUNK_SIZE) + if (read <= 0L) return false } + } - val bytesRead = source.read(content, bufferLen.toLong()) - if (bytesRead <= 0) { - return false - } - continue + emitProgress(headers, bodyDelivered, contentLength, isFinal = true, listener) + sink?.flush() + listener.onChunkComplete(headers, accumulator, isLast) + + if (isLast) return true + } + } + + /** Read from upstream until [delim] appears in [buffer]; do not consume the delimiter. */ + @Throws(IOException::class) + private fun skipUntil(buffer: Buffer, delim: ByteString): Boolean { + val keep = (delim.size() - 1).toLong() + while (true) { + val idx = buffer.indexOf(delim) + if (idx >= 0) { + buffer.skip(idx) + return true } + val drop = buffer.size() - keep + if (drop > 0) buffer.skip(drop) + val read = source.read(buffer, READ_CHUNK_SIZE) + if (read <= 0L) return false + } + } - val chunkEnd = indexOfDelimiter - val length = chunkEnd - chunkStart - - // Ignore preamble - if (chunkStart > 0) { - val chunk = Buffer() - content.skip(chunkStart) - content.read(chunk, length) - emitProgress(currentHeaders, chunk.size() - currentHeadersLength, true, listener) - emitChunk(chunk, isCloseDelimiter, listener) - currentHeaders = null - currentHeadersLength = 0 - } else { - content.skip(chunkEnd) + /** + * Read and parse the chunk header block, consuming the trailing CRLF CRLF separator. If a + * delimiter is encountered before the header separator, the chunk is treated as having no + * headers and an empty map is returned without consuming the delimiter. + */ + @Throws(IOException::class) + private fun readHeaders(buffer: Buffer): Map? { + while (true) { + val sepIdx = buffer.indexOf(headerSep) + val regIdx = buffer.indexOf(regularDelim) + val closeIdx = buffer.indexOf(closeDelim) + val nextDelim = minNonNegative(regIdx, closeIdx) + + if (sepIdx >= 0 && (nextDelim < 0 || sepIdx < nextDelim)) { + val headersBuf = Buffer() + buffer.read(headersBuf, sepIdx) + buffer.skip(headerSep.size().toLong()) + return parseHeaders(headersBuf) } - if (isCloseDelimiter) { - return true + if (nextDelim >= 0) { + // Chunk has no headers section; let the body loop consume the delimiter. + return emptyMap() } - chunkStart = delimiter.size().toLong() - bytesSeen = chunkStart + val read = source.read(buffer, READ_CHUNK_SIZE) + if (read <= 0L) return null + } + } + + /** Locate whichever of the regular or close delimiters appears first in [buffer]. */ + private fun findDelimiter(buffer: Buffer): DelimiterHit? { + val regIdx = buffer.indexOf(regularDelim) + val closeIdx = buffer.indexOf(closeDelim) + return when { + regIdx < 0 && closeIdx < 0 -> null + regIdx < 0 -> DelimiterHit(closeIdx, closeDelim.size().toLong(), isClose = true) + closeIdx < 0 -> DelimiterHit(regIdx, regularDelim.size().toLong(), isClose = false) + regIdx <= closeIdx -> DelimiterHit(regIdx, regularDelim.size().toLong(), isClose = false) + else -> DelimiterHit(closeIdx, closeDelim.size().toLong(), isClose = true) + } + } + + private data class DelimiterHit(val index: Long, val delimSize: Long, val isClose: Boolean) + + /** + * Transfer [byteCount] bytes from [src] to either [sink] (streaming case) or [accumulator] + * (buffered case). Both branches use okio segment-move semantics, so no per-byte copy + * happens for large transfers. + */ + @Throws(IOException::class) + private fun transfer(src: Buffer, byteCount: Long, sink: BufferedSink?, accumulator: Buffer?) { + if (byteCount <= 0L) return + when { + sink != null -> sink.write(src, byteCount) + accumulator != null -> accumulator.write(src, byteCount) + else -> src.skip(byteCount) } } private fun parseHeaders(data: Buffer): Map { val headers: MutableMap = mutableMapOf() val text = data.readUtf8() - val lines = text.split(CRLF.toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() + val lines = text.split(CRLF).dropLastWhile { it.isEmpty() } for (line in lines) { - val indexOfSeparator = line.indexOf(":") - if (indexOfSeparator == -1) { - continue - } - val key = line.substring(0, indexOfSeparator).trim { it <= ' ' } - val value = line.substring(indexOfSeparator + 1).trim { it <= ' ' } + val sep = line.indexOf(':') + if (sep == -1) continue + val key = line.substring(0, sep).trim { it <= ' ' } + val value = line.substring(sep + 1).trim { it <= ' ' } headers[key] = value } return headers } - @Throws(IOException::class) - private fun emitChunk(chunk: Buffer, done: Boolean, listener: ChunkListener) { - val marker: ByteString = ByteString.encodeUtf8(CRLF + CRLF) - val indexOfMarker = chunk.indexOf(marker) - if (indexOfMarker == -1L) { - listener.onChunkComplete(emptyMap(), chunk, done) - } else { - val headers = Buffer() - val body = Buffer() - chunk.read(headers, indexOfMarker) - chunk.skip(marker.size().toLong()) - chunk.readAll(body) - listener.onChunkComplete(parseHeaders(headers), body, done) - } - } - @Throws(IOException::class) private fun emitProgress( - headers: Map?, + headers: Map, + loaded: Long, contentLength: Long, isFinal: Boolean, - listener: ChunkListener?, + listener: ChunkListener, ) { - if (listener == null || headers == null) { - return - } - val currentTime = System.currentTimeMillis() - if (currentTime - lastProgressEvent > 16 || isFinal) { - lastProgressEvent = currentTime - val headersContentLength = headers.getOrDefault("Content-Length", "0").toLong() - listener.onChunkProgress(headers, contentLength, headersContentLength) + val now = System.currentTimeMillis() + if (isFinal || now - lastProgressEvent > 16) { + lastProgressEvent = now + listener.onChunkProgress(headers, loaded, contentLength) } } + private fun minNonNegative(a: Long, b: Long): Long = + when { + a < 0 -> b + b < 0 -> a + else -> minOf(a, b) + } + companion object { - // Standard line separator for HTTP. private const val CRLF = "\r\n" + private const val READ_CHUNK_SIZE: Long = 16L * 1024L } } diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/AllocationProbe.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/AllocationProbe.kt new file mode 100644 index 000000000000..c313e2758deb --- /dev/null +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/AllocationProbe.kt @@ -0,0 +1,168 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.facebook.react.devsupport + +import org.junit.Assume + +/** + * Thin wrapper over HotSpot's `com.sun.management.ThreadMXBean` and the JMX memory pool API to + * give performance tests a uniform way to measure: + * * **allocated bytes** per thread (cumulative, GC-independent) — the right metric for streaming + * code. + * * **peak heap usage** across all heap pools (coarse upper bound; affected by GC timing). + * + * Everything is accessed via reflection because Android's `android.jar` (used to compile + * library unit tests) strips the `java.lang.management` package. At runtime the host HotSpot + * JVM provides the full implementation. + */ +internal object AllocationProbe { + + // --- Thread allocation (com.sun.management.ThreadMXBean) --------------------------------- + private val threadMxBean: Any? + private val getThreadAllocatedBytesSingle: java.lang.reflect.Method? + private val getThreadAllocatedBytesBulk: java.lang.reflect.Method? + private val getAllThreadIds: java.lang.reflect.Method? + private val setThreadAllocatedMemoryEnabled: java.lang.reflect.Method? + private val isThreadAllocatedMemorySupported: java.lang.reflect.Method? + + // --- Heap pool peak usage (java.lang.management.MemoryPoolMXBean) ------------------------ + private val heapPools: List + private val getPeakUsage: java.lang.reflect.Method? + private val resetPeakUsage: java.lang.reflect.Method? + private val memoryUsageGetUsed: java.lang.reflect.Method? + + init { + val managementFactory = runCatching { Class.forName("java.lang.management.ManagementFactory") } + .getOrNull() + val sunThreadMxBeanClass = + runCatching { Class.forName("com.sun.management.ThreadMXBean") }.getOrNull() + val threadMxBeanClass = + runCatching { Class.forName("java.lang.management.ThreadMXBean") }.getOrNull() + + threadMxBean = + runCatching { managementFactory?.getMethod("getThreadMXBean")?.invoke(null) } + .getOrNull() + ?.takeIf { sunThreadMxBeanClass?.isInstance(it) == true } + + getThreadAllocatedBytesSingle = + sunThreadMxBeanClass?.declaredMethods?.firstOrNull { + it.name == "getThreadAllocatedBytes" && + it.parameterTypes.size == 1 && + it.parameterTypes[0] == Long::class.javaPrimitiveType + } + getThreadAllocatedBytesBulk = + sunThreadMxBeanClass?.declaredMethods?.firstOrNull { + it.name == "getThreadAllocatedBytes" && + it.parameterTypes.size == 1 && + it.parameterTypes[0] == LongArray::class.java + } + getAllThreadIds = + runCatching { threadMxBeanClass?.getMethod("getAllThreadIds") }.getOrNull() + setThreadAllocatedMemoryEnabled = + runCatching { + sunThreadMxBeanClass?.getMethod( + "setThreadAllocatedMemoryEnabled", + Boolean::class.javaPrimitiveType, + ) + } + .getOrNull() + isThreadAllocatedMemorySupported = + runCatching { + sunThreadMxBeanClass?.getMethod("isThreadAllocatedMemorySupported") + } + .getOrNull() + + // Heap pool plumbing. + val memoryTypeClass = runCatching { Class.forName("java.lang.management.MemoryType") }.getOrNull() + val heapEnum = + runCatching { memoryTypeClass?.getField("HEAP")?.get(null) }.getOrNull() + val memoryPoolMxBeanClass = + runCatching { Class.forName("java.lang.management.MemoryPoolMXBean") }.getOrNull() + val getType = + runCatching { memoryPoolMxBeanClass?.getMethod("getType") }.getOrNull() + val allPools: List = + runCatching { + @Suppress("UNCHECKED_CAST") + (managementFactory?.getMethod("getMemoryPoolMXBeans")?.invoke(null) as? List) + ?: emptyList() + } + .getOrDefault(emptyList()) + heapPools = + if (getType != null && heapEnum != null) { + allPools.filter { runCatching { getType.invoke(it) == heapEnum }.getOrDefault(false) } + } else emptyList() + + getPeakUsage = + runCatching { memoryPoolMxBeanClass?.getMethod("getPeakUsage") }.getOrNull() + resetPeakUsage = + runCatching { memoryPoolMxBeanClass?.getMethod("resetPeakUsage") }.getOrNull() + memoryUsageGetUsed = + runCatching { Class.forName("java.lang.management.MemoryUsage").getMethod("getUsed") } + .getOrNull() + } + + /** Skips the calling test if per-thread allocation tracking isn't available. */ + fun requireSupported() { + Assume.assumeTrue( + "com.sun.management.ThreadMXBean is unavailable (non-HotSpot JVM?)", + threadMxBean != null && getThreadAllocatedBytesSingle != null, + ) + val supported = + runCatching { isThreadAllocatedMemorySupported?.invoke(threadMxBean) as? Boolean } + .getOrNull() ?: false + Assume.assumeTrue("Per-thread allocated memory is not supported on this JVM", supported) + runCatching { setThreadAllocatedMemoryEnabled?.invoke(threadMxBean, true) } + } + + /** Cumulative bytes allocated on [threadId] since that thread started, or 0 if unsupported. */ + fun allocatedBytes(threadId: Long): Long = + runCatching { + getThreadAllocatedBytesSingle?.invoke(threadMxBean, threadId) as? Long ?: 0L + } + .getOrDefault(0L) + + /** Sum of cumulative allocations across every currently live thread. */ + fun totalAllocatedBytes(): Long { + val bean = threadMxBean ?: return 0L + val ids = runCatching { getAllThreadIds?.invoke(bean) as? LongArray }.getOrNull() ?: return 0L + val arr = + runCatching { getThreadAllocatedBytesBulk?.invoke(bean, ids) as? LongArray }.getOrNull() + ?: return 0L + var sum = 0L + for (v in arr) if (v > 0) sum += v + return sum + } + + /** Reset peak heap counters across all heap pools. Call before a measured section. */ + fun resetPeakHeap() { + val reset = resetPeakUsage ?: return + heapPools.forEach { runCatching { reset.invoke(it) } } + } + + /** Peak bytes used across all heap pools since the last [resetPeakHeap]. */ + fun peakHeapBytes(): Long { + val getUsage = getPeakUsage ?: return 0L + val getUsed = memoryUsageGetUsed ?: return 0L + var total = 0L + for (pool in heapPools) { + val usage = runCatching { getUsage.invoke(pool) }.getOrNull() ?: continue + total += runCatching { getUsed.invoke(usage) as? Long ?: 0L }.getOrDefault(0L) + } + return total + } + + /** Encourage the runtime to run a full GC before a measurement. */ + fun settle() { + System.gc() + Thread.sleep(50) + System.gc() + } + + /** Format a byte count as `12.34 MB`. */ + fun fmt(bytes: Long): String = String.format("%.2f MB", bytes / 1024.0 / 1024.0) +} diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/BundleDownloaderPerfTest.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/BundleDownloaderPerfTest.kt new file mode 100644 index 000000000000..ceabacc2cdad --- /dev/null +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/BundleDownloaderPerfTest.kt @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +@file:Suppress("DEPRECATION_ERROR") // Conflicting okhttp/okio versions + +package com.facebook.react.devsupport + +import com.facebook.react.devsupport.interfaces.DevBundleDownloadListener +import java.io.File +import java.nio.file.Files +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicReference +import okhttp3.Interceptor +import okhttp3.MediaType +import okhttp3.OkHttpClient +import okhttp3.Protocol +import okhttp3.Response +import okhttp3.ResponseBody +import okio.Okio +import org.assertj.core.api.Assertions.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.experimental.categories.Category + +/** + * End-to-end performance & memory regression test for [BundleDownloader] against a 100 MB JS + * bundle. The OkHttp call is short-circuited by an [Interceptor] that returns a synthetic + * `multipart/mixed` [ResponseBody] backed by [LargeMultipartSource], so no socket is involved + * and the server-side never holds the payload. + * + * Run with: + * ``` + * ./gradlew :packages:react-native:ReactAndroid:testDebugUnitTest \ + * -PrunPerfTests -Preact.internal.useHermesNightly=true \ + * --tests "*BundleDownloaderPerfTest" + * ``` + */ +@Category(PerformanceTest::class) +class BundleDownloaderPerfTest { + + private val boundary = "perf_boundary" + private val payloadBytes = 100L * 1024 * 1024 // 100 MB + private val bundleUrl = "http://localhost/perf.bundle" + + private lateinit var tmpDir: File + private lateinit var outputFile: File + + @Before + fun setUp() { + AllocationProbe.requireSupported() + tmpDir = Files.createTempDirectory("bundle-downloader-perf").toFile() + outputFile = File(tmpDir, "bundle.js") + } + + @After + fun tearDown() { + tmpDir.deleteRecursively() + } + + @Test + fun downloads100MBMultipartBundleWithBoundedAllocation() { + val workerThreadId = AtomicLong(-1L) + val workerAllocStart = AtomicLong(0L) + + val syntheticInterceptor = Interceptor { chain -> + // We're on the OkHttp dispatcher thread at this point — capture it so we can measure the + // allocations the read path actually attributes to it. + val tid = Thread.currentThread().id + workerThreadId.set(tid) + workerAllocStart.set(AllocationProbe.allocatedBytes(tid)) + + val mediaType = MediaType.parse("multipart/mixed; boundary=\"$boundary\"") + val source = Okio.buffer(LargeMultipartSource(boundary, payloadBytes)) + val body: ResponseBody = ResponseBody.create(mediaType, -1L, source) + + Response.Builder() + .request(chain.request()) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("OK") + .header("content-type", "multipart/mixed; boundary=\"$boundary\"") + .body(body) + .build() + } + + val client = OkHttpClient.Builder().addInterceptor(syntheticInterceptor).build() + val downloader = BundleDownloader(client) + + val done = CountDownLatch(1) + val failure = AtomicReference(null) + val listener = + object : DevBundleDownloadListener { + override fun onSuccess() = done.countDown() + + override fun onProgress(status: String?, done: Int?, total: Int?, percent: Int?) = Unit + + override fun onFailure(cause: Exception) { + failure.set(cause) + done.countDown() + } + } + + val testThreadId = Thread.currentThread().id + AllocationProbe.settle() + AllocationProbe.resetPeakHeap() + val testAllocBefore = AllocationProbe.allocatedBytes(testThreadId) + val totalAllocBefore = AllocationProbe.totalAllocatedBytes() + val nanosBefore = System.nanoTime() + + downloader.downloadBundleFromURL(listener, outputFile, bundleUrl, BundleDownloader.BundleInfo()) + + assertThat(done.await(120, TimeUnit.SECONDS)) + .`as`("Download did not complete within timeout") + .isTrue + assertThat(failure.get()).isNull() + + val elapsedMs = (System.nanoTime() - nanosBefore) / 1_000_000 + val testAllocated = AllocationProbe.allocatedBytes(testThreadId) - testAllocBefore + val totalAllocated = AllocationProbe.totalAllocatedBytes() - totalAllocBefore + val workerAllocated = + if (workerThreadId.get() >= 0) + AllocationProbe.allocatedBytes(workerThreadId.get()) - workerAllocStart.get() + else 0L + val peakHeap = AllocationProbe.peakHeapBytes() + + println( + "[BundleDownloaderPerfTest] payload=${AllocationProbe.fmt(payloadBytes)} " + + "elapsed=${elapsedMs}ms " + + "test-thread-allocated=${AllocationProbe.fmt(testAllocated)} " + + "worker-thread-allocated=${AllocationProbe.fmt(workerAllocated)} " + + "all-threads-allocated=${AllocationProbe.fmt(totalAllocated)} " + + "peak-heap=${AllocationProbe.fmt(peakHeap)} " + + "output-size=${AllocationProbe.fmt(outputFile.length())}" + ) + + // Correctness: the bundle was streamed straight to disk; the file size equals the + // synthetic payload (post multipart-parsing). + assertThat(outputFile.length()).isEqualTo(payloadBytes) + + // Memory: peak heap is the property that proves the bundle is streamed to disk rather + // than retained in heap. For a 100 MB payload it should sit well under the payload size; + // the 80 MB ceiling leaves room for OkHttp dispatcher warmup and cross-machine variance. + // + // We deliberately do NOT assert on allocated bytes: okio's SegmentPool is capped at 64 KB + // so the OkHttp/okio pipeline always churns ~payloadBytes of segment allocations, + // regardless of whether BundleDownloader retains the body. Peak heap is the right metric. + assertThat(peakHeap) + .`as`("Peak heap should be O(buffer size), not O(payload)") + .isLessThan(80L * 1024 * 1024) + } +} diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/LargeMultipartSource.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/LargeMultipartSource.kt new file mode 100644 index 000000000000..5b6b63d2dcc4 --- /dev/null +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/LargeMultipartSource.kt @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.facebook.react.devsupport + +import okio.Buffer +import okio.Source +import okio.Timeout + +/** + * An [okio.Source] that synthesizes a syntactically valid `multipart/mixed` response containing a + * single application/javascript payload of [payloadBytes] bytes. + * + * The bytes are produced lazily, so the test harness itself never holds the full payload in heap. + * That makes it safe to use this in allocation- and peak-heap-sensitive tests at sizes (e.g. + * 100 MB) that would otherwise dominate any measurement of the code under test. + * + * The emitted framing matches what [MultipartStreamReader] expects: a CRLF preamble followed by + * `--\r\n`, the headers block, the payload, and a final closing delimiter + * `\r\n----\r\n`. + */ +internal class LargeMultipartSource( + private val boundary: String, + private val payloadBytes: Long, + /** Maximum bytes returned per [read] call. Keeps the synthesizer's own memory tiny. */ + private val chunkSize: Int = 64 * 1024, +) : Source { + + private enum class Phase { + PREAMBLE, + HEADERS, + PAYLOAD, + CLOSE, + DONE, + } + + // Bytes that bracket the payload. Computed once, reused by reference. + // NB: the reader's delimiter is "\r\n--\r\n", so the preamble must end with CRLF. + // A bare CRLF is the minimal valid preamble. + private val preamble: ByteArray = "\r\n".toByteArray(Charsets.UTF_8) + // Note: matches what Metro actually sends. The `charset=UTF-8` parameter exists to catch + // any future regression to bare-string Content-Type matching in BundleDownloader. + private val headers: ByteArray = + ("--$boundary\r\n" + + "Content-Type: application/javascript; charset=UTF-8\r\n" + + "Content-Length: $payloadBytes\r\n" + + "\r\n") + .toByteArray(Charsets.UTF_8) + private val close: ByteArray = "\r\n--$boundary--\r\n".toByteArray(Charsets.UTF_8) + + // Single reused filler buffer. The exact byte value doesn't matter as long as it never spells + // out the boundary string. + private val filler: ByteArray = ByteArray(chunkSize) { 'A'.code.toByte() } + + private var phase: Phase = Phase.PREAMBLE + private var payloadRemaining: Long = payloadBytes + + /** Total number of bytes this source will emit over its lifetime. Useful for assertions. */ + val totalBytes: Long + get() = preamble.size + headers.size + payloadBytes + close.size + + override fun read(sink: Buffer, byteCount: Long): Long { + require(byteCount >= 0) { "byteCount < 0: $byteCount" } + if (byteCount == 0L) return 0 + return when (phase) { + Phase.PREAMBLE -> { + sink.write(preamble) + phase = Phase.HEADERS + preamble.size.toLong() + } + Phase.HEADERS -> { + sink.write(headers) + phase = if (payloadRemaining > 0) Phase.PAYLOAD else Phase.CLOSE + headers.size.toLong() + } + Phase.PAYLOAD -> { + val n = minOf(byteCount, payloadRemaining, filler.size.toLong()).toInt() + sink.write(filler, 0, n) + payloadRemaining -= n + if (payloadRemaining == 0L) phase = Phase.CLOSE + n.toLong() + } + Phase.CLOSE -> { + sink.write(close) + phase = Phase.DONE + close.size.toLong() + } + Phase.DONE -> -1L + } + } + + override fun timeout(): Timeout = Timeout.NONE + + override fun close() = Unit +} diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderPerfTest.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderPerfTest.kt new file mode 100644 index 000000000000..142a4d5b4ad6 --- /dev/null +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderPerfTest.kt @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +@file:Suppress("DEPRECATION_ERROR") // Conflicting okio versions + +package com.facebook.react.devsupport + +import okio.Buffer +import okio.BufferedSink +import okio.Okio +import okio.Sink +import okio.Timeout +import org.assertj.core.api.Assertions.assertThat +import org.junit.Before +import org.junit.Test +import org.junit.experimental.categories.Category + +/** + * Performance & memory regression test for [MultipartStreamReader] against a 100 MB JavaScript + * payload. Exercises the streaming path: the listener returns a [BufferedSink] from + * [MultipartStreamReader.ChunkListener.onChunkHeader], so the body bytes should be transferred + * to the sink without accumulating in heap. + * + * Run with: + * ``` + * ./gradlew :packages:react-native:ReactAndroid:testDebugUnitTest \ + * -PrunPerfTests -Preact.internal.useHermesNightly=true \ + * --tests "*MultipartStreamReaderPerfTest" + * ``` + */ +@Category(PerformanceTest::class) +class MultipartStreamReaderPerfTest { + + private val boundary = "perf_boundary" + private val payloadBytes = 100L * 1024 * 1024 // 100 MB + + @Before + fun setUp() { + AllocationProbe.requireSupported() + } + + @Test + fun streams100MBBundleWithBoundedAllocation() { + val syntheticSource = LargeMultipartSource(boundary, payloadBytes) + val bufferedSource = Okio.buffer(syntheticSource) + val reader = MultipartStreamReader(bufferedSource, boundary) + + val discardingSink = CountingDiscardingSink() + val bufferedSink = Okio.buffer(discardingSink) + var receivedHeaders: Map = emptyMap() + var bufferDeliveredViaComplete = false + + val listener = + object : MultipartStreamReader.ChunkListener { + override fun onChunkHeader(headers: Map): BufferedSink { + receivedHeaders = headers + return bufferedSink + } + + override fun onChunkComplete( + headers: Map, + body: Buffer?, + isLastChunk: Boolean, + ) { + // body must be null when we returned a sink from onChunkHeader. + if (body != null) bufferDeliveredViaComplete = true + } + + override fun onChunkProgress( + headers: Map, + loaded: Long, + total: Long, + ) = Unit + } + + val threadId = Thread.currentThread().id + AllocationProbe.settle() + AllocationProbe.resetPeakHeap() + val allocBefore = AllocationProbe.allocatedBytes(threadId) + val nanosBefore = System.nanoTime() + + val success = reader.readAllParts(listener) + bufferedSink.flush() + + val elapsedMs = (System.nanoTime() - nanosBefore) / 1_000_000 + val allocated = AllocationProbe.allocatedBytes(threadId) - allocBefore + val peakHeap = AllocationProbe.peakHeapBytes() + + println( + "[MultipartStreamReaderPerfTest] payload=${AllocationProbe.fmt(payloadBytes)} " + + "elapsed=${elapsedMs}ms " + + "thread-allocated=${AllocationProbe.fmt(allocated)} " + + "peak-heap=${AllocationProbe.fmt(peakHeap)} " + + "sink-bytes=${AllocationProbe.fmt(discardingSink.bytesWritten)}" + ) + + // Correctness: every payload byte made it to the sink, none was buffered into a Buffer + // and surfaced via onChunkComplete. + assertThat(success).isTrue + assertThat(discardingSink.bytesWritten).isEqualTo(payloadBytes) + assertThat(receivedHeaders["Content-Type"]) + .isEqualTo("application/javascript; charset=UTF-8") + assertThat(bufferDeliveredViaComplete) + .`as`("Body must be streamed to the sink, not delivered as a Buffer") + .isFalse + + // Memory: peak heap must be bounded by the reader's working buffer plus a small overhead + // (class loading, JIT scratch, JMX bookkeeping), not by the payload size. The 64 MB + // ceiling leaves room for cross-machine variance while still proving the reader doesn't + // retain the payload. + // + // We deliberately do NOT assert on `thread-allocated`: okio's SegmentPool is capped at + // 64 KB, so streaming 100 MB through any pipeline (production or test) churns roughly + // 100 MB of segment allocations regardless of whether the reader is well-behaved. Peak + // heap is the property that distinguishes streaming from buffering. + assertThat(peakHeap) + .`as`("Peak heap should be O(buffer size), not O(payload)") + .isLessThan(64L * 1024 * 1024) + } + + /** + * An [okio.Sink] that counts the bytes written to it and discards them. Used so the test's + * assertion budget reflects only the reader's allocations, not a sink buffer's. + */ + private class CountingDiscardingSink : Sink { + var bytesWritten: Long = 0L + private set + + override fun write(source: Buffer, byteCount: Long) { + bytesWritten += byteCount + source.skip(byteCount) + } + + override fun flush() = Unit + + override fun timeout(): Timeout = Timeout.NONE + + override fun close() = Unit + } +} diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt index de2655d8722f..b6a05d784d33 100644 --- a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/MultipartStreamReaderTest.kt @@ -8,6 +8,7 @@ package com.facebook.react.devsupport import okio.Buffer +import okio.BufferedSink import okio.ByteString import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -36,14 +37,14 @@ class MultipartStreamReaderTest { object : CallCountTrackingChunkCallback() { override fun onChunkComplete( headers: Map, - body: Buffer, + body: Buffer?, isLastChunk: Boolean, ) { super.onChunkComplete(headers, body, isLastChunk) assertThat(isLastChunk).isTrue assertThat(headers["Content-Type"]).isEqualTo("application/json; charset=utf-8") - assertThat(body.readUtf8()).isEqualTo("{}") + assertThat(body?.readUtf8()).isEqualTo("{}") } } val success = reader.readAllParts(callback) @@ -76,13 +77,13 @@ class MultipartStreamReaderTest { object : CallCountTrackingChunkCallback() { override fun onChunkComplete( headers: Map, - body: Buffer, + body: Buffer?, isLastChunk: Boolean, ) { super.onChunkComplete(headers, body, isLastChunk) assertThat(isLastChunk).isEqualTo(callCount == 3) - assertThat(body.readUtf8()).isEqualTo("$callCount") + assertThat(body?.readUtf8()).isEqualTo("$callCount") } } val success = reader.readAllParts(callback) @@ -136,7 +137,14 @@ class MultipartStreamReaderTest { var callCount = 0 private set - override fun onChunkComplete(headers: Map, body: Buffer, isLastChunk: Boolean) { + // Buffer body in-memory so individual tests can assert on it. + override fun onChunkHeader(headers: Map): BufferedSink? = null + + override fun onChunkComplete( + headers: Map, + body: Buffer?, + isLastChunk: Boolean, + ) { callCount++ } diff --git a/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/PerformanceTest.kt b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/PerformanceTest.kt new file mode 100644 index 000000000000..104cec438235 --- /dev/null +++ b/packages/react-native/ReactAndroid/src/test/java/com/facebook/react/devsupport/PerformanceTest.kt @@ -0,0 +1,15 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.facebook.react.devsupport + +/** + * JUnit category marker for performance / memory tests that are too expensive to run on every + * change. Excluded from the default `testDebugUnitTest` task; opt-in via + * `-PrunPerfTests=true`. + */ +interface PerformanceTest