diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0e0eb1459..e5b3d7779 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -7,12 +7,14 @@ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" + target-branch: "develop" schedule: interval: "weekly" ignore: - dependency-name: "*" - package-ecosystem: gradle directory: "/" + target-branch: "develop" schedule: interval: "daily" open-pull-requests-limit: 200 diff --git a/README.md b/README.md index c1177d6cb..0f36639aa 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Root-level Markdown documents: - Choose the backend that fits the moment: your own AUTOMATIC1111 or SwarmUI server, AI Horde, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI, or local diffusion where the platform supports it. - Generate with familiar Stable Diffusion controls: prompts, negative prompts where supported, seed, steps, CFG scale, image size, model selectors, LoRA, embeddings, and more. - Use one shared mobile experience across Android and iOS for remote generation workflows. -- Work locally when privacy or connectivity matters with Android ONNX, MediaPipe, stable-diffusion.cpp SDXL, iOS Silicon Diffusion Core ML, or iOS Silicon Diffusion PrismML Bonsai. +- Work locally when privacy or connectivity matters with Android ONNX, MediaPipe, stable-diffusion.cpp SDXL, Android Local Diffusion PrismML Bonsai, iOS Silicon Diffusion Core ML, or iOS Silicon Diffusion PrismML Bonsai. - Keep your creations in a local gallery with image details, zoom, sharing, native platform save flows, and zip export. - Check local-device fit before heavy runs with the on-device benchmark, then review storage and network usage from Settings. - Stay in control: the project is open source and the app does not include ads or telemetry. @@ -49,7 +49,7 @@ Android builds are distributed in three flavors: - `full`: full GitHub/release build. - `foss`: F-Droid friendly build. -iOS uses the shared mobile experience with remote generation providers, Silicon Diffusion Core ML, and Silicon Diffusion PrismML Bonsai for on-device generation on supported devices. +iOS uses the shared mobile experience with remote generation providers, Silicon Diffusion Core ML, and Silicon Diffusion PrismML Bonsai for on-device generation on supported devices. Android also exposes Local Diffusion PrismML Bonsai on builds where the native runtime is packaged, with device suitability checked through setup and benchmark flows. | Provider / backend | What it connects to | iOS | Android `playstore` | Android `full` | Android `foss` | Notes | | --- | --- |--------| --- | --- | --- | --- | @@ -65,18 +65,18 @@ iOS uses the shared mobile experience with remote generation providers, Silicon | Local Diffusion: Google AI MediaPipe | On-device MediaPipe image generator | 🔴 No | 🟢 Yes | 🟢 Yes | 🔴 No | Android-only txt2img. Excluded from the FOSS flavor. | | Local Diffusion: stable-diffusion.cpp SDXL | On-device SDXL-compatible model inference | 🔴 No | 🟢 Yes | 🟢 Yes | 🟢 Yes | Android-only txt2img through stable-diffusion.cpp. Supports catalog GGUF/safetensors/ckpt models, CPU/OpenCL/Vulkan backend selection, and custom local model paths outside the Play build. | | Silicon Diffusion Core ML | On-device Core ML Stable Diffusion runtime | 🟢 Yes | 🔴 No | 🔴 No | 🔴 No | iOS-only txt2img and img2img with explicit downloadable/imported Core ML model assets. SDXL catalog entries are disabled until device-gated QA is stable. | -| Silicon Diffusion PrismML Bonsai | On-device PrismML Bonsai Image 4B MLX runtime | 🟢 Yes | 🔴 No | 🔴 No | 🔴 No | iOS-only txt2img with downloadable PrismML Bonsai Ternary and Binary model archives. | +| Local / Silicon Diffusion PrismML Bonsai | On-device PrismML Bonsai Image 4B runtime | 🟢 Yes | 🟢 Yes | 🟢 Yes | 🟢 Yes | iOS uses the Silicon Diffusion MLX runtime and is marked beta. Android uses the Local Diffusion NDK/ART runtime, is marked experimental, supports Auto/CPU/Vulkan backend selection, and should be device-checked before heavy runs. | ## AI Feature Matrix | AI-specific feature | Supported providers | Notes | | --- | --- | --- | -| Text to image | AUTOMATIC1111, SwarmUI, AI Horde, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI, Local ONNX, Local MediaPipe, Local SDXL, Silicon Diffusion Core ML, Silicon Diffusion PrismML Bonsai | Core generation path exists for every provider exposed by the current platform/build. | +| Text to image | AUTOMATIC1111, SwarmUI, AI Horde, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI, Local ONNX, Local MediaPipe, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Core generation path exists for every provider exposed by the current platform/build. | | Image to image | AUTOMATIC1111, SwarmUI, AI Horde, Hugging Face, Stability AI, Fal.ai, ArliAI, Silicon Diffusion Core ML | OpenAI and Android local diffusion providers are txt2img-only in the app. Core ML img2img requires a compatible downloaded model archive. | | Inpaint mask controls | AUTOMATIC1111 | Mask image, mask blur, mask mode, masked content, inpaint area, and only-masked padding are mapped to the A1111 img2img API. | -| Negative prompt | AUTOMATIC1111, SwarmUI, Hugging Face, Stability AI, ArliAI, Local ONNX, Local SDXL, Silicon Diffusion Core ML | Horde, OpenAI, and MediaPipe flows do not expose/send a negative prompt. | +| Negative prompt | AUTOMATIC1111, SwarmUI, Hugging Face, Stability AI, ArliAI, Local ONNX, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Horde, OpenAI, and MediaPipe flows do not expose/send a negative prompt. | | Batch generation | AUTOMATIC1111, SwarmUI, AI Horde, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI | Fal.ai uses native `num_images`; ArliAI uses the SDNext-compatible batch size field; local providers are treated as single-image generation flows. | -| Model or engine selection | AUTOMATIC1111, SwarmUI, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI, Local ONNX, Local MediaPipe, Local SDXL, Silicon Diffusion Core ML, Silicon Diffusion PrismML Bonsai | Depending on provider, this selects an SD checkpoint, SwarmUI model, HF model, OpenAI model, Stability engine, Fal.ai endpoint, ArliAI checkpoint, or local model. | +| Model or engine selection | AUTOMATIC1111, SwarmUI, Hugging Face, OpenAI, Stability AI, Fal.ai, ArliAI, Local ONNX, Local MediaPipe, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Depending on provider, this selects an SD checkpoint, SwarmUI model, HF model, OpenAI model, Stability engine, Fal.ai endpoint, ArliAI checkpoint, or local model. Android SDXL and Bonsai also expose runtime backend selection where supported. | | LoRA picker | AUTOMATIC1111, SwarmUI | Remote LoRA lists are fetched from the active compatible server. | | Textual inversion / embeddings picker | AUTOMATIC1111, SwarmUI | Remote embeddings are fetched from the active compatible server. | | Hypernetwork picker | AUTOMATIC1111 | Hypernetwork discovery is implemented for A1111. | @@ -85,9 +85,9 @@ iOS uses the shared mobile experience with remote generation providers, Silicon | OpenAI model, size, and quality | OpenAI | Uses current GPT Image model options exposed by the Images API. | | Stability style preset and clip guidance | Stability AI | Passed to Stability AI requests when selected. | | NSFW flag | AI Horde, Fal.ai, Silicon Diffusion Core ML | Exposed for Horde requests, mapped to Fal.ai safety-checker settings, and mapped to the local Core ML safety checker. | -| Offline generation | Local ONNX, Local MediaPipe, Local SDXL, Silicon Diffusion Core ML, Silicon Diffusion PrismML Bonsai | Runs after the selected local model is available on the current platform. | -| On-device benchmark | Local ONNX, Local MediaPipe, Local SDXL, Silicon Diffusion Core ML, Silicon Diffusion PrismML Bonsai | Runs a safe inference-like CPU and memory workload, stores the latest local result, and recommends local provider settings without loading model files or starting AI runtimes. | -| Generation interrupt | AUTOMATIC1111, AI Horde, Local ONNX, Local SDXL, Silicon Diffusion Core ML, Silicon Diffusion PrismML Bonsai | Other providers rely on request completion when no platform-level interrupt is exposed. | +| Offline generation | Local ONNX, Local MediaPipe, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Runs after the selected local model is available on the current platform. | +| On-device benchmark | Local ONNX, Local MediaPipe, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Runs a safe inference-like CPU and memory workload, includes Android Bonsai Vulkan compute probing, stores the latest local result, and recommends local provider settings without loading model files or starting AI runtimes. | +| Generation interrupt | AUTOMATIC1111, AI Horde, Local ONNX, Local SDXL, Local/Silicon Diffusion PrismML Bonsai, Silicon Diffusion Core ML | Other providers rely on request completion when no platform-level interrupt is exposed. | ## Core Workflow @@ -113,7 +113,8 @@ iOS uses the shared mobile experience with remote generation providers, Silicon - Server URL and credentials for own-server providers. - API keys for hosted providers. - Local model selection and download flow for supported local diffusion providers. -- Hardware benchmark in AI Settings for local providers, including device score, acceleration availability, estimated generation time, and recommended local settings. +- Runtime backend selectors for Android local providers where available, including SDXL Auto/CPU/OpenCL/Vulkan and Bonsai Auto/CPU/Vulkan. +- Hardware benchmark in AI Settings for local providers, including device score, acceleration availability, Android Bonsai Vulkan compute support, estimated generation time, and recommended local settings. - Server availability monitoring for compatible own-server modes. - Storage usage and network usage screens for gallery/cache/model files and provider traffic counters. @@ -176,17 +177,19 @@ Use this on iOS for on-device Stable Diffusion generation through Core ML. Downl The first-party catalog intentionally starts with Apple/Hugging Face Stable Diffusion 1.x and 2.x palettized Core ML archives. SDXL archives are not exposed by default until memory and execution-plan compatibility are stable across a tested device matrix. -### Option 13: Silicon Diffusion PrismML Bonsai +### Option 13: PrismML Bonsai Image 4B -Use this on iOS for on-device text-to-image generation through the PrismML Bonsai Image 4B MLX runtime. Download a supported Bonsai Ternary or Binary model archive from the in-app catalog, select it during setup, and generate without sending prompts to a remote service. +Use this for on-device text-to-image generation with PrismML Bonsai Image 4B. Download a supported Bonsai Ternary or Binary model archive from the in-app catalog, select it during setup, and generate without sending prompts to a remote service. -This provider is beta while the custom KMP-to-Swift runtime bridge and model QA stabilize across real iPhone devices. +On iOS the provider is displayed as Silicon Diffusion PrismML Bonsai and uses the MLX runtime. On Android it is displayed as Local Diffusion PrismML Bonsai, uses the custom NDK/ART runtime, and exposes Auto, CPU, and Vulkan backend selection from the generation form. + +iOS support is beta. Android support is experimental and device-sensitive: run the benchmark/setup checks before long generations, especially on low-RAM devices. ## Build Flavor Notes -Android flavor availability is driven by the Gradle flavor configuration and runtime provider filtering. Most network providers are available everywhere; Google AI MediaPipe is intentionally unavailable in `foss`. Local SDXL through stable-diffusion.cpp is available in `playstore`, `full`, and `foss`; its model catalog is shared, and model files are downloaded or imported by the user rather than bundled into the app. The Play build avoids custom local model path selection for local diffusion models because broad file access is not generally accepted for Google Play distribution. +Android flavor availability is driven by the Gradle flavor configuration and runtime provider filtering. Most network providers are available everywhere; Google AI MediaPipe is intentionally unavailable in `foss`. Local SDXL through stable-diffusion.cpp and Local Diffusion PrismML Bonsai are available in `playstore`, `full`, and `foss`; their model catalogs are shared, and model files are downloaded or imported by the user rather than bundled into the app. The Play build avoids custom local model path selection for local diffusion models because broad file access is not generally accepted for Google Play distribution. -The iOS app is not split into Android-style flavors. It uses the shared mobile UI, remote-provider stack, Silicon Diffusion Core ML, and Silicon Diffusion PrismML Bonsai as iOS-only local providers. Android ONNX and MediaPipe local diffusion remain Android-specific. +The iOS app is not split into Android-style flavors. It uses the shared mobile UI, remote-provider stack, Silicon Diffusion Core ML, and Silicon Diffusion PrismML Bonsai. Android ONNX, MediaPipe, SDXL, and Local Diffusion PrismML Bonsai local runtimes remain Android-specific. For a historical overview of flavor policy, see the project wiki page: [Build flavor difference](https://github.com/ShiftHackZ/Stable-Diffusion-Android/wiki/Build-flavor-difference). diff --git a/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 0d87ef808..606c2a5c2 100755 --- a/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.links.LinksProvider +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -52,6 +53,7 @@ val providersModule = module { override val buildNumber: Int = BuildConfig.VERSION_CODE override val version: BuildVersion = BuildVersion(BuildConfig.VERSION_NAME) override val type: BuildType = BuildType.fromBuildConfig(BuildConfig.BUILD_FLAVOR_TYPE) + override val platform: Platform = Platform.ANDROID override fun toString(): String = buildString { append("$version") diff --git a/build.gradle.kts b/build.gradle.kts index 15f2bbaa0..0bed028d7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -27,8 +27,11 @@ val documentedProjects = listOf( ":domain", ":feature:auth", ":feature:benchmark", + ":feature:bonsai", + ":feature:coreml", ":feature:onnx", ":feature:mediapipe", + ":feature:sdxl", ":feature:work", ":network", ":presentation", diff --git a/core/common/src/androidMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.android.kt b/core/common/src/androidMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.android.kt index 20305d0ec..c316a6f2a 100644 --- a/core/common/src/androidMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.android.kt +++ b/core/common/src/androidMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.android.kt @@ -1,9 +1,9 @@ package com.shifthackz.aisdv1.core.common.appbuild /** - * Creates the SDAI value produced by `createPlatformBuildInfoProvider`. + * Supplies Android build metadata from the generated app-level provider. * - * @return Result produced by `createPlatformBuildInfoProvider`. - * @author Dmitriy Moroz + * Core modules use a stub here so Android app variants can override the binding + * with flavor-aware metadata during dependency injection. */ actual fun createPlatformBuildInfoProvider(): BuildInfoProvider = BuildInfoProvider.stub diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt index b2a2db1f3..380169bc3 100644 --- a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt @@ -1,52 +1,33 @@ package com.shifthackz.aisdv1.core.common.appbuild +import com.shifthackz.aisdv1.core.common.platform.Platform + /** - * Defines the `BuildInfoProvider` contract for the SDAI core common layer. + * Runtime build metadata shared by domain and presentation code. * - * @author Dmitriy Moroz + * Provider filtering relies on both the distribution flavor and the current + * platform, so this abstraction is intentionally available outside platform UI + * modules. */ interface BuildInfoProvider { - /** - * Exposes the `isDebug` value used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ + val isDebug: Boolean - /** - * Exposes the `buildNumber` value used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ + val buildNumber: Int - /** - * Exposes the `version` value used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ + val version: BuildVersion - /** - * Exposes the `type` value used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ + val type: BuildType - /** - * Provides the `companion object` singleton used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ + val platform: Platform + companion object { - /** - * Exposes the `stub` value used by the SDAI core common layer. - * - * @author Dmitriy Moroz - */ val stub = object : BuildInfoProvider { override val isDebug: Boolean = true override val buildNumber: Int = 0 override val version: BuildVersion = BuildVersion() override val type: BuildType = BuildType.FOSS + override val platform: Platform = Platform.ANDROID override fun toString(): String = displayString() } diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.kt index 2b221aefd..d703bcbda 100644 --- a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.kt +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.kt @@ -1,10 +1,10 @@ package com.shifthackz.aisdv1.core.common.appbuild /** - * Creates the SDAI value produced by `createPlatformBuildInfoProvider`. + * Creates the platform-specific build metadata provider. * - * @return Result produced by `createPlatformBuildInfoProvider`. - * @author Dmitriy Moroz + * Each target fills the shared [BuildInfoProvider] contract from its native + * build system so common code can make flavor and platform decisions without + * reaching into Android or iOS APIs. */ expect fun createPlatformBuildInfoProvider(): BuildInfoProvider - diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/Platform.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/Platform.kt new file mode 100644 index 000000000..b7b73c7c7 --- /dev/null +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/Platform.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.core.common.platform + +/** + * Native platform currently running the shared Kotlin code. + */ +enum class Platform { + ANDROID, + IOS, +} diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/PlatformValue.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/PlatformValue.kt new file mode 100644 index 000000000..589f46cdf --- /dev/null +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/platform/PlatformValue.kt @@ -0,0 +1,22 @@ +package com.shifthackz.aisdv1.core.common.platform + +/** + * Small value holder for metadata that can differ between Android and iOS. + * + * It is used for provider readiness and similar catalog fields where a single + * domain enum should stay stable while the user-facing platform status differs. + */ +data class PlatformValue( + val android: T, + val ios: T, +) { + constructor(value: T) : this( + android = value, + ios = value, + ) + + operator fun get(platform: Platform): T = when (platform) { + Platform.ANDROID -> android + Platform.IOS -> ios + } +} diff --git a/core/common/src/iosMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.ios.kt b/core/common/src/iosMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.ios.kt index d946fcaaa..b4671acee 100644 --- a/core/common/src/iosMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.ios.kt +++ b/core/common/src/iosMain/kotlin/com/shifthackz/aisdv1/core/common/appbuild/PlatformBuildInfoProvider.ios.kt @@ -1,27 +1,32 @@ package com.shifthackz.aisdv1.core.common.appbuild +import com.shifthackz.aisdv1.core.common.platform.Platform import platform.Foundation.NSBundle /** - * Creates the SDAI value produced by `createPlatformBuildInfoProvider`. + * Reads iOS bundle metadata for the shared build info contract. * - * @return Result produced by `createPlatformBuildInfoProvider`. - * @author Dmitriy Moroz + * The iOS target has no Android-style flavor split, so it reports the App Store + * compatible flavor while still exposing [Platform.IOS] to common code. */ actual fun createPlatformBuildInfoProvider(): BuildInfoProvider = object : BuildInfoProvider { private val bundle = NSBundle.mainBundle override val isDebug: Boolean = false + override val buildNumber: Int = bundle .objectForInfoDictionaryKey("CFBundleVersion") .toString() .toIntOrNull() ?: 0 + override val version: BuildVersion = BuildVersion( bundle.objectForInfoDictionaryKey("CFBundleShortVersionString") as? String, ) + override val type: BuildType = BuildType.PLAY + override val platform: Platform = Platform.IOS + override fun toString(): String = displayString() } - diff --git a/core/localization/src/androidMain/res/values-ru/strings.xml b/core/localization/src/androidMain/res/values-ru/strings.xml index 9c3855d79..3b13b65d3 100644 --- a/core/localization/src/androidMain/res/values-ru/strings.xml +++ b/core/localization/src/androidMain/res/values-ru/strings.xml @@ -172,6 +172,7 @@ Endpoint Fal.ai Ускорение Рантайм бэкенд + Рантайм бэкенд Синхронный режим Исходное изображение Качество @@ -237,6 +238,7 @@ Эта конфигурация использует stable-diffusion.cpp и позволяет запускать SDXL, SDXL-Turbo, LCM или квантованные GGUF-модели на телефоне. Локальная генерация на iPhone. Локальная генерация Bonsai Image 4B на iPhone. + Локальная генерация Bonsai Image 4B на поддерживаемых устройствах. Веб diff --git a/core/localization/src/androidMain/res/values-tr/strings.xml b/core/localization/src/androidMain/res/values-tr/strings.xml index 0c3e213f2..171507c31 100644 --- a/core/localization/src/androidMain/res/values-tr/strings.xml +++ b/core/localization/src/androidMain/res/values-tr/strings.xml @@ -172,6 +172,7 @@ Fal.ai uç noktası Hızlandırma Çalışma zamanı arka ucu + Çalışma zamanı arka ucu Senkronizasyon modu Girdi görüntüsü Kalite @@ -236,7 +237,8 @@ Bu yapılandırma Google AI MediaPipe çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır. Bu yapılandırma stable-diffusion.cpp kullanır ve telefonunuzda SDXL, SDXL-Turbo, LCM veya kuantize GGUF modellerini çalıştırmanızı sağlar. iPhone üzerinde yerel üretim. - iPhone üzerinde yerel Bonsai Image 4B üretimi. + iPhone cihazında yerel Bonsai Image 4B üretimi. + Desteklenen cihazlarda yerel Bonsai Image 4B üretimi. Web arayüzü diff --git a/core/localization/src/androidMain/res/values-uk/strings.xml b/core/localization/src/androidMain/res/values-uk/strings.xml index a0aed76b1..e2e65d7e6 100644 --- a/core/localization/src/androidMain/res/values-uk/strings.xml +++ b/core/localization/src/androidMain/res/values-uk/strings.xml @@ -172,6 +172,7 @@ Endpoint Fal.ai Прискорення Рантайм бекенд + Рантайм бекенд Синхронний режим Вхідне зображення Якість @@ -237,6 +238,7 @@ Ця конфігурація використовує stable-diffusion.cpp і дає змогу запускати SDXL, SDXL-Turbo, LCM або квантовані GGUF-моделі на телефоні. Локальна генерація на iPhone. Локальна генерація Bonsai Image 4B на iPhone. + Локальна генерація Bonsai Image 4B на підтримуваних пристроях. Веб diff --git a/core/localization/src/androidMain/res/values-zh/strings.xml b/core/localization/src/androidMain/res/values-zh/strings.xml index 79cdbaafd..5b05d4a2a 100644 --- a/core/localization/src/androidMain/res/values-zh/strings.xml +++ b/core/localization/src/androidMain/res/values-zh/strings.xml @@ -188,6 +188,7 @@ Fal.ai 端点 加速 运行时后端 + 运行时后端 同步模式 输入图像 质量 @@ -262,6 +263,7 @@ 此配置使用 stable-diffusion.cpp,可在手机上运行 SDXL、SDXL-Turbo、LCM 或量化 GGUF 模型。 在 iPhone 上进行本地生成。 在 iPhone 上进行本地 Bonsai Image 4B 生成。 + 在受支持设备上进行本地 Bonsai Image 4B 生成。 网络界面 diff --git a/core/localization/src/androidMain/res/values/strings.xml b/core/localization/src/androidMain/res/values/strings.xml index dfa69abd2..628f152bb 100755 --- a/core/localization/src/androidMain/res/values/strings.xml +++ b/core/localization/src/androidMain/res/values/strings.xml @@ -117,6 +117,8 @@ MediaPipe Local Diffusion stable-diffusion.cpp SDXL (Experimental) SDXL + Silicon Diffusion PrismML Bonsai + Local Diffusion PrismML Bonsai Bonsai Hugging Face Inference HuggingFace @@ -195,6 +197,7 @@ Fal.ai endpoint Acceleration Runtime backend + Runtime backend Sync mode Input image Quality @@ -263,7 +266,10 @@ Local Diffusion SDXL This configuration uses stable-diffusion.cpp and allows to run SDXL, SDXL-Turbo, LCM, or quantized GGUF models on your phone. Local generation on iPhone. + Silicon Diffusion PrismML Bonsai + Local Diffusion PrismML Bonsai Local Bonsai Image 4B generation on iPhone. + Local Bonsai Image 4B generation on supported devices. Web UI diff --git a/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStore.kt b/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStore.kt index 58b19c73f..176c17b4c 100644 --- a/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStore.kt +++ b/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStore.kt @@ -56,8 +56,7 @@ internal class AndroidDownloadableModelFileStore( } LocalAiModel.Type.Bonsai -> { - val files = getLocalModelFiles(model.id) - files.isNotEmpty() + getLocalModelDirectory(model.id).hasBonsaiModelLayout() } } } @@ -136,3 +135,54 @@ internal class AndroidDownloadableModelFileStore( val singleFileModelExtensions = setOf("ckpt", "gguf", "safetensors") } } + +private const val TEXT_ENCODER_MLX_DIRECTORY = "text_encoder-mlx-4bit" +private const val TEXT_ENCODER_LEGACY_DIRECTORY = "text_encoder" +private const val TRANSFORMER_DIRECTORY = "transformer-packed-mflux" +private const val TOKENIZER_DIRECTORY = "tokenizer" +private const val VAE_DIRECTORY = "vae" +private const val SCHEDULER_DIRECTORY = "scheduler" +private const val RESOURCES_DIRECTORY = "Resources" +private const val EXTRACTED_DIRECTORY = "extracted" +private const val MAX_NESTED_SEARCH_DEPTH = 4 + +private fun File.hasBonsaiModelLayout(): Boolean { + if (!exists()) return false + + directBonsaiCandidates() + .any(File::isBonsaiRoot) + .let { found -> if (found) return true } + + val rootDepth = toPath().nameCount + return walkTopDown() + .filter(File::isDirectory) + .filter { candidate -> + candidate.toPath().nameCount - rootDepth <= MAX_NESTED_SEARCH_DEPTH + } + .any(File::isBonsaiRoot) +} + +private fun File.directBonsaiCandidates(): List = listOf( + this, + File(this, RESOURCES_DIRECTORY), + File(this, EXTRACTED_DIRECTORY), + File(File(this, EXTRACTED_DIRECTORY), RESOURCES_DIRECTORY), +) + +private fun File.isBonsaiRoot(): Boolean { + val quantizationConfig = File( + File(this, TRANSFORMER_DIRECTORY), + "quantization_config.json", + ) + val requiredDirectories = listOf( + File(this, TOKENIZER_DIRECTORY), + File(this, VAE_DIRECTORY), + File(this, SCHEDULER_DIRECTORY), + ) + + return quantizationConfig.isFile && + listOf(TEXT_ENCODER_MLX_DIRECTORY, TEXT_ENCODER_LEGACY_DIRECTORY) + .map { name -> File(this, name) } + .any(File::isDirectory) && + requiredDirectories.all(File::isDirectory) +} diff --git a/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/remote/AndroidDownloadableModelFileDownloader.kt b/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/remote/AndroidDownloadableModelFileDownloader.kt index 7e2673807..545a96ef8 100644 --- a/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/remote/AndroidDownloadableModelFileDownloader.kt +++ b/data/src/androidMain/kotlin/com/shifthackz/aisdv1/data/remote/AndroidDownloadableModelFileDownloader.kt @@ -63,9 +63,10 @@ internal class AndroidDownloadableModelFileDownloader( url: String, ): Flow = flow { val dir = File("${fileProviderDescriptor.localModelDirPath}/$id") + if (dir.exists()) dir.deleteRecursively() + dir.mkdirs() + val destination = File(getDestinationPath(id, url)) - if (destination.exists()) destination.delete() - if (!dir.exists()) dir.mkdirs() emit(DownloadState.Downloading(0)) try { @@ -80,10 +81,10 @@ internal class AndroidDownloadableModelFileDownloader( } emit(DownloadState.Complete(complete.path)) } catch (e: CancellationException) { - destination.delete() + dir.deleteRecursively() throw e } catch (e: Exception) { - destination.delete() + dir.deleteRecursively() emit(DownloadState.Error(e)) } }.flowOn(Dispatchers.IO) diff --git a/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStoreTest.kt b/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStoreTest.kt index ddd268dbb..438ef2fe0 100644 --- a/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStoreTest.kt +++ b/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/local/AndroidDownloadableModelFileStoreTest.kt @@ -3,6 +3,8 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.domain.entity.LocalAiModel import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder @@ -68,6 +70,39 @@ class AndroidDownloadableModelFileStoreTest { assertEquals(directory.path, actual) } + + @Test + fun `given bonsai model directory contains complete layout, is downloaded returns true`() { + val model = model(id = "bonsai-model", type = LocalAiModel.Type.Bonsai) + File(temporaryFolder.root, model.id) + .apply(File::mkdirs) + .createBonsaiLayout() + + val actual = fileStore.isDownloaded(model) + + assertTrue(actual) + } + + @Test + fun `given bonsai model directory contains invalid residue, is downloaded returns false`() { + val model = model(id = "bonsai-model", type = LocalAiModel.Type.Bonsai) + val directory = File(temporaryFolder.root, model.id).apply(File::mkdirs) + File(directory, "download-error.html").writeText("pve001 unavailable") + + val actual = fileStore.isDownloaded(model) + + assertFalse(actual) + } +} + +private fun File.createBonsaiLayout() { + File(this, "transformer-packed-mflux").apply(File::mkdirs) + .resolve("quantization_config.json") + .writeText("{}") + File(this, "text_encoder-mlx-4bit").mkdirs() + File(this, "tokenizer").mkdirs() + File(this, "vae").mkdirs() + File(this, "scheduler").mkdirs() } private fun model( diff --git a/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/IosModelArchiveValidation.kt b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/IosModelArchiveValidation.kt new file mode 100644 index 000000000..cb9048372 --- /dev/null +++ b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/IosModelArchiveValidation.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.data + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.ByteVar +import kotlinx.cinterop.allocArray +import kotlinx.cinterop.get +import kotlinx.cinterop.memScoped +import platform.posix.fclose +import platform.posix.fopen +import platform.posix.fread + +/** + * Checks that a downloaded model artifact starts with a ZIP archive signature. + * + * iOS keeps downloadable model bundles as `model.zip`, so this catches CDN or host + * failures that return an HTML/error payload with a successful transfer callback. + * + * @receiver Absolute path to the candidate archive. + * @return `true` when the file begins with a ZIP header. + * @author Dmitriy Moroz + */ +@OptIn(ExperimentalForeignApi::class) +internal fun String.hasZipArchiveSignature(): Boolean = memScoped { + val file = fopen(this@hasZipArchiveSignature, "rb") ?: return@memScoped false + return try { + val bytes = allocArray(ZIP_SIGNATURE_LENGTH) + val readBytes = fread(bytes, 1uL, ZIP_SIGNATURE_LENGTH.toULong(), file) + if (readBytes < ZIP_SIGNATURE_LENGTH.toULong()) return@memScoped false + + bytes[0] == ZIP_MAGIC_FIRST && + bytes[1] == ZIP_MAGIC_SECOND && + zipSignatureSuffixes.any { suffix -> + bytes[2] == suffix.first && bytes[3] == suffix.second + } + } finally { + fclose(file) + } +} + +private const val ZIP_SIGNATURE_LENGTH = 4 +private val ZIP_MAGIC_FIRST = 0x50.toByte() +private val ZIP_MAGIC_SECOND = 0x4B.toByte() +private val zipSignatureSuffixes = listOf( + 0x03.toByte() to 0x04.toByte(), + 0x05.toByte() to 0x06.toByte(), + 0x07.toByte() to 0x08.toByte(), +) diff --git a/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/local/IosDownloadableModelFileStore.kt b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/local/IosDownloadableModelFileStore.kt index 235c1c0c3..64ce61556 100644 --- a/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/local/IosDownloadableModelFileStore.kt +++ b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/local/IosDownloadableModelFileStore.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.data.hasZipArchiveSignature import com.shifthackz.aisdv1.domain.entity.LocalAiModel import kotlinx.cinterop.ExperimentalForeignApi import platform.Foundation.NSFileManager @@ -57,6 +58,7 @@ internal class IosDownloadableModelFileStore( private fun LocalAiModel.hasDownloadedArchive(archivePath: String): Boolean { if (!NSFileManager.defaultManager.fileExistsAtPath(path = archivePath)) return false + if (!archivePath.hasZipArchiveSignature()) return false val archiveSize = archivePath.fileSize() ?: return true val expectedSize = size.expectedSizeBytes() ?: return archiveSize > 0L diff --git a/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/remote/IosDownloadableModelFileDownloader.kt b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/remote/IosDownloadableModelFileDownloader.kt index a5b1c2d90..4401629e9 100644 --- a/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/remote/IosDownloadableModelFileDownloader.kt +++ b/data/src/iosMain/kotlin/com/shifthackz/aisdv1/data/remote/IosDownloadableModelFileDownloader.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.data.hasZipArchiveSignature import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.NetworkUsageBucket import com.shifthackz.aisdv1.domain.repository.NetworkUsageRepository @@ -98,8 +99,17 @@ internal class IosDownloadableModelFileDownloader( error = null, ) ) { - trySend(DownloadState.Downloading(100)) - trySend(DownloadState.Complete(destinationPath)) + if (destinationPath.hasZipArchiveSignature()) { + trySend(DownloadState.Downloading(100)) + trySend(DownloadState.Complete(destinationPath)) + } else { + fileManager.deleteDownloadFiles(destinationPath, temporaryPath) + trySend( + DownloadState.Error( + IllegalStateException("Downloaded model archive is not a valid zip file."), + ), + ) + } } else { fileManager.deleteDownloadFiles(destinationPath, temporaryPath) trySend( diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/BonsaiBackend.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/BonsaiBackend.kt new file mode 100644 index 000000000..a861fdc24 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/BonsaiBackend.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.domain.entity + +/** + * Runtime backend requested by Android Bonsai generation. + * + * `AUTO` lets the native runtime choose the safest available path. Explicit + * values are kept in the domain payload so foreground generation, background + * work, and persisted form state all pass the same backend key to the NDK layer. + */ +enum class BonsaiBackend( + val key: String, + val displayName: String, +) { + AUTO("auto", "AUTO"), + CPU("cpu", "CPU"), + VULKAN("vulkan", "Vulkan"), + ; + + companion object { + fun parse(value: String?): BonsaiBackend { + return entries.firstOrNull { backend -> + backend.key == value || backend.name == value + } ?: AUTO + } + } +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 69b1e761d..6295ce406 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -1,6 +1,8 @@ package com.shifthackz.aisdv1.domain.entity import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.core.common.platform.Platform +import com.shifthackz.aisdv1.core.common.platform.PlatformValue /** * Provider catalog entry used by setup, onboarding, and settings screens. @@ -15,15 +17,16 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildType enum class ServerSource( val key: String, val type: ServerSourceType, - val readiness: ServerSourceReadiness, + val readiness: PlatformValue, val version: String, val featureTags: Set, val allowedInBuilds: Set = setOf(BuildType.FOSS, BuildType.PLAY, BuildType.FULL), + val allowedPlatforms: Set = setOf(Platform.ANDROID, Platform.IOS), ) { AUTOMATIC1111( key = "custom", type = ServerSourceType.SELF_HOSTED, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -39,7 +42,7 @@ enum class ServerSource( SWARM_UI( key = "swarm_ui", type = ServerSourceType.SELF_HOSTED, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -54,18 +57,19 @@ enum class ServerSource( LOCAL_MICROSOFT_ONNX( key = "local", type = ServerSourceType.LOCAL, - readiness = ServerSourceReadiness.BETA, + readiness = PlatformValue(ServerSourceReadiness.BETA), version = "2024.9.23", featureTags = setOf( FeatureTag.Offline, FeatureTag.Txt2Img, FeatureTag.MultipleModels, ), + allowedPlatforms = setOf(Platform.ANDROID), ), LOCAL_GOOGLE_MEDIA_PIPE( key = "local_google_media_pipe", type = ServerSourceType.LOCAL, - readiness = ServerSourceReadiness.BETA, + readiness = PlatformValue(ServerSourceReadiness.BETA), version = "2026.6.10", featureTags = setOf( FeatureTag.Offline, @@ -73,22 +77,27 @@ enum class ServerSource( FeatureTag.MultipleModels, ), allowedInBuilds = setOf(BuildType.PLAY, BuildType.FULL), + allowedPlatforms = setOf(Platform.ANDROID), ), LOCAL_STABLE_DIFFUSION_CPP( key = "local_stable_diffusion_cpp", type = ServerSourceType.LOCAL, - readiness = ServerSourceReadiness.ALPHA, + readiness = PlatformValue(ServerSourceReadiness.ALPHA), version = "2026.6.13", featureTags = setOf( FeatureTag.Offline, FeatureTag.Txt2Img, FeatureTag.MultipleModels, ), + allowedPlatforms = setOf(Platform.ANDROID), ), LOCAL_APPLE_CORE_ML( key = "local_apple_core_ml", type = ServerSourceType.LOCAL, - readiness = ServerSourceReadiness.ALPHA, + readiness = PlatformValue( + ios = ServerSourceReadiness.ALPHA, + android = ServerSourceReadiness.EXPERIMENTAL + ), version = "2026.6.12", featureTags = setOf( FeatureTag.Offline, @@ -97,12 +106,16 @@ enum class ServerSource( FeatureTag.MultipleModels, FeatureTag.Batch, ), + allowedPlatforms = setOf(Platform.IOS), ), LOCAL_APPLE_BONSAI( key = "local_apple_bonsai", type = ServerSourceType.LOCAL, - readiness = ServerSourceReadiness.BETA, - version = "2026.6.15", + readiness = PlatformValue( + ios = ServerSourceReadiness.BETA, + android = ServerSourceReadiness.EXPERIMENTAL, + ), + version = "2026.6.20", featureTags = setOf( FeatureTag.Offline, FeatureTag.Txt2Img, @@ -112,7 +125,7 @@ enum class ServerSource( HORDE( key = "horde", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -123,7 +136,7 @@ enum class ServerSource( HUGGING_FACE( key = "hugging_face", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -135,7 +148,7 @@ enum class ServerSource( OPEN_AI( key = "open_ai", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -146,7 +159,7 @@ enum class ServerSource( STABILITY_AI( key = "stability_ai", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.STABLE, + readiness = PlatformValue(ServerSourceReadiness.STABLE), version = "2026.6.10", featureTags = setOf( FeatureTag.Txt2Img, @@ -158,7 +171,7 @@ enum class ServerSource( FAL_AI( key = "fal_ai", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.ALPHA, + readiness = PlatformValue(ServerSourceReadiness.ALPHA), version = "2026.6.11", featureTags = setOf( FeatureTag.Txt2Img, @@ -170,7 +183,7 @@ enum class ServerSource( ARLI_AI( key = "arli_ai", type = ServerSourceType.CLOUD, - readiness = ServerSourceReadiness.ALPHA, + readiness = PlatformValue(ServerSourceReadiness.ALPHA), version = "2026.6.13", featureTags = setOf( FeatureTag.Txt2Img, diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt index a555b16c7..435cf35e1 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt @@ -1,177 +1,41 @@ package com.shifthackz.aisdv1.domain.entity /** - * Carries `TextToImagePayload` data through the SDAI domain layer. + * Provider-neutral request model for text-to-image generation. * - * @author Dmitriy Moroz + * The payload intentionally carries superset fields from remote APIs and local + * runtimes. Repository and feature implementations pick the fields relevant to + * their provider, including Android local runtime backend choices such as SDXL + * and Bonsai. */ data class TextToImagePayload( - /** - * Exposes the `prompt` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `samplingSteps` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val samplingSteps: Int, - /** - * Exposes the `cfgScale` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val cfgScale: Float, - /** - * Exposes the `width` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val width: Int, - /** - * Exposes the `height` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val height: Int, - /** - * Exposes the `restoreFaces` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val restoreFaces: Boolean, - /** - * Exposes the `seed` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val seed: String, - /** - * Exposes the `subSeed` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val subSeed: String, - /** - * Exposes the `subSeedStrength` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val subSeedStrength: Float, - /** - * Exposes the `sampler` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val sampler: String, - /** - * Exposes the `scheduler` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val scheduler: Scheduler = Scheduler.AUTOMATIC, - /** - * Exposes the `nsfw` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val nsfw: Boolean, - /** - * Exposes the `batchCount` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val batchCount: Int, - /** - * Exposes the `style` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val style: String?, - /** - * Exposes the `quality` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val quality: String?, - /** - * Exposes the `openAiModel` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val openAiModel: OpenAiModel?, - /** - * Exposes the `stabilityAiClipGuidance` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val stabilityAiClipGuidance: StabilityAiClipGuidance?, - /** - * Exposes the `stabilityAiStylePreset` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val stabilityAiStylePreset: StabilityAiStylePreset?, - /** - * Exposes the `aDetailer` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val aDetailer: ADetailerConfig = ADetailerConfig.DISABLED, - /** - * Exposes the `hires` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val hires: HiresConfig = HiresConfig.DISABLED, - /** - * Exposes the `forgeModules` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val forgeModules: List = emptyList(), - /** - * Exposes the `falAiModel` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val falAiModel: FalAiModel = FalAiModel.defaultTextToImage, - /** - * Exposes the `falAiImageSize` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val falAiImageSize: FalAiImageSize = FalAiImageSize.default, - /** - * Exposes the `falAiAcceleration` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val falAiAcceleration: FalAiAcceleration = FalAiAcceleration.default, - /** - * Exposes the `sdxlBackend` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val sdxlBackend: SdxlBackend = SdxlBackend.AUTO, - /** - * Exposes the `falAiSyncMode` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ + val bonsaiBackend: BonsaiBackend = BonsaiBackend.AUTO, val falAiSyncMode: Boolean = false, - /** - * Exposes the `arliAiModel` value used by the SDAI domain layer. - * - * @author Dmitriy Moroz - */ val arliAiModel: String = "", ) diff --git a/feature/benchmark/build.gradle.kts b/feature/benchmark/build.gradle.kts index 481a3eb8c..6c393c987 100644 --- a/feature/benchmark/build.gradle.kts +++ b/feature/benchmark/build.gradle.kts @@ -4,6 +4,21 @@ plugins { android { namespace = "com.shifthackz.aisdv1.feature.benchmark" + + defaultConfig { + externalNativeBuild { + cmake { + arguments += "-DANDROID_STL=c++_shared" + } + } + } + + externalNativeBuild { + cmake { + path = file("src/androidMain/cpp/CMakeLists.txt") + version = "3.22.1" + } + } } kotlin { diff --git a/feature/benchmark/src/androidMain/cpp/CMakeLists.txt b/feature/benchmark/src/androidMain/cpp/CMakeLists.txt new file mode 100644 index 000000000..b8c9c89ff --- /dev/null +++ b/feature/benchmark/src/androidMain/cpp/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 3.22.1) + +project(sdai_benchmark_runtime LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +add_library( + sdai_benchmark + SHARED + benchmark_vulkan_probe.cpp +) + +target_link_libraries( + sdai_benchmark + PRIVATE + log + vulkan +) + +target_link_options( + sdai_benchmark + PRIVATE + "-Wl,-z,max-page-size=16384" + "-Wl,-z,common-page-size=16384" +) + +target_compile_options( + sdai_benchmark + PRIVATE + -Wall + -Wextra + -Wno-unused-parameter +) diff --git a/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_probe.cpp b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_probe.cpp new file mode 100644 index 000000000..2b10d9ad6 --- /dev/null +++ b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_probe.cpp @@ -0,0 +1,780 @@ +#include + +#include "benchmark_vulkan_smoke_spv.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Benchmark"; +constexpr uint64_t BYTES_IN_MB = 1024ULL * 1024ULL; +constexpr uint32_t MIN_STORAGE_BUFFER_RANGE = 64U * 1024U * 1024U; +constexpr uint32_t SMOKE_VALUE_COUNT = 256U; +constexpr uint32_t SMOKE_LOCAL_SIZE = 64U; + +struct VulkanSmokeResult { + bool ok = false; + double elapsed_ms = 0.0; + std::string reason; +}; + +struct VulkanSmokeBuffer { + VkBuffer buffer = VK_NULL_HANDLE; + VkDeviceMemory memory = VK_NULL_HANDLE; + VkDeviceSize size = 0; + bool coherent = false; +}; + +struct VulkanSmokeResources { + VkDevice device = VK_NULL_HANDLE; + VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE; + VkPipelineLayout pipeline_layout = VK_NULL_HANDLE; + VkShaderModule shader_module = VK_NULL_HANDLE; + VkPipeline pipeline = VK_NULL_HANDLE; + VkDescriptorPool descriptor_pool = VK_NULL_HANDLE; + VkCommandPool command_pool = VK_NULL_HANDLE; + VkFence fence = VK_NULL_HANDLE; + VulkanSmokeBuffer input; + VulkanSmokeBuffer output; + + ~VulkanSmokeResources() { + if (device == VK_NULL_HANDLE) { + return; + } + if (fence != VK_NULL_HANDLE) { + vkDestroyFence(device, fence, nullptr); + } + if (command_pool != VK_NULL_HANDLE) { + vkDestroyCommandPool(device, command_pool, nullptr); + } + if (descriptor_pool != VK_NULL_HANDLE) { + vkDestroyDescriptorPool(device, descriptor_pool, nullptr); + } + if (pipeline != VK_NULL_HANDLE) { + vkDestroyPipeline(device, pipeline, nullptr); + } + if (shader_module != VK_NULL_HANDLE) { + vkDestroyShaderModule(device, shader_module, nullptr); + } + if (pipeline_layout != VK_NULL_HANDLE) { + vkDestroyPipelineLayout(device, pipeline_layout, nullptr); + } + if (descriptor_set_layout != VK_NULL_HANDLE) { + vkDestroyDescriptorSetLayout(device, descriptor_set_layout, nullptr); + } + if (input.buffer != VK_NULL_HANDLE) { + vkDestroyBuffer(device, input.buffer, nullptr); + } + if (input.memory != VK_NULL_HANDLE) { + vkFreeMemory(device, input.memory, nullptr); + } + if (output.buffer != VK_NULL_HANDLE) { + vkDestroyBuffer(device, output.buffer, nullptr); + } + if (output.memory != VK_NULL_HANDLE) { + vkFreeMemory(device, output.memory, nullptr); + } + vkDestroyDevice(device, nullptr); + } +}; + +std::string bool_value(bool value) { + return value ? "true" : "false"; +} + +std::string version_string(uint32_t version) { + std::ostringstream output; + output << VK_VERSION_MAJOR(version) << "." + << VK_VERSION_MINOR(version) << "." + << VK_VERSION_PATCH(version); + return output.str(); +} + +std::string vk_result_name(VkResult result) { + switch (result) { + case VK_SUCCESS: + return "VK_SUCCESS"; + case VK_NOT_READY: + return "VK_NOT_READY"; + case VK_TIMEOUT: + return "VK_TIMEOUT"; + case VK_EVENT_SET: + return "VK_EVENT_SET"; + case VK_EVENT_RESET: + return "VK_EVENT_RESET"; + case VK_INCOMPLETE: + return "VK_INCOMPLETE"; + case VK_ERROR_OUT_OF_HOST_MEMORY: + return "VK_ERROR_OUT_OF_HOST_MEMORY"; + case VK_ERROR_OUT_OF_DEVICE_MEMORY: + return "VK_ERROR_OUT_OF_DEVICE_MEMORY"; + case VK_ERROR_INITIALIZATION_FAILED: + return "VK_ERROR_INITIALIZATION_FAILED"; + case VK_ERROR_DEVICE_LOST: + return "VK_ERROR_DEVICE_LOST"; + case VK_ERROR_MEMORY_MAP_FAILED: + return "VK_ERROR_MEMORY_MAP_FAILED"; + case VK_ERROR_LAYER_NOT_PRESENT: + return "VK_ERROR_LAYER_NOT_PRESENT"; + case VK_ERROR_EXTENSION_NOT_PRESENT: + return "VK_ERROR_EXTENSION_NOT_PRESENT"; + case VK_ERROR_FEATURE_NOT_PRESENT: + return "VK_ERROR_FEATURE_NOT_PRESENT"; + case VK_ERROR_INCOMPATIBLE_DRIVER: + return "VK_ERROR_INCOMPATIBLE_DRIVER"; + case VK_ERROR_TOO_MANY_OBJECTS: + return "VK_ERROR_TOO_MANY_OBJECTS"; + case VK_ERROR_FORMAT_NOT_SUPPORTED: + return "VK_ERROR_FORMAT_NOT_SUPPORTED"; + case VK_ERROR_FRAGMENTED_POOL: + return "VK_ERROR_FRAGMENTED_POOL"; + default: + return "VK_RESULT_" + std::to_string(static_cast(result)); + } +} + +uint32_t loader_api_version() { + uint32_t version = VK_API_VERSION_1_0; + const auto enumerate_instance_version = + reinterpret_cast( + vkGetInstanceProcAddr(nullptr, "vkEnumerateInstanceVersion") + ); + if (enumerate_instance_version != nullptr && + enumerate_instance_version(&version) != VK_SUCCESS) { + return VK_API_VERSION_1_0; + } + return version; +} + +bool find_compute_queue_family(VkPhysicalDevice physical_device, uint32_t& queue_family_index) { + uint32_t queue_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_count, nullptr); + if (queue_count == 0) { + return false; + } + + std::vector queues(queue_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_count, queues.data()); + for (uint32_t index = 0; index < queue_count; ++index) { + const VkQueueFamilyProperties& queue = queues[index]; + if ((queue.queueFlags & VK_QUEUE_COMPUTE_BIT) != 0 && queue.queueCount > 0) { + queue_family_index = index; + return true; + } + } + return false; +} + +uint64_t device_local_heap_mb(VkPhysicalDevice physical_device) { + VkPhysicalDeviceMemoryProperties memory_properties {}; + vkGetPhysicalDeviceMemoryProperties(physical_device, &memory_properties); + uint64_t total = 0; + for (uint32_t index = 0; index < memory_properties.memoryHeapCount; ++index) { + const VkMemoryHeap& heap = memory_properties.memoryHeaps[index]; + if ((heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) != 0) { + total += heap.size; + } + } + return total / BYTES_IN_MB; +} + +int32_t find_memory_type( + VkPhysicalDevice physical_device, + uint32_t type_bits, + VkMemoryPropertyFlags required_properties +) { + VkPhysicalDeviceMemoryProperties memory_properties {}; + vkGetPhysicalDeviceMemoryProperties(physical_device, &memory_properties); + for (uint32_t index = 0; index < memory_properties.memoryTypeCount; ++index) { + const bool type_supported = (type_bits & (1U << index)) != 0; + const bool properties_supported = + (memory_properties.memoryTypes[index].propertyFlags & required_properties) == + required_properties; + if (type_supported && properties_supported) { + return static_cast(index); + } + } + return -1; +} + +VulkanSmokeResult smoke_error(const std::string& reason) { + return VulkanSmokeResult { false, 0.0, reason }; +} + +bool create_host_storage_buffer( + VkPhysicalDevice physical_device, + VulkanSmokeResources& resources, + VkDeviceSize size, + VulkanSmokeBuffer& output, + std::string& reason +) { + output.size = size; + + VkBufferCreateInfo buffer_info {}; + buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + buffer_info.size = size; + buffer_info.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + + VkResult result = vkCreateBuffer(resources.device, &buffer_info, nullptr, &output.buffer); + if (result != VK_SUCCESS) { + reason = "create_buffer_failed:" + vk_result_name(result); + return false; + } + + VkMemoryRequirements requirements {}; + vkGetBufferMemoryRequirements(resources.device, output.buffer, &requirements); + int32_t memory_type = find_memory_type( + physical_device, + requirements.memoryTypeBits, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT + ); + output.coherent = memory_type >= 0; + if (memory_type < 0) { + memory_type = find_memory_type( + physical_device, + requirements.memoryTypeBits, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT + ); + } + if (memory_type < 0) { + reason = "missing_host_visible_storage_memory"; + return false; + } + + VkMemoryAllocateInfo allocate_info {}; + allocate_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + allocate_info.allocationSize = requirements.size; + allocate_info.memoryTypeIndex = static_cast(memory_type); + result = vkAllocateMemory(resources.device, &allocate_info, nullptr, &output.memory); + if (result != VK_SUCCESS) { + reason = "allocate_memory_failed:" + vk_result_name(result); + return false; + } + + result = vkBindBufferMemory(resources.device, output.buffer, output.memory, 0); + if (result != VK_SUCCESS) { + reason = "bind_buffer_failed:" + vk_result_name(result); + return false; + } + return true; +} + +bool write_buffer( + const VulkanSmokeResources& resources, + const VulkanSmokeBuffer& buffer, + const void* source, + size_t byte_count, + std::string& reason +) { + void* mapped = nullptr; + VkResult result = vkMapMemory(resources.device, buffer.memory, 0, byte_count, 0, &mapped); + if (result != VK_SUCCESS) { + reason = "map_write_failed:" + vk_result_name(result); + return false; + } + std::memcpy(mapped, source, byte_count); + if (!buffer.coherent) { + VkMappedMemoryRange range {}; + range.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + range.memory = buffer.memory; + range.offset = 0; + range.size = VK_WHOLE_SIZE; + result = vkFlushMappedMemoryRanges(resources.device, 1, &range); + if (result != VK_SUCCESS) { + vkUnmapMemory(resources.device, buffer.memory); + reason = "flush_write_failed:" + vk_result_name(result); + return false; + } + } + vkUnmapMemory(resources.device, buffer.memory); + return true; +} + +bool read_buffer( + const VulkanSmokeResources& resources, + const VulkanSmokeBuffer& buffer, + void* destination, + size_t byte_count, + std::string& reason +) { + if (!buffer.coherent) { + VkMappedMemoryRange range {}; + range.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + range.memory = buffer.memory; + range.offset = 0; + range.size = VK_WHOLE_SIZE; + VkResult result = vkInvalidateMappedMemoryRanges(resources.device, 1, &range); + if (result != VK_SUCCESS) { + reason = "invalidate_read_failed:" + vk_result_name(result); + return false; + } + } + + void* mapped = nullptr; + VkResult result = vkMapMemory(resources.device, buffer.memory, 0, byte_count, 0, &mapped); + if (result != VK_SUCCESS) { + reason = "map_read_failed:" + vk_result_name(result); + return false; + } + std::memcpy(destination, mapped, byte_count); + vkUnmapMemory(resources.device, buffer.memory); + return true; +} + +bool create_smoke_pipeline(VulkanSmokeResources& resources, std::string& reason) { + VkDescriptorSetLayoutBinding bindings[2] {}; + bindings[0].binding = 0; + bindings[0].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bindings[0].descriptorCount = 1; + bindings[0].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bindings[1].binding = 1; + bindings[1].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bindings[1].descriptorCount = 1; + bindings[1].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + + VkDescriptorSetLayoutCreateInfo layout_info {}; + layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + layout_info.bindingCount = 2; + layout_info.pBindings = bindings; + VkResult result = vkCreateDescriptorSetLayout( + resources.device, + &layout_info, + nullptr, + &resources.descriptor_set_layout + ); + if (result != VK_SUCCESS) { + reason = "create_descriptor_set_layout_failed:" + vk_result_name(result); + return false; + } + + VkPushConstantRange push_constant {}; + push_constant.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + push_constant.offset = 0; + push_constant.size = sizeof(uint32_t); + + VkPipelineLayoutCreateInfo pipeline_layout_info {}; + pipeline_layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + pipeline_layout_info.setLayoutCount = 1; + pipeline_layout_info.pSetLayouts = &resources.descriptor_set_layout; + pipeline_layout_info.pushConstantRangeCount = 1; + pipeline_layout_info.pPushConstantRanges = &push_constant; + result = vkCreatePipelineLayout( + resources.device, + &pipeline_layout_info, + nullptr, + &resources.pipeline_layout + ); + if (result != VK_SUCCESS) { + reason = "create_pipeline_layout_failed:" + vk_result_name(result); + return false; + } + + VkShaderModuleCreateInfo shader_info {}; + shader_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_info.codeSize = kBenchmarkVulkanSmokeSpvSize; + shader_info.pCode = reinterpret_cast(kBenchmarkVulkanSmokeSpv); + result = vkCreateShaderModule( + resources.device, + &shader_info, + nullptr, + &resources.shader_module + ); + if (result != VK_SUCCESS) { + reason = "create_shader_module_failed:" + vk_result_name(result); + return false; + } + + VkComputePipelineCreateInfo pipeline_info {}; + pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeline_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeline_info.stage.module = resources.shader_module; + pipeline_info.stage.pName = "main"; + pipeline_info.layout = resources.pipeline_layout; + result = vkCreateComputePipelines( + resources.device, + VK_NULL_HANDLE, + 1, + &pipeline_info, + nullptr, + &resources.pipeline + ); + if (result != VK_SUCCESS) { + reason = "create_compute_pipeline_failed:" + vk_result_name(result); + return false; + } + return true; +} + +VulkanSmokeResult run_compute_smoke( + VkPhysicalDevice physical_device, + uint32_t queue_family_index +) { + VulkanSmokeResources resources; + + const float priority = 1.0F; + VkDeviceQueueCreateInfo queue_info {}; + queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_info.queueFamilyIndex = queue_family_index; + queue_info.queueCount = 1; + queue_info.pQueuePriorities = &priority; + + VkDeviceCreateInfo device_info {}; + device_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_info.queueCreateInfoCount = 1; + device_info.pQueueCreateInfos = &queue_info; + + VkResult result = vkCreateDevice(physical_device, &device_info, nullptr, &resources.device); + if (result != VK_SUCCESS) { + return smoke_error("create_device_failed:" + vk_result_name(result)); + } + + VkQueue queue = VK_NULL_HANDLE; + vkGetDeviceQueue(resources.device, queue_family_index, 0, &queue); + if (queue == VK_NULL_HANDLE) { + return smoke_error("get_compute_queue_failed"); + } + + const VkDeviceSize byte_count = SMOKE_VALUE_COUNT * sizeof(float); + std::string reason; + if (!create_host_storage_buffer(physical_device, resources, byte_count, resources.input, reason) || + !create_host_storage_buffer(physical_device, resources, byte_count, resources.output, reason)) { + return smoke_error(reason); + } + + std::vector input(SMOKE_VALUE_COUNT); + std::vector zeroes(SMOKE_VALUE_COUNT, 0.0F); + for (uint32_t index = 0; index < SMOKE_VALUE_COUNT; ++index) { + input[index] = static_cast(index) * 0.25F - 7.0F; + } + if (!write_buffer(resources, resources.input, input.data(), static_cast(byte_count), reason) || + !write_buffer(resources, resources.output, zeroes.data(), static_cast(byte_count), reason)) { + return smoke_error(reason); + } + + if (!create_smoke_pipeline(resources, reason)) { + return smoke_error(reason); + } + + VkDescriptorPoolSize pool_size {}; + pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + pool_size.descriptorCount = 2; + + VkDescriptorPoolCreateInfo pool_info {}; + pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + pool_info.maxSets = 1; + pool_info.poolSizeCount = 1; + pool_info.pPoolSizes = &pool_size; + result = vkCreateDescriptorPool(resources.device, &pool_info, nullptr, &resources.descriptor_pool); + if (result != VK_SUCCESS) { + return smoke_error("create_descriptor_pool_failed:" + vk_result_name(result)); + } + + VkDescriptorSet descriptor_set = VK_NULL_HANDLE; + VkDescriptorSetAllocateInfo descriptor_allocate {}; + descriptor_allocate.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + descriptor_allocate.descriptorPool = resources.descriptor_pool; + descriptor_allocate.descriptorSetCount = 1; + descriptor_allocate.pSetLayouts = &resources.descriptor_set_layout; + result = vkAllocateDescriptorSets(resources.device, &descriptor_allocate, &descriptor_set); + if (result != VK_SUCCESS) { + return smoke_error("allocate_descriptor_set_failed:" + vk_result_name(result)); + } + + VkDescriptorBufferInfo input_info {}; + input_info.buffer = resources.input.buffer; + input_info.offset = 0; + input_info.range = byte_count; + VkDescriptorBufferInfo output_info {}; + output_info.buffer = resources.output.buffer; + output_info.offset = 0; + output_info.range = byte_count; + + VkWriteDescriptorSet descriptor_writes[2] {}; + descriptor_writes[0].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + descriptor_writes[0].dstSet = descriptor_set; + descriptor_writes[0].dstBinding = 0; + descriptor_writes[0].descriptorCount = 1; + descriptor_writes[0].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + descriptor_writes[0].pBufferInfo = &input_info; + descriptor_writes[1].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + descriptor_writes[1].dstSet = descriptor_set; + descriptor_writes[1].dstBinding = 1; + descriptor_writes[1].descriptorCount = 1; + descriptor_writes[1].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + descriptor_writes[1].pBufferInfo = &output_info; + vkUpdateDescriptorSets(resources.device, 2, descriptor_writes, 0, nullptr); + + VkCommandPoolCreateInfo command_pool_info {}; + command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + command_pool_info.queueFamilyIndex = queue_family_index; + result = vkCreateCommandPool(resources.device, &command_pool_info, nullptr, &resources.command_pool); + if (result != VK_SUCCESS) { + return smoke_error("create_command_pool_failed:" + vk_result_name(result)); + } + + VkCommandBuffer command_buffer = VK_NULL_HANDLE; + VkCommandBufferAllocateInfo command_allocate {}; + command_allocate.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + command_allocate.commandPool = resources.command_pool; + command_allocate.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + command_allocate.commandBufferCount = 1; + result = vkAllocateCommandBuffers(resources.device, &command_allocate, &command_buffer); + if (result != VK_SUCCESS) { + return smoke_error("allocate_command_buffer_failed:" + vk_result_name(result)); + } + + VkCommandBufferBeginInfo begin_info {}; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + result = vkBeginCommandBuffer(command_buffer, &begin_info); + if (result != VK_SUCCESS) { + return smoke_error("begin_command_buffer_failed:" + vk_result_name(result)); + } + + vkCmdBindPipeline(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, resources.pipeline); + vkCmdBindDescriptorSets( + command_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + resources.pipeline_layout, + 0, + 1, + &descriptor_set, + 0, + nullptr + ); + const uint32_t value_count = SMOKE_VALUE_COUNT; + vkCmdPushConstants( + command_buffer, + resources.pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(uint32_t), + &value_count + ); + vkCmdDispatch( + command_buffer, + (SMOKE_VALUE_COUNT + SMOKE_LOCAL_SIZE - 1U) / SMOKE_LOCAL_SIZE, + 1, + 1 + ); + + VkBufferMemoryBarrier barrier {}; + barrier.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + barrier.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT; + barrier.dstAccessMask = VK_ACCESS_HOST_READ_BIT; + barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.buffer = resources.output.buffer; + barrier.offset = 0; + barrier.size = byte_count; + vkCmdPipelineBarrier( + command_buffer, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_HOST_BIT, + 0, + 0, + nullptr, + 1, + &barrier, + 0, + nullptr + ); + + result = vkEndCommandBuffer(command_buffer); + if (result != VK_SUCCESS) { + return smoke_error("end_command_buffer_failed:" + vk_result_name(result)); + } + + VkFenceCreateInfo fence_info {}; + fence_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + result = vkCreateFence(resources.device, &fence_info, nullptr, &resources.fence); + if (result != VK_SUCCESS) { + return smoke_error("create_fence_failed:" + vk_result_name(result)); + } + + VkSubmitInfo submit_info {}; + submit_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submit_info.commandBufferCount = 1; + submit_info.pCommandBuffers = &command_buffer; + const auto start = std::chrono::steady_clock::now(); + result = vkQueueSubmit(queue, 1, &submit_info, resources.fence); + if (result != VK_SUCCESS) { + return smoke_error("queue_submit_failed:" + vk_result_name(result)); + } + result = vkWaitForFences(resources.device, 1, &resources.fence, VK_TRUE, 5'000'000'000ULL); + const auto end = std::chrono::steady_clock::now(); + if (result != VK_SUCCESS) { + return smoke_error("wait_fence_failed:" + vk_result_name(result)); + } + + std::vector output(SMOKE_VALUE_COUNT); + if (!read_buffer(resources, resources.output, output.data(), static_cast(byte_count), reason)) { + return smoke_error(reason); + } + + for (uint32_t index = 0; index < SMOKE_VALUE_COUNT; ++index) { + const float expected = input[index] * 2.0F + 1.0F; + if (std::fabs(output[index] - expected) > 0.0001F) { + std::ostringstream mismatch; + mismatch << "compute_mismatch:index=" << index + << ",expected=" << expected + << ",actual=" << output[index]; + return smoke_error(mismatch.str()); + } + } + + const double elapsed_ms = + std::chrono::duration(end - start).count(); + return VulkanSmokeResult { true, elapsed_ms, "ok" }; +} + +std::string probe_physical_device(VkPhysicalDevice physical_device) { + VkPhysicalDeviceProperties properties {}; + vkGetPhysicalDeviceProperties(physical_device, &properties); + + VkPhysicalDeviceFeatures features {}; + vkGetPhysicalDeviceFeatures(physical_device, &features); + + const bool api_1_1 = properties.apiVersion >= VK_API_VERSION_1_1; + uint32_t queue_family_index = 0; + const bool compute_queue = find_compute_queue_family(physical_device, queue_family_index); + const bool storage_range_ok = + properties.limits.maxStorageBufferRange >= MIN_STORAGE_BUFFER_RANGE; + const bool baseline_usable = api_1_1 && compute_queue && storage_range_ok; + const VulkanSmokeResult smoke_result = baseline_usable + ? run_compute_smoke(physical_device, queue_family_index) + : VulkanSmokeResult {}; + const bool usable = baseline_usable && smoke_result.ok; + + std::ostringstream output; + output << "usable=" << bool_value(usable) + << ";apiDetected=true" + << ";device=" << properties.deviceName + << ";api=" << version_string(properties.apiVersion) + << ";driver=" << version_string(properties.driverVersion) + << ";computeQueue=" << bool_value(compute_queue) + << ";maxStorageBufferRangeMb=" + << (properties.limits.maxStorageBufferRange / BYTES_IN_MB) + << ";maxComputeWorkGroupInvocations=" + << properties.limits.maxComputeWorkGroupInvocations + << ";maxComputeWorkGroupSize=" + << properties.limits.maxComputeWorkGroupSize[0] << "x" + << properties.limits.maxComputeWorkGroupSize[1] << "x" + << properties.limits.maxComputeWorkGroupSize[2] + << ";maxPushConstantsSize=" << properties.limits.maxPushConstantsSize + << ";deviceLocalHeapMb=" << device_local_heap_mb(physical_device) + << ";shaderInt16=" << bool_value(features.shaderInt16) + << ";computeSmoke=" << bool_value(smoke_result.ok) + << ";computeSmokeMs=" << smoke_result.elapsed_ms + << ";reason="; + + if (usable) { + output << smoke_result.reason; + } else if (!api_1_1) { + output << "requires_vulkan_1_1"; + } else if (!compute_queue) { + output << "missing_compute_queue"; + } else if (!storage_range_ok) { + output << "small_storage_buffer_range"; + } else if (!smoke_result.ok) { + output << smoke_result.reason; + } else { + output << "unknown"; + } + return output.str(); +} + +std::string probe_vulkan() { + const uint32_t loader_version = loader_api_version(); + VkApplicationInfo app_info {}; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pApplicationName = "SDAI Bonsai Vulkan Probe"; + app_info.applicationVersion = 1; + app_info.pEngineName = "SDAI"; + app_info.engineVersion = 1; + app_info.apiVersion = loader_version >= VK_API_VERSION_1_1 + ? VK_API_VERSION_1_1 + : VK_API_VERSION_1_0; + + VkInstanceCreateInfo instance_info {}; + instance_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instance_info.pApplicationInfo = &app_info; + + VkInstance instance = VK_NULL_HANDLE; + const VkResult create_result = vkCreateInstance(&instance_info, nullptr, &instance); + if (create_result != VK_SUCCESS) { + std::ostringstream output; + output << "usable=false;apiDetected=false;loaderApi=" + << version_string(loader_version) + << ";reason=create_instance_failed:" + << vk_result_name(create_result); + return output.str(); + } + + uint32_t device_count = 0; + VkResult enumerate_result = + vkEnumeratePhysicalDevices(instance, &device_count, nullptr); + if (enumerate_result != VK_SUCCESS || device_count == 0) { + vkDestroyInstance(instance, nullptr); + std::ostringstream output; + output << "usable=false;apiDetected=false;loaderApi=" + << version_string(loader_version) + << ";reason=no_physical_devices:" + << vk_result_name(enumerate_result); + return output.str(); + } + + std::vector devices(device_count); + enumerate_result = + vkEnumeratePhysicalDevices(instance, &device_count, devices.data()); + if (enumerate_result != VK_SUCCESS) { + vkDestroyInstance(instance, nullptr); + std::ostringstream output; + output << "usable=false;apiDetected=false;loaderApi=" + << version_string(loader_version) + << ";reason=enumerate_devices_failed:" + << vk_result_name(enumerate_result); + return output.str(); + } + + std::string best_summary; + for (VkPhysicalDevice device : devices) { + const std::string summary = probe_physical_device(device); + if (best_summary.empty() || + summary.find("usable=true") != std::string::npos) { + best_summary = summary; + } + if (summary.find("usable=true") != std::string::npos) { + break; + } + } + + vkDestroyInstance(instance, nullptr); + return best_summary.empty() + ? "usable=false;apiDetected=false;reason=no_probe_result" + : best_summary; +} + +} // namespace + +extern "C" JNIEXPORT jstring JNICALL +Java_com_shifthackz_aisdv1_feature_benchmark_AndroidBenchmarkVulkanProbe_probeVulkan( + JNIEnv* env, + jobject +) { + const std::string summary = probe_vulkan(); + __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "bonsai_vulkan_probe %s", summary.c_str()); + return env->NewStringUTF(summary.c_str()); +} diff --git a/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke.comp b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke.comp new file mode 100644 index 000000000..2c92849c3 --- /dev/null +++ b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke.comp @@ -0,0 +1,23 @@ +#version 450 + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer InputBuffer { + float values[]; +} input_buffer; + +layout(set = 0, binding = 1) writeonly buffer OutputBuffer { + float values[]; +} output_buffer; + +layout(push_constant) uniform Params { + uint count; +} params; + +void main() { + const uint index = gl_GlobalInvocationID.x; + if (index >= params.count) { + return; + } + output_buffer.values[index] = input_buffer.values[index] * 2.0 + 1.0; +} diff --git a/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke_spv.h b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke_spv.h new file mode 100644 index 000000000..846a39251 --- /dev/null +++ b/feature/benchmark/src/androidMain/cpp/benchmark_vulkan_smoke_spv.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include + +alignas(4) static const uint8_t kBenchmarkVulkanSmokeSpv[] = { + 0x03, 0x02, 0x23, 0x07, 0x00, 0x03, 0x01, 0x00, 0x0a, 0x00, 0x0d, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x53, 0x4c, 0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, + 0xc2, 0x01, 0x00, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x47, 0x4c, 0x5f, 0x47, + 0x4f, 0x4f, 0x47, 0x4c, 0x45, 0x5f, 0x63, 0x70, 0x70, 0x5f, 0x73, 0x74, + 0x79, 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x6e, 0x65, 0x5f, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x76, 0x65, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x47, 0x4c, 0x5f, 0x47, 0x4f, 0x4f, 0x47, 0x4c, 0x45, 0x5f, 0x69, 0x6e, + 0x63, 0x6c, 0x75, 0x64, 0x65, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x76, 0x65, 0x00, 0x05, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x08, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x47, + 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x49, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x49, 0x44, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x00, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00, 0x4f, 0x75, 0x74, 0x70, + 0x75, 0x74, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x62, + 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x25, 0x00, 0x00, 0x00, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x42, 0x75, 0x66, + 0x66, 0x65, 0x72, 0x00, 0x06, 0x00, 0x05, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x06, 0x00, 0x27, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x70, 0x75, + 0x74, 0x5f, 0x62, 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x11, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x04, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x24, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x48, 0x00, 0x04, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x25, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x27, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x33, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x13, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x03, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x02, 0x00, 0x19, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x1f, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x21, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, 0x26, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x26, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x29, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x1e, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x2b, 0x00, 0x04, 0x00, + 0x1e, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, + 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x06, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x36, 0x00, 0x05, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0xae, 0x00, 0x05, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0xf7, 0x00, 0x03, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x01, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x29, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0x85, 0x00, 0x05, 0x00, 0x1e, 0x00, 0x00, 0x00, + 0x2d, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x05, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x2f, 0x00, 0x00, 0x00, + 0x2d, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x29, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x2f, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x01, 0x00, + 0x38, 0x00, 0x01, 0x00 +}; +static constexpr size_t kBenchmarkVulkanSmokeSpvSize = sizeof(kBenchmarkVulkanSmokeSpv); diff --git a/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkDeviceProbe.kt b/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkDeviceProbe.kt index 384eb4e6b..3516a3d4b 100644 --- a/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkDeviceProbe.kt +++ b/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkDeviceProbe.kt @@ -21,6 +21,7 @@ internal class AndroidBenchmarkDeviceProbe( val memoryInfo = activityManager()?.let { manager -> ActivityManager.MemoryInfo().also(manager::getMemoryInfo) } + val vulkanProbe = AndroidBenchmarkVulkanProbe.capture() BenchmarkDeviceInfo( platform = BenchmarkPlatform.ANDROID, manufacturer = Build.MANUFACTURER.orUnknown(), @@ -33,7 +34,8 @@ internal class AndroidBenchmarkDeviceProbe( availableRamMb = memoryInfo?.availMem?.bytesToMb() ?: maxMemoryMb(), totalVramMb = null, availableVramMb = null, - accelerators = accelerators(), + accelerators = accelerators(vulkanProbe), + acceleratorDiagnostics = listOf("Bonsai Vulkan probe: ${vulkanProbe.summary}"), ) }.getOrElse { fallback() } @@ -68,12 +70,17 @@ internal class AndroidBenchmarkDeviceProbe( return candidates.joinToString(" / ").ifBlank { "Android GPU" } } - private fun accelerators(): List = buildList { + private fun accelerators( + vulkanProbe: AndroidBenchmarkVulkanProbeResult, + ): List = buildList { if (hasSystemFeature("android.hardware.vulkan.level") || hasSystemFeature("android.hardware.vulkan.version") ) { add(BenchmarkAccelerator.VULKAN) } + if (vulkanProbe.usable) { + add(BenchmarkAccelerator.BONSAI_VULKAN) + } if (hasOpenClLibrary()) add(BenchmarkAccelerator.OPEN_CL) if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O_MR1) { add(BenchmarkAccelerator.NNAPI) diff --git a/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkVulkanProbe.kt b/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkVulkanProbe.kt new file mode 100644 index 000000000..441ee0c13 --- /dev/null +++ b/feature/benchmark/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/AndroidBenchmarkVulkanProbe.kt @@ -0,0 +1,65 @@ +package com.shifthackz.aisdv1.feature.benchmark + +import android.util.Log + +/** + * Runs a benchmark-scoped native Vulkan probe for custom Bonsai compute kernels. + * + * This does not initialize Bonsai models or generation runtime. + * + * @author Dmitriy Moroz + */ +internal object AndroidBenchmarkVulkanProbe { + + private val loadResult = runCatching { + System.loadLibrary(LIBRARY_NAME) + } + + fun capture(): AndroidBenchmarkVulkanProbeResult { + val result = loadResult.fold( + onSuccess = { + runCatching { probeVulkan().toProbeResult() } + .getOrElse { error -> + AndroidBenchmarkVulkanProbeResult( + apiDetected = false, + usable = false, + summary = "usable=false;apiDetected=false;reason=probe_failed:${error.message}", + ) + } + }, + onFailure = { error -> + AndroidBenchmarkVulkanProbeResult( + apiDetected = false, + usable = false, + summary = "usable=false;apiDetected=false;reason=native_unavailable:${error.message}", + ) + }, + ) + Log.i(LOG_TAG, "bonsai_vulkan_probe ${result.summary}") + return result + } + + private external fun probeVulkan(): String +} + +internal data class AndroidBenchmarkVulkanProbeResult( + val apiDetected: Boolean, + val usable: Boolean, + val summary: String, +) + +private fun String.toProbeResult(): AndroidBenchmarkVulkanProbeResult = + AndroidBenchmarkVulkanProbeResult( + apiDetected = containsKeyValue("apiDetected", "true"), + usable = containsKeyValue("usable", "true"), + summary = this, + ) + +private fun String.containsKeyValue( + key: String, + value: String, +): Boolean = split(';') + .any { item -> item.substringBefore('=') == key && item.substringAfter('=') == value } + +private const val LIBRARY_NAME = "sdai_benchmark" +private const val LOG_TAG = "SDAI-Benchmark" diff --git a/feature/benchmark/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkRecommendationPolicyTest.kt b/feature/benchmark/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkRecommendationPolicyTest.kt index 862a513d3..14b9b0b88 100644 --- a/feature/benchmark/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkRecommendationPolicyTest.kt +++ b/feature/benchmark/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkRecommendationPolicyTest.kt @@ -74,6 +74,26 @@ class BenchmarkRecommendationPolicyTest { BenchmarkAccelerationStatus.NOT_RECOMMENDED, capabilities.status(BenchmarkAccelerator.NNAPI), ) + assertEquals( + BenchmarkAccelerationStatus.UNAVAILABLE, + capabilities.status(BenchmarkAccelerator.BONSAI_VULKAN), + ) + } + + @Test + fun `given pixel 3a class device with bonsai vulkan probe, expected bonsai compute is separate`() { + val capabilities = pixel3a() + .copy(accelerators = pixel3a().accelerators + BenchmarkAccelerator.BONSAI_VULKAN) + .accelerationCapabilities() + + assertEquals( + BenchmarkAccelerationStatus.BACKEND_UNAVAILABLE, + capabilities.status(BenchmarkAccelerator.VULKAN), + ) + assertEquals( + BenchmarkAccelerationStatus.SUPPORTED, + capabilities.status(BenchmarkAccelerator.BONSAI_VULKAN), + ) } @Test diff --git a/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkEntities.kt b/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkEntities.kt index a820c6b48..410d59ded 100644 --- a/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkEntities.kt +++ b/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkEntities.kt @@ -21,6 +21,7 @@ enum class BenchmarkPlatform { */ enum class BenchmarkAccelerator { VULKAN, + BONSAI_VULKAN, OPEN_CL, NNAPI, METAL, @@ -74,6 +75,7 @@ data class BenchmarkDeviceInfo( val totalVramMb: Long?, val availableVramMb: Long?, val accelerators: List, + val acceleratorDiagnostics: List = emptyList(), ) /** @@ -148,6 +150,10 @@ private fun BenchmarkDeviceInfo.accelerationStatus( val apiDetected = accelerator in accelerators return when (platform) { BenchmarkPlatform.ANDROID -> when (accelerator) { + BenchmarkAccelerator.BONSAI_VULKAN -> when { + !apiDetected -> BenchmarkAccelerationStatus.UNAVAILABLE + else -> BenchmarkAccelerationStatus.SUPPORTED + } BenchmarkAccelerator.VULKAN, BenchmarkAccelerator.OPEN_CL -> when { !apiDetected -> BenchmarkAccelerationStatus.UNAVAILABLE @@ -176,6 +182,7 @@ private fun BenchmarkDeviceInfo.accelerationStatus( BenchmarkAccelerationStatus.UNAVAILABLE } BenchmarkAccelerator.VULKAN, + BenchmarkAccelerator.BONSAI_VULKAN, BenchmarkAccelerator.OPEN_CL, BenchmarkAccelerator.NNAPI -> BenchmarkAccelerationStatus.UNAVAILABLE } diff --git a/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkScoreEngine.kt b/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkScoreEngine.kt index b05e13964..5cb8d7685 100644 --- a/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkScoreEngine.kt +++ b/feature/benchmark/src/commonMain/kotlin/com/shifthackz/aisdv1/feature/benchmark/BenchmarkScoreEngine.kt @@ -130,6 +130,7 @@ internal class BenchmarkScoreEngine( .fold(base) { score, capability -> score + when (capability.accelerator) { BenchmarkAccelerator.VULKAN -> 180 + BenchmarkAccelerator.BONSAI_VULKAN -> 220 BenchmarkAccelerator.OPEN_CL -> 320 BenchmarkAccelerator.NNAPI -> 240 BenchmarkAccelerator.METAL -> 760 @@ -157,6 +158,7 @@ internal class BenchmarkScoreEngine( deviceInfo: BenchmarkDeviceInfo, workload: WorkloadScore, ): List = buildList { + deviceInfo.acceleratorDiagnostics.forEach(::add) if (deviceInfo.totalVramMb == null) add("VRAM is not directly exposed by this platform.") if (deviceInfo.accelerationCapabilities().none { it.status == BenchmarkAccelerationStatus.SUPPORTED }) { add("No supported local hardware acceleration backend was detected.") diff --git a/feature/bonsai/build.gradle.kts b/feature/bonsai/build.gradle.kts index 3a86da57c..ea277439b 100644 --- a/feature/bonsai/build.gradle.kts +++ b/feature/bonsai/build.gradle.kts @@ -4,6 +4,25 @@ plugins { android { namespace = "com.shifthackz.aisdv1.feature.bonsai" + + defaultConfig { + ndk { + abiFilters += "arm64-v8a" + } + + externalNativeBuild { + cmake { + arguments += "-DANDROID_STL=c++_shared" + } + } + } + + externalNativeBuild { + cmake { + path = file("src/androidMain/cpp/CMakeLists.txt") + version = "3.22.1" + } + } } kotlin { @@ -13,5 +32,9 @@ kotlin { implementation(libs.koin.core) implementation(libs.kotlinx.coroutines.core) } + + androidUnitTest.dependencies { + implementation(libs.test.junit) + } } } diff --git a/feature/bonsai/src/androidMain/cpp/CMakeLists.txt b/feature/bonsai/src/androidMain/cpp/CMakeLists.txt new file mode 100644 index 000000000..7a3fecec8 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/CMakeLists.txt @@ -0,0 +1,77 @@ +cmake_minimum_required(VERSION 3.22.1) + +project(sdai_bonsai_runtime LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +add_library( + sdai_bonsai + SHARED + bonsai_activation.cpp + bonsai_attention.cpp + bonsai_dequant.cpp + bonsai_embedding.cpp + bonsai_flux_attention_layout.cpp + bonsai_flux_double_block.cpp + bonsai_flux_modulation.cpp + bonsai_flux_output.cpp + bonsai_flux_pos_embed.cpp + bonsai_flux_rope.cpp + bonsai_flux_single_block.cpp + bonsai_flux_time_embedding.cpp + bonsai_flux_transformer.cpp + bonsai_flux_vae.cpp + bonsai_image_encoder.cpp + bonsai_jni.cpp + bonsai_layer_norm.cpp + bonsai_latents.cpp + bonsai_linear.cpp + bonsai_matmul.cpp + bonsai_model_config.cpp + bonsai_model_probe.cpp + bonsai_norm.cpp + bonsai_packed_weight.cpp + bonsai_prompt.cpp + bonsai_qwen.cpp + bonsai_qwen_inputs.cpp + bonsai_rotary.cpp + bonsai_runtime.cpp + bonsai_runtime_context.cpp + bonsai_scheduler.cpp + bonsai_safetensors.cpp + bonsai_tensor.cpp + bonsai_tensor_storage.cpp + bonsai_tokenizer.cpp + bonsai_vae_decoder.cpp + bonsai_vae_ops.cpp + bonsai_vulkan.cpp +) + +target_link_libraries( + sdai_bonsai + PRIVATE + android + log + vulkan + z +) + +target_link_options( + sdai_bonsai + PRIVATE + "-Wl,-z,max-page-size=16384" + "-Wl,-z,common-page-size=16384" +) + +target_compile_options( + sdai_bonsai + PRIVATE + -Wall + -Wextra + -Wno-unused-parameter + $<$:-O3> + $<$:-DNDEBUG> + $<$:-fno-math-errno> +) diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_activation.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_activation.cpp new file mode 100644 index 000000000..41663d1fe --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_activation.cpp @@ -0,0 +1,74 @@ +#include "bonsai_activation.h" + +#include +#include + +float bonsai_silu(float value) { + return value / (1.0F + std::exp(-value)); +} + +std::vector bonsai_silu( + const std::vector& input +) { + std::vector output; + output.reserve(input.size()); + for (float value : input) { + output.push_back(bonsai_silu(value)); + } + return output; +} + +std::vector bonsai_silu_times( + const std::vector& gate, + const std::vector& up +) { + if (gate.size() != up.size()) { + throw std::runtime_error("Bonsai SiLU gate/up size mismatch."); + } + + std::vector output; + output.reserve(gate.size()); + for (size_t index = 0; index < gate.size(); index++) { + output.push_back(bonsai_silu(gate[index]) * up[index]); + } + return output; +} + +std::vector bonsai_swiglu( + const std::vector& input +) { + if (input.empty() || input.size() % 2 != 0) { + throw std::runtime_error("Bonsai SwiGLU input size must be positive and even."); + } + + const size_t half = input.size() / 2; + std::vector output; + output.reserve(half); + for (size_t index = 0; index < half; index++) { + output.push_back(bonsai_silu(input[index]) * input[half + index]); + } + return output; +} + +std::vector bonsai_swiglu_last_dimension( + const std::vector& input, + uint64_t last_dimension +) { + if (last_dimension == 0 || last_dimension % 2 != 0) { + throw std::runtime_error("Bonsai SwiGLU last dimension must be positive and even."); + } + const size_t dimension = static_cast(last_dimension); + if (input.empty() || input.size() % dimension != 0) { + throw std::runtime_error("Bonsai SwiGLU input shape mismatch."); + } + + std::vector output; + output.reserve(input.size() / 2); + const size_t half = dimension / 2; + for (size_t offset = 0; offset < input.size(); offset += dimension) { + for (size_t index = 0; index < half; index++) { + output.push_back(bonsai_silu(input[offset + index]) * input[offset + half + index]); + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_activation.h b/feature/bonsai/src/androidMain/cpp/bonsai_activation.h new file mode 100644 index 000000000..2c0d83cdf --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_activation.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +float bonsai_silu(float value); + +std::vector bonsai_silu( + const std::vector& input +); + +std::vector bonsai_silu_times( + const std::vector& gate, + const std::vector& up +); + +std::vector bonsai_swiglu( + const std::vector& input +); + +std::vector bonsai_swiglu_last_dimension( + const std::vector& input, + uint64_t last_dimension +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_attention.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_attention.cpp new file mode 100644 index 000000000..5ff26f01e --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_attention.cpp @@ -0,0 +1,169 @@ +#include "bonsai_attention.h" + +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(label); + } + return left * right; +} + +uint64_t attention_size(uint64_t heads, uint64_t length, uint64_t head_dimension) { + return checked_multiply( + checked_multiply(heads, length, "Bonsai attention shape is too large."), + head_dimension, + "Bonsai attention shape is too large." + ); +} + +size_t attention_index( + uint64_t head, + uint64_t row, + uint64_t column, + uint64_t length, + uint64_t head_dimension +) { + return static_cast(((head * length) + row) * head_dimension + column); +} + +float dot_product( + const std::vector& left, + const std::vector& right, + uint64_t left_head, + uint64_t right_head, + uint64_t left_row, + uint64_t right_row, + uint64_t length, + uint64_t head_dimension +) { + float sum = 0.0F; + for (uint64_t column = 0; column < head_dimension; column++) { + sum += left[attention_index(left_head, left_row, column, length, head_dimension)] * + right[attention_index(right_head, right_row, column, length, head_dimension)]; + } + return sum; +} + +} // namespace + +std::vector bonsai_repeat_kv_heads( + const std::vector& input, + uint64_t key_value_heads, + uint64_t repeats, + uint64_t length, + uint64_t head_dimension +) { + if (repeats == 0) { + throw std::runtime_error("Bonsai repeat-KV repeat count must be positive."); + } + const uint64_t expected_input = attention_size(key_value_heads, length, head_dimension); + if (input.size() != static_cast(expected_input)) { + throw std::runtime_error("Bonsai repeat-KV input size mismatch."); + } + + std::vector output; + output.reserve(static_cast(checked_multiply( + expected_input, + repeats, + "Bonsai repeat-KV output shape is too large." + ))); + for (uint64_t head = 0; head < key_value_heads; head++) { + for (uint64_t repeat = 0; repeat < repeats; repeat++) { + for (uint64_t row = 0; row < length; row++) { + for (uint64_t column = 0; column < head_dimension; column++) { + output.push_back(input[attention_index( + head, + row, + column, + length, + head_dimension + )]); + } + } + } + } + return output; +} + +std::vector bonsai_scaled_dot_product_attention( + const std::vector& queries, + const std::vector& keys, + const std::vector& values, + const std::vector& additive_mask, + uint64_t heads, + uint64_t length, + uint64_t head_dimension, + float scale +) { + const uint64_t expected_size = attention_size(heads, length, head_dimension); + if (queries.size() != static_cast(expected_size) || + keys.size() != static_cast(expected_size) || + values.size() != static_cast(expected_size) + ) { + throw std::runtime_error("Bonsai attention input size mismatch."); + } + if (!additive_mask.empty() && additive_mask.size() != static_cast(length * length)) { + throw std::runtime_error("Bonsai attention mask size mismatch."); + } + + std::vector output; + output.resize(static_cast(expected_size), 0.0F); + std::vector scores; + scores.resize(static_cast(length)); + + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t row = 0; row < length; row++) { + float max_score = -std::numeric_limits::infinity(); + for (uint64_t key_row = 0; key_row < length; key_row++) { + float score = dot_product( + queries, + keys, + head, + head, + row, + key_row, + length, + head_dimension + ) * scale; + if (!additive_mask.empty()) { + score += additive_mask[static_cast(row * length + key_row)]; + } + scores[static_cast(key_row)] = score; + max_score = std::max(max_score, score); + } + + float denominator = 0.0F; + for (uint64_t key_row = 0; key_row < length; key_row++) { + const float weight = std::exp(scores[static_cast(key_row)] - max_score); + scores[static_cast(key_row)] = weight; + denominator += weight; + } + + if (denominator == 0.0F) { + throw std::runtime_error("Bonsai attention softmax denominator is zero."); + } + + for (uint64_t column = 0; column < head_dimension; column++) { + float value = 0.0F; + for (uint64_t key_row = 0; key_row < length; key_row++) { + const float weight = scores[static_cast(key_row)] / denominator; + value += weight * values[attention_index( + head, + key_row, + column, + length, + head_dimension + )]; + } + output[attention_index(head, row, column, length, head_dimension)] = value; + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_attention.h b/feature/bonsai/src/androidMain/cpp/bonsai_attention.h new file mode 100644 index 000000000..9b35d904e --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_attention.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +std::vector bonsai_repeat_kv_heads( + const std::vector& input, + uint64_t key_value_heads, + uint64_t repeats, + uint64_t length, + uint64_t head_dimension +); + +std::vector bonsai_scaled_dot_product_attention( + const std::vector& queries, + const std::vector& keys, + const std::vector& values, + const std::vector& additive_mask, + uint64_t heads, + uint64_t length, + uint64_t head_dimension, + float scale +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_dequant.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_dequant.cpp new file mode 100644 index 000000000..30e516335 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_dequant.cpp @@ -0,0 +1,124 @@ +#include "bonsai_dequant.h" + +#include "bonsai_tensor.h" + +#include +#include +#include + +namespace { + +template +T read_unaligned(const uint8_t* data) { + T value {}; + std::memcpy(&value, data, sizeof(T)); + return value; +} + +uint64_t last_dimension(const BonsaiTensorView& view) { + if (view.descriptor->shape.empty()) { + throw std::runtime_error("Bonsai tensor must have at least one dimension."); + } + return view.descriptor->shape.back(); +} + +void require_packed(const BonsaiPackedWeightViews& views) { + if (!views.packed) { + throw std::runtime_error( + "Bonsai weight is dense, not packed: " + views.weight.descriptor->key + ); + } +} + +void require_row(const BonsaiPackedWeightViews& views, uint64_t row) { + if (row >= views.leading_rows) { + throw std::runtime_error( + "Bonsai packed weight row is out of range: " + views.weight.descriptor->key + ); + } +} + +void require_column(const BonsaiPackedWeightViews& views, uint64_t column) { + if (column >= views.input_values) { + throw std::runtime_error( + "Bonsai packed weight column is out of range: " + views.weight.descriptor->key + ); + } +} + +float read_row_scalar( + const BonsaiTensorView& view, + uint64_t row, + uint64_t column +) { + const uint64_t columns = last_dimension(view); + const uint64_t index = row * columns + column; + return bonsai_read_scalar_as_f32( + view.data + index * view.dtype_byte_count, + view.dtype + ); +} + +uint32_t read_packed_word( + const BonsaiPackedWeightViews& views, + uint64_t row, + uint64_t packed_word_column +) { + const uint64_t packed_columns = last_dimension(views.weight); + const uint64_t word_index = row * packed_columns + packed_word_column; + return read_unaligned( + views.weight.data + word_index * views.weight.dtype_byte_count + ); +} + +} // namespace + +uint32_t bonsai_unpack_quantized_value( + uint32_t word, + int bits, + uint64_t packed_value_index +) { + if (bits != 1 && bits != 2 && bits != 4) { + throw std::runtime_error("Unsupported Bonsai quantization bits: " + std::to_string(bits)); + } + const uint32_t mask = (1U << static_cast(bits)) - 1U; + const uint32_t values_per_word = 32U / static_cast(bits); + const uint32_t shift = static_cast( + (packed_value_index % values_per_word) * static_cast(bits) + ); + return (word >> shift) & mask; +} + +float bonsai_dequantize_packed_value( + const BonsaiPackedWeightViews& views, + uint64_t row, + uint64_t column +) { + require_packed(views); + require_row(views, row); + require_column(views, column); + + const uint64_t values_per_word = 32ULL / static_cast(views.bits); + const uint64_t word_column = column / values_per_word; + const uint64_t group_column = column / static_cast(views.group_size); + const uint32_t word = read_packed_word(views, row, word_column); + const uint32_t quantized = bonsai_unpack_quantized_value(word, views.bits, column); + const float scale = read_row_scalar(views.scales, row, group_column); + const float bias = read_row_scalar(views.biases, row, group_column); + return static_cast(quantized) * scale + bias; +} + +std::vector bonsai_dequantize_packed_row( + const BonsaiPackedWeightViews& views, + uint64_t row +) { + require_packed(views); + require_row(views, row); + + std::vector output; + output.reserve(static_cast(views.input_values)); + for (uint64_t column = 0; column < views.input_values; column++) { + output.push_back(bonsai_dequantize_packed_value(views, row, column)); + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_dequant.h b/feature/bonsai/src/androidMain/cpp/bonsai_dequant.h new file mode 100644 index 000000000..5ffbff033 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_dequant.h @@ -0,0 +1,23 @@ +#pragma once + +#include "bonsai_packed_weight.h" + +#include +#include + +uint32_t bonsai_unpack_quantized_value( + uint32_t word, + int bits, + uint64_t packed_value_index +); + +std::vector bonsai_dequantize_packed_row( + const BonsaiPackedWeightViews& views, + uint64_t row +); + +float bonsai_dequantize_packed_value( + const BonsaiPackedWeightViews& views, + uint64_t row, + uint64_t column +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_embedding.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_embedding.cpp new file mode 100644 index 000000000..00350ceb1 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_embedding.cpp @@ -0,0 +1,93 @@ +#include "bonsai_embedding.h" + +#include "bonsai_dequant.h" +#include "bonsai_tensor.h" + +#include +#include + +namespace { + +void require_token_id(const BonsaiEmbeddingViews& views, uint64_t token_id) { + if (token_id >= views.rows) { + throw std::runtime_error( + "Bonsai embedding token id is out of range: " + std::to_string(token_id) + ); + } +} + +std::vector dense_embedding_row( + const BonsaiDenseWeightViews& views, + uint64_t token_id +) { + std::vector output; + output.reserve(static_cast(views.input_values)); + for (uint64_t column = 0; column < views.input_values; column++) { + const uint64_t index = token_id * views.input_values + column; + output.push_back( + bonsai_read_scalar_as_f32( + views.weight.data + index * views.weight.dtype_byte_count, + views.weight.dtype + ) + ); + } + return output; +} + +} // namespace + +BonsaiEmbeddingViews bonsai_require_embedding_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor +) { + BonsaiEmbeddingViews views; + views.kind = descriptor.packed ? BonsaiEmbeddingWeightKind::Packed : BonsaiEmbeddingWeightKind::Dense; + if (descriptor.packed) { + views.packed = bonsai_require_packed_weight_views(storage, index, descriptor); + views.rows = views.packed.leading_rows; + views.dimensions = views.packed.input_values; + } else { + views.dense = bonsai_require_dense_weight_view(storage, index, descriptor.weight_key); + views.rows = views.dense.leading_rows; + views.dimensions = views.dense.input_values; + } + return views; +} + +std::vector bonsai_embedding_row( + const BonsaiEmbeddingViews& views, + uint64_t token_id +) { + require_token_id(views, token_id); + switch (views.kind) { + case BonsaiEmbeddingWeightKind::Dense: + return dense_embedding_row(views.dense, token_id); + case BonsaiEmbeddingWeightKind::Packed: + return bonsai_dequantize_packed_row(views.packed, token_id); + } + throw std::runtime_error("Unsupported Bonsai embedding weight kind."); +} + +std::vector bonsai_embedding_lookup( + const BonsaiEmbeddingViews& views, + const std::vector& token_ids +) { + std::vector output; + output.reserve(static_cast(views.dimensions * token_ids.size())); + for (uint64_t token_id : token_ids) { + const std::vector row = bonsai_embedding_row(views, token_id); + output.insert(output.end(), row.begin(), row.end()); + } + return output; +} + +uint64_t bonsai_embedding_byte_count(const BonsaiEmbeddingViews& views) { + switch (views.kind) { + case BonsaiEmbeddingWeightKind::Dense: + return views.dense.weight.byte_count; + case BonsaiEmbeddingWeightKind::Packed: + return bonsai_packed_weight_byte_count(views.packed); + } + return 0; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_embedding.h b/feature/bonsai/src/androidMain/cpp/bonsai_embedding.h new file mode 100644 index 000000000..d4df0b263 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_embedding.h @@ -0,0 +1,40 @@ +#pragma once + +#include "bonsai_matmul.h" +#include "bonsai_packed_weight.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include + +enum class BonsaiEmbeddingWeightKind { + Dense, + Packed, +}; + +struct BonsaiEmbeddingViews { + BonsaiEmbeddingWeightKind kind = BonsaiEmbeddingWeightKind::Dense; + BonsaiDenseWeightViews dense; + BonsaiPackedWeightViews packed; + uint64_t rows = 0; + uint64_t dimensions = 0; +}; + +BonsaiEmbeddingViews bonsai_require_embedding_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor +); + +std::vector bonsai_embedding_row( + const BonsaiEmbeddingViews& views, + uint64_t token_id +); + +std::vector bonsai_embedding_lookup( + const BonsaiEmbeddingViews& views, + const std::vector& token_ids +); + +uint64_t bonsai_embedding_byte_count(const BonsaiEmbeddingViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.cpp new file mode 100644 index 000000000..ec36be30c --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.cpp @@ -0,0 +1,468 @@ +#include "bonsai_flux_attention_layout.h" + +#include +#include +#include + +namespace { + +size_t checked_size(uint64_t left, uint64_t right, const char* label) { + const uint64_t limit = static_cast(std::numeric_limits::max()); + if (left != 0 && right > limit / left) { + throw std::runtime_error(std::string("Bonsai Flux attention layout size overflow: ") + label); + } + return static_cast(left * right); +} + +uint64_t checked_u64_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai Flux attention layout size overflow: ") + label); + } + return left * right; +} + +size_t checked_size_3(uint64_t first, uint64_t second, uint64_t third, const char* label) { + const size_t first_second = checked_size(first, second, label); + if (third != 0 && + first_second > std::numeric_limits::max() / static_cast(third)) { + throw std::runtime_error(std::string("Bonsai Flux attention layout size overflow: ") + label); + } + return first_second * static_cast(third); +} + +uint64_t checked_width( + uint64_t dimensions, + uint64_t mlp_hidden_dimensions +) { + if (dimensions > (std::numeric_limits::max() / 3U)) { + throw std::runtime_error("Bonsai Flux single projection width overflow."); + } + const uint64_t qkv_width = dimensions * 3U; + if (mlp_hidden_dimensions > (std::numeric_limits::max() / 2U)) { + throw std::runtime_error("Bonsai Flux single projection width overflow."); + } + const uint64_t mlp_width = mlp_hidden_dimensions * 2U; + if (qkv_width > std::numeric_limits::max() - mlp_width) { + throw std::runtime_error("Bonsai Flux single projection width overflow."); + } + return qkv_width + mlp_width; +} + +void require_positive(uint64_t value, const char* label) { + if (value == 0) { + throw std::runtime_error(std::string("Bonsai Flux attention layout value must be positive: ") + label); + } +} + +void copy_projection_chunk( + const std::vector& fused, + uint64_t batch, + uint64_t sequence_length, + uint64_t fused_width, + uint64_t chunk_offset, + uint64_t chunk_width, + std::vector* output +) { + output->assign(checked_size_3(batch, sequence_length, chunk_width, "projection chunk"), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + const size_t source_offset = static_cast( + (batch_index * sequence_length + token) * fused_width + chunk_offset + ); + const size_t target_offset = static_cast( + (batch_index * sequence_length + token) * chunk_width + ); + for (uint64_t index = 0; index < chunk_width; index++) { + (*output)[target_offset + static_cast(index)] = + fused[source_offset + static_cast(index)]; + } + } + } +} + +size_t head_index( + uint64_t batch_index, + uint64_t head, + uint64_t token, + uint64_t column, + uint64_t heads, + uint64_t sequence_length, + uint64_t head_dimension +) { + return static_cast( + ((batch_index * heads + head) * sequence_length + token) * head_dimension + column + ); +} + +size_t sequence_index( + uint64_t batch_index, + uint64_t token, + uint64_t head, + uint64_t column, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + return static_cast( + ((batch_index * sequence_length + token) * heads + head) * head_dimension + column + ); +} + +} // namespace + +BonsaiFluxSingleProjectionParts bonsai_flux_split_single_projection( + const std::vector& fused, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + uint64_t mlp_hidden_dimensions +) { + require_positive(batch, "batch"); + require_positive(sequence_length, "sequence"); + require_positive(dimensions, "dimensions"); + require_positive(mlp_hidden_dimensions, "mlp hidden"); + + const uint64_t fused_width = checked_width(dimensions, mlp_hidden_dimensions); + if (fused.size() != checked_size_3(batch, sequence_length, fused_width, "single projection")) { + throw std::runtime_error("Bonsai Flux single projection shape mismatch."); + } + + BonsaiFluxSingleProjectionParts output { + batch, + sequence_length, + dimensions, + mlp_hidden_dimensions, + {}, + {}, + {}, + {}, + }; + copy_projection_chunk( + fused, + batch, + sequence_length, + fused_width, + 0, + dimensions, + &output.query + ); + copy_projection_chunk( + fused, + batch, + sequence_length, + fused_width, + dimensions, + dimensions, + &output.key + ); + copy_projection_chunk( + fused, + batch, + sequence_length, + fused_width, + dimensions * 2U, + dimensions, + &output.value + ); + copy_projection_chunk( + fused, + batch, + sequence_length, + fused_width, + dimensions * 3U, + mlp_hidden_dimensions * 2U, + &output.mlp_values + ); + return output; +} + +std::vector bonsai_flux_sequence_to_heads( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + require_positive(batch, "batch"); + require_positive(sequence_length, "sequence"); + require_positive(heads, "heads"); + require_positive(head_dimension, "head dimension"); + const uint64_t dimensions = checked_u64_multiply(heads, head_dimension, "sequence dimensions"); + const uint64_t batch_heads = checked_u64_multiply(batch, heads, "batch heads"); + const size_t expected = checked_size_3(batch, sequence_length, dimensions, "sequence"); + if (input.size() != expected) { + throw std::runtime_error("Bonsai Flux sequence-to-heads shape mismatch."); + } + + std::vector output(checked_size_3(batch_heads, sequence_length, head_dimension, "heads"), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t column = 0; column < head_dimension; column++) { + output[head_index( + batch_index, + head, + token, + column, + heads, + sequence_length, + head_dimension + )] = input[sequence_index( + batch_index, + token, + head, + column, + sequence_length, + heads, + head_dimension + )]; + } + } + } + } + return output; +} + +std::vector bonsai_flux_heads_to_sequence( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + require_positive(batch, "batch"); + require_positive(sequence_length, "sequence"); + require_positive(heads, "heads"); + require_positive(head_dimension, "head dimension"); + const uint64_t dimensions = checked_u64_multiply(heads, head_dimension, "sequence dimensions"); + const uint64_t batch_heads = checked_u64_multiply(batch, heads, "batch heads"); + const size_t expected = checked_size_3(batch_heads, sequence_length, head_dimension, "heads"); + if (input.size() != expected) { + throw std::runtime_error("Bonsai Flux heads-to-sequence shape mismatch."); + } + + std::vector output(checked_size_3(batch, sequence_length, dimensions, "sequence"), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t column = 0; column < head_dimension; column++) { + output[sequence_index( + batch_index, + token, + head, + column, + sequence_length, + heads, + head_dimension + )] = input[head_index( + batch_index, + head, + token, + column, + heads, + sequence_length, + head_dimension + )]; + } + } + } + } + return output; +} + +std::vector bonsai_flux_concat_head_sequences( + const std::vector& first, + const std::vector& second, + uint64_t batch, + uint64_t heads, + uint64_t first_sequence_length, + uint64_t second_sequence_length, + uint64_t head_dimension +) { + require_positive(batch, "batch"); + require_positive(heads, "heads"); + require_positive(first_sequence_length, "first sequence"); + require_positive(second_sequence_length, "second sequence"); + require_positive(head_dimension, "head dimension"); + const uint64_t batch_heads = checked_u64_multiply(batch, heads, "batch heads"); + const size_t first_size = checked_size_3(batch_heads, first_sequence_length, head_dimension, "first"); + const size_t second_size = checked_size_3(batch_heads, second_sequence_length, head_dimension, "second"); + if (first.size() != first_size || second.size() != second_size) { + throw std::runtime_error("Bonsai Flux concat-head-sequences shape mismatch."); + } + + const uint64_t combined_sequence_length = first_sequence_length + second_sequence_length; + if (combined_sequence_length < first_sequence_length) { + throw std::runtime_error("Bonsai Flux concat-head-sequences length overflow."); + } + std::vector output( + checked_size_3(batch_heads, combined_sequence_length, head_dimension, "concat"), + 0.0F + ); + for (uint64_t batch_head = 0; batch_head < batch_heads; batch_head++) { + for (uint64_t token = 0; token < first_sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const size_t source = static_cast( + (batch_head * first_sequence_length + token) * head_dimension + column + ); + const size_t target = static_cast( + (batch_head * combined_sequence_length + token) * head_dimension + column + ); + output[target] = first[source]; + } + } + for (uint64_t token = 0; token < second_sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const size_t source = static_cast( + (batch_head * second_sequence_length + token) * head_dimension + column + ); + const size_t target = static_cast( + ( + batch_head * combined_sequence_length + + first_sequence_length + + token + ) * head_dimension + column + ); + output[target] = second[source]; + } + } + } + return output; +} + +BonsaiFluxHeadSequenceParts bonsai_flux_split_head_sequences( + const std::vector& input, + uint64_t batch, + uint64_t heads, + uint64_t first_sequence_length, + uint64_t second_sequence_length, + uint64_t head_dimension +) { + require_positive(batch, "batch"); + require_positive(heads, "heads"); + require_positive(first_sequence_length, "first sequence"); + require_positive(second_sequence_length, "second sequence"); + require_positive(head_dimension, "head dimension"); + const uint64_t batch_heads = checked_u64_multiply(batch, heads, "batch heads"); + const uint64_t combined_sequence_length = first_sequence_length + second_sequence_length; + if (combined_sequence_length < first_sequence_length) { + throw std::runtime_error("Bonsai Flux split-head-sequences length overflow."); + } + const size_t expected = checked_size_3( + batch_heads, + combined_sequence_length, + head_dimension, + "split" + ); + if (input.size() != expected) { + throw std::runtime_error("Bonsai Flux split-head-sequences shape mismatch."); + } + + BonsaiFluxHeadSequenceParts output { + batch, + heads, + first_sequence_length, + second_sequence_length, + head_dimension, + {}, + {}, + }; + output.first.assign( + checked_size_3(batch_heads, first_sequence_length, head_dimension, "split first"), + 0.0F + ); + output.second.assign( + checked_size_3(batch_heads, second_sequence_length, head_dimension, "split second"), + 0.0F + ); + + for (uint64_t batch_head = 0; batch_head < batch_heads; batch_head++) { + for (uint64_t token = 0; token < first_sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const size_t source = static_cast( + (batch_head * combined_sequence_length + token) * head_dimension + column + ); + const size_t target = static_cast( + (batch_head * first_sequence_length + token) * head_dimension + column + ); + output.first[target] = input[source]; + } + } + for (uint64_t token = 0; token < second_sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const size_t source = static_cast( + ( + batch_head * combined_sequence_length + + first_sequence_length + + token + ) * head_dimension + column + ); + const size_t target = static_cast( + (batch_head * second_sequence_length + token) * head_dimension + column + ); + output.second[target] = input[source]; + } + } + } + return output; +} + +std::vector bonsai_flux_concat_last_dimension( + const std::vector& first, + const std::vector& second, + uint64_t batch, + uint64_t sequence_length, + uint64_t first_dimensions, + uint64_t second_dimensions +) { + require_positive(batch, "batch"); + require_positive(sequence_length, "sequence"); + require_positive(first_dimensions, "first dimensions"); + require_positive(second_dimensions, "second dimensions"); + const size_t first_size = checked_size_3( + batch, + sequence_length, + first_dimensions, + "concat last first" + ); + const size_t second_size = checked_size_3( + batch, + sequence_length, + second_dimensions, + "concat last second" + ); + if (first.size() != first_size || second.size() != second_size) { + throw std::runtime_error("Bonsai Flux concat-last-dimension shape mismatch."); + } + + const uint64_t output_dimensions = first_dimensions + second_dimensions; + if (output_dimensions < first_dimensions) { + throw std::runtime_error("Bonsai Flux concat-last-dimension width overflow."); + } + std::vector output( + checked_size_3(batch, sequence_length, output_dimensions, "concat last"), + 0.0F + ); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + const size_t first_offset = static_cast( + (batch_index * sequence_length + token) * first_dimensions + ); + const size_t second_offset = static_cast( + (batch_index * sequence_length + token) * second_dimensions + ); + const size_t output_offset = static_cast( + (batch_index * sequence_length + token) * output_dimensions + ); + for (uint64_t index = 0; index < first_dimensions; index++) { + output[output_offset + static_cast(index)] = + first[first_offset + static_cast(index)]; + } + for (uint64_t index = 0; index < second_dimensions; index++) { + output[output_offset + static_cast(first_dimensions + index)] = + second[second_offset + static_cast(index)]; + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.h new file mode 100644 index 000000000..a687ffbf9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_attention_layout.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include + +struct BonsaiFluxSingleProjectionParts { + uint64_t batch = 0; + uint64_t sequence_length = 0; + uint64_t dimensions = 0; + uint64_t mlp_hidden_dimensions = 0; + std::vector query; + std::vector key; + std::vector value; + std::vector mlp_values; +}; + +struct BonsaiFluxHeadSequenceParts { + uint64_t batch = 0; + uint64_t heads = 0; + uint64_t first_sequence_length = 0; + uint64_t second_sequence_length = 0; + uint64_t head_dimension = 0; + std::vector first; + std::vector second; +}; + +BonsaiFluxSingleProjectionParts bonsai_flux_split_single_projection( + const std::vector& fused, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + uint64_t mlp_hidden_dimensions +); + +std::vector bonsai_flux_sequence_to_heads( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +); + +std::vector bonsai_flux_heads_to_sequence( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +); + +std::vector bonsai_flux_concat_head_sequences( + const std::vector& first, + const std::vector& second, + uint64_t batch, + uint64_t heads, + uint64_t first_sequence_length, + uint64_t second_sequence_length, + uint64_t head_dimension +); + +BonsaiFluxHeadSequenceParts bonsai_flux_split_head_sequences( + const std::vector& input, + uint64_t batch, + uint64_t heads, + uint64_t first_sequence_length, + uint64_t second_sequence_length, + uint64_t head_dimension +); + +std::vector bonsai_flux_concat_last_dimension( + const std::vector& first, + const std::vector& second, + uint64_t batch, + uint64_t sequence_length, + uint64_t first_dimensions, + uint64_t second_dimensions +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.cpp new file mode 100644 index 000000000..d728f91b9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.cpp @@ -0,0 +1,360 @@ +#include "bonsai_flux_double_block.h" + +#include "bonsai_attention.h" +#include "bonsai_flux_attention_layout.h" +#include "bonsai_flux_modulation.h" +#include "bonsai_flux_rope.h" + +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai Flux double block shape overflow: ") + label); + } + return left * right; +} + +uint64_t checked_add(uint64_t left, uint64_t right, const char* label) { + if (left > std::numeric_limits::max() - right) { + throw std::runtime_error(std::string("Bonsai Flux double block shape overflow: ") + label); + } + return left + right; +} + +size_t checked_size_3(uint64_t first, uint64_t second, uint64_t third, const char* label) { + const uint64_t first_second = checked_multiply(first, second, label); + const uint64_t total = checked_multiply(first_second, third, label); + if (total > static_cast(std::numeric_limits::max())) { + throw std::runtime_error(std::string("Bonsai Flux double block shape overflow: ") + label); + } + return static_cast(total); +} + +void require_positive(uint64_t value, const char* label) { + if (value == 0) { + throw std::runtime_error(std::string("Bonsai Flux double block value must be positive: ") + label); + } +} + +void require_finite_positive(float value, const char* label) { + if (value <= 0.0F || !std::isfinite(value)) { + throw std::runtime_error(std::string("Bonsai Flux double block value must be finite and positive: ") + label); + } +} + +void require_projection_size( + const std::vector& values, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + const char* label +) { + if (values.size() != checked_size_3(batch, sequence_length, dimensions, label)) { + throw std::runtime_error(std::string("Bonsai Flux double block projection shape mismatch: ") + label); + } +} + +std::vector apply_flux_rope( + const std::vector& projection, + const std::vector& norm_weight, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension, + float epsilon +) { + return bonsai_flux_apply_rms_norm_and_rope( + bonsai_flux_sequence_to_heads(projection, batch, sequence_length, heads, head_dimension), + norm_weight, + cos_values, + sin_values, + checked_multiply(batch, heads, "batch heads"), + sequence_length, + head_dimension, + epsilon + ); +} + +} // namespace + +BonsaiFluxDoubleBlockReferenceOutput bonsai_flux_double_block_reference( + const std::vector& text, + const std::vector& image, + const std::vector& text_modulation_values, + const std::vector& image_modulation_values, + const std::vector& text_query_projection, + const std::vector& text_key_projection, + const std::vector& text_value_projection, + const std::vector& image_query_projection, + const std::vector& image_key_projection, + const std::vector& image_value_projection, + const std::vector& text_attention_update, + const std::vector& image_attention_update, + const std::vector& text_mlp_update, + const std::vector& image_mlp_update, + const std::vector& text_query_norm_weight, + const std::vector& text_key_norm_weight, + const std::vector& image_query_norm_weight, + const std::vector& image_key_norm_weight, + const std::vector& text_cos_values, + const std::vector& text_sin_values, + const std::vector& image_cos_values, + const std::vector& image_sin_values, + uint64_t batch, + uint64_t text_sequence_length, + uint64_t image_sequence_length, + uint64_t heads, + uint64_t head_dimension, + float layer_norm_epsilon, + float rms_norm_epsilon +) { + require_positive(batch, "batch"); + require_positive(text_sequence_length, "text sequence"); + require_positive(image_sequence_length, "image sequence"); + require_positive(heads, "heads"); + require_positive(head_dimension, "head dimension"); + require_finite_positive(layer_norm_epsilon, "layer norm epsilon"); + require_finite_positive(rms_norm_epsilon, "rms norm epsilon"); + + const uint64_t dimensions = checked_multiply(heads, head_dimension, "hidden dimensions"); + require_projection_size(text, batch, text_sequence_length, dimensions, "text"); + require_projection_size(image, batch, image_sequence_length, dimensions, "image"); + require_projection_size( + text_attention_update, + batch, + text_sequence_length, + dimensions, + "text attention update" + ); + require_projection_size( + image_attention_update, + batch, + image_sequence_length, + dimensions, + "image attention update" + ); + require_projection_size(text_mlp_update, batch, text_sequence_length, dimensions, "text mlp"); + require_projection_size(image_mlp_update, batch, image_sequence_length, dimensions, "image mlp"); + if (text_query_norm_weight.size() != static_cast(head_dimension) || + text_key_norm_weight.size() != static_cast(head_dimension) || + image_query_norm_weight.size() != static_cast(head_dimension) || + image_key_norm_weight.size() != static_cast(head_dimension)) { + throw std::runtime_error("Bonsai Flux double block norm weight shape mismatch."); + } + + const BonsaiFluxDoubleModulation text_modulation = bonsai_flux_split_double_modulation( + text_modulation_values, + batch, + dimensions + ); + const BonsaiFluxDoubleModulation image_modulation = bonsai_flux_split_double_modulation( + image_modulation_values, + batch, + dimensions + ); + const std::vector normalized_text_msa = bonsai_flux_apply_modulated_layer_norm( + text, + text_modulation.shift_msa, + text_modulation.scale_msa, + batch, + text_sequence_length, + dimensions, + layer_norm_epsilon + ); + const std::vector normalized_image_msa = bonsai_flux_apply_modulated_layer_norm( + image, + image_modulation.shift_msa, + image_modulation.scale_msa, + batch, + image_sequence_length, + dimensions, + layer_norm_epsilon + ); + + const uint64_t combined_sequence_length = checked_add( + text_sequence_length, + image_sequence_length, + "combined sequence" + ); + const uint64_t batch_heads = checked_multiply(batch, heads, "batch heads"); + const std::vector full_queries = bonsai_flux_concat_head_sequences( + apply_flux_rope( + text_query_projection, + text_query_norm_weight, + text_cos_values, + text_sin_values, + batch, + text_sequence_length, + heads, + head_dimension, + rms_norm_epsilon + ), + apply_flux_rope( + image_query_projection, + image_query_norm_weight, + image_cos_values, + image_sin_values, + batch, + image_sequence_length, + heads, + head_dimension, + rms_norm_epsilon + ), + batch, + heads, + text_sequence_length, + image_sequence_length, + head_dimension + ); + const std::vector full_keys = bonsai_flux_concat_head_sequences( + apply_flux_rope( + text_key_projection, + text_key_norm_weight, + text_cos_values, + text_sin_values, + batch, + text_sequence_length, + heads, + head_dimension, + rms_norm_epsilon + ), + apply_flux_rope( + image_key_projection, + image_key_norm_weight, + image_cos_values, + image_sin_values, + batch, + image_sequence_length, + heads, + head_dimension, + rms_norm_epsilon + ), + batch, + heads, + text_sequence_length, + image_sequence_length, + head_dimension + ); + const std::vector full_values = bonsai_flux_concat_head_sequences( + bonsai_flux_sequence_to_heads( + text_value_projection, + batch, + text_sequence_length, + heads, + head_dimension + ), + bonsai_flux_sequence_to_heads( + image_value_projection, + batch, + image_sequence_length, + heads, + head_dimension + ), + batch, + heads, + text_sequence_length, + image_sequence_length, + head_dimension + ); + const BonsaiFluxHeadSequenceParts attention_parts = bonsai_flux_split_head_sequences( + bonsai_scaled_dot_product_attention( + full_queries, + full_keys, + full_values, + {}, + batch_heads, + combined_sequence_length, + head_dimension, + 1.0F / std::sqrt(static_cast(head_dimension)) + ), + batch, + heads, + text_sequence_length, + image_sequence_length, + head_dimension + ); + const std::vector attention_text = bonsai_flux_heads_to_sequence( + attention_parts.first, + batch, + text_sequence_length, + heads, + head_dimension + ); + const std::vector attention_image = bonsai_flux_heads_to_sequence( + attention_parts.second, + batch, + image_sequence_length, + heads, + head_dimension + ); + const std::vector text_after_attention = bonsai_flux_apply_gated_residual( + text, + text_attention_update, + text_modulation.gate_msa, + batch, + text_sequence_length, + dimensions + ); + const std::vector image_after_attention = bonsai_flux_apply_gated_residual( + image, + image_attention_update, + image_modulation.gate_msa, + batch, + image_sequence_length, + dimensions + ); + const std::vector normalized_text_mlp = bonsai_flux_apply_modulated_layer_norm( + text_after_attention, + text_modulation.shift_mlp, + text_modulation.scale_mlp, + batch, + text_sequence_length, + dimensions, + layer_norm_epsilon + ); + const std::vector normalized_image_mlp = bonsai_flux_apply_modulated_layer_norm( + image_after_attention, + image_modulation.shift_mlp, + image_modulation.scale_mlp, + batch, + image_sequence_length, + dimensions, + layer_norm_epsilon + ); + + return BonsaiFluxDoubleBlockReferenceOutput { + batch, + text_sequence_length, + image_sequence_length, + dimensions, + normalized_text_msa, + normalized_image_msa, + attention_text, + attention_image, + normalized_text_mlp, + normalized_image_mlp, + bonsai_flux_apply_gated_residual( + text_after_attention, + text_mlp_update, + text_modulation.gate_mlp, + batch, + text_sequence_length, + dimensions + ), + bonsai_flux_apply_gated_residual( + image_after_attention, + image_mlp_update, + image_modulation.gate_mlp, + batch, + image_sequence_length, + dimensions + ), + }; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.h new file mode 100644 index 000000000..36e8b0df8 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_double_block.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +struct BonsaiFluxDoubleBlockReferenceOutput { + uint64_t batch = 0; + uint64_t text_sequence_length = 0; + uint64_t image_sequence_length = 0; + uint64_t dimensions = 0; + std::vector normalized_text_msa; + std::vector normalized_image_msa; + std::vector attention_text; + std::vector attention_image; + std::vector normalized_text_mlp; + std::vector normalized_image_mlp; + std::vector text_output; + std::vector image_output; +}; + +BonsaiFluxDoubleBlockReferenceOutput bonsai_flux_double_block_reference( + const std::vector& text, + const std::vector& image, + const std::vector& text_modulation_values, + const std::vector& image_modulation_values, + const std::vector& text_query_projection, + const std::vector& text_key_projection, + const std::vector& text_value_projection, + const std::vector& image_query_projection, + const std::vector& image_key_projection, + const std::vector& image_value_projection, + const std::vector& text_attention_update, + const std::vector& image_attention_update, + const std::vector& text_mlp_update, + const std::vector& image_mlp_update, + const std::vector& text_query_norm_weight, + const std::vector& text_key_norm_weight, + const std::vector& image_query_norm_weight, + const std::vector& image_key_norm_weight, + const std::vector& text_cos_values, + const std::vector& text_sin_values, + const std::vector& image_cos_values, + const std::vector& image_sin_values, + uint64_t batch, + uint64_t text_sequence_length, + uint64_t image_sequence_length, + uint64_t heads, + uint64_t head_dimension, + float layer_norm_epsilon, + float rms_norm_epsilon +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.cpp new file mode 100644 index 000000000..b4855edd7 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.cpp @@ -0,0 +1,247 @@ +#include "bonsai_flux_modulation.h" + +#include +#include +#include +#include + +namespace { + +size_t checked_multiply_size(uint64_t left, uint64_t right, const char* label) { + const uint64_t limit = static_cast(std::numeric_limits::max()); + if (left != 0 && right > limit / left) { + throw std::runtime_error(std::string("Bonsai Flux modulation size overflow: ") + label); + } + return static_cast(left * right); +} + +size_t checked_modulation_size( + uint64_t batch, + uint64_t chunks, + uint64_t dimensions, + const char* label +) { + const size_t batch_chunks = checked_multiply_size(batch, chunks, label); + if (dimensions != 0 && + batch_chunks > std::numeric_limits::max() / static_cast(dimensions)) { + throw std::runtime_error(std::string("Bonsai Flux modulation size overflow: ") + label); + } + return batch_chunks * static_cast(dimensions); +} + +void require_positive_shape(uint64_t batch, uint64_t dimensions, const char* label) { + if (batch == 0 || dimensions == 0) { + throw std::runtime_error(std::string("Bonsai Flux modulation shape must be positive: ") + label); + } +} + +void copy_chunk( + const std::vector& values, + uint64_t batch, + uint64_t chunks, + uint64_t dimensions, + uint64_t chunk, + std::vector* output +) { + output->assign(checked_multiply_size(batch, dimensions, "chunk"), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t index = 0; index < dimensions; index++) { + const size_t source_index = static_cast( + (batch_index * chunks + chunk) * dimensions + index + ); + const size_t target_index = static_cast( + batch_index * dimensions + index + ); + (*output)[target_index] = values[source_index]; + } + } +} + +void require_input_shape( + const std::vector& input, + const std::vector& shift, + const std::vector& scale, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions +) { + if (batch == 0 || sequence_length == 0 || dimensions == 0) { + throw std::runtime_error("Bonsai Flux modulated LayerNorm shape must be positive."); + } + const size_t input_size = checked_modulation_size( + batch, + sequence_length, + dimensions, + "layer norm" + ); + const size_t modulation_size = checked_multiply_size(batch, dimensions, "layer norm modulation"); + if (input.size() != input_size || + shift.size() != modulation_size || + scale.size() != modulation_size) { + throw std::runtime_error("Bonsai Flux modulated LayerNorm input shape mismatch."); + } +} + +} // namespace + +BonsaiFluxSingleModulation bonsai_flux_split_single_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +) { + require_positive_shape(batch, dimensions, "single"); + if (values.size() != checked_modulation_size(batch, 3, dimensions, "single")) { + throw std::runtime_error("Bonsai Flux single modulation shape mismatch."); + } + + BonsaiFluxSingleModulation output { + batch, + dimensions, + {}, + {}, + {}, + }; + copy_chunk(values, batch, 3, dimensions, 0, &output.shift); + copy_chunk(values, batch, 3, dimensions, 1, &output.scale); + copy_chunk(values, batch, 3, dimensions, 2, &output.gate); + return output; +} + +BonsaiFluxDoubleModulation bonsai_flux_split_double_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +) { + require_positive_shape(batch, dimensions, "double"); + if (values.size() != checked_modulation_size(batch, 6, dimensions, "double")) { + throw std::runtime_error("Bonsai Flux double modulation shape mismatch."); + } + + BonsaiFluxDoubleModulation output { + batch, + dimensions, + {}, + {}, + {}, + {}, + {}, + {}, + }; + copy_chunk(values, batch, 6, dimensions, 0, &output.shift_msa); + copy_chunk(values, batch, 6, dimensions, 1, &output.scale_msa); + copy_chunk(values, batch, 6, dimensions, 2, &output.gate_msa); + copy_chunk(values, batch, 6, dimensions, 3, &output.shift_mlp); + copy_chunk(values, batch, 6, dimensions, 4, &output.scale_mlp); + copy_chunk(values, batch, 6, dimensions, 5, &output.gate_mlp); + return output; +} + +BonsaiFluxNormOutModulation bonsai_flux_split_norm_out_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +) { + require_positive_shape(batch, dimensions, "norm out"); + if (values.size() != checked_modulation_size(batch, 2, dimensions, "norm out")) { + throw std::runtime_error("Bonsai Flux norm-out modulation shape mismatch."); + } + + BonsaiFluxNormOutModulation output { + batch, + dimensions, + {}, + {}, + }; + copy_chunk(values, batch, 2, dimensions, 0, &output.scale); + copy_chunk(values, batch, 2, dimensions, 1, &output.shift); + return output; +} + +std::vector bonsai_flux_apply_modulated_layer_norm( + const std::vector& input, + const std::vector& shift, + const std::vector& scale, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + float epsilon +) { + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai Flux modulated LayerNorm epsilon must be finite and positive."); + } + require_input_shape(input, shift, scale, batch, sequence_length, dimensions); + + std::vector output(input.size(), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + const size_t row_offset = static_cast( + (batch_index * sequence_length + token) * dimensions + ); + const size_t modulation_offset = static_cast(batch_index * dimensions); + + double mean = 0.0; + for (uint64_t index = 0; index < dimensions; index++) { + mean += static_cast(input[row_offset + static_cast(index)]); + } + mean /= static_cast(dimensions); + + double variance = 0.0; + for (uint64_t index = 0; index < dimensions; index++) { + const double centered = + static_cast(input[row_offset + static_cast(index)]) - mean; + variance += centered * centered; + } + variance /= static_cast(dimensions); + + const float norm_scale = 1.0F / std::sqrt(static_cast(variance) + epsilon); + for (uint64_t index = 0; index < dimensions; index++) { + const size_t input_index = row_offset + static_cast(index); + const size_t modulation_index = modulation_offset + static_cast(index); + const float normalized = + (input[input_index] - static_cast(mean)) * norm_scale; + output[input_index] = + normalized * (1.0F + scale[modulation_index]) + shift[modulation_index]; + } + } + } + return output; +} + +std::vector bonsai_flux_apply_gated_residual( + const std::vector& residual, + const std::vector& update, + const std::vector& gate, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions +) { + if (batch == 0 || sequence_length == 0 || dimensions == 0) { + throw std::runtime_error("Bonsai Flux gated residual shape must be positive."); + } + const size_t input_size = checked_modulation_size( + batch, + sequence_length, + dimensions, + "gated residual" + ); + const size_t gate_size = checked_multiply_size(batch, dimensions, "gated residual gate"); + if (residual.size() != input_size || update.size() != input_size || gate.size() != gate_size) { + throw std::runtime_error("Bonsai Flux gated residual input shape mismatch."); + } + + std::vector output(residual.size(), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < sequence_length; token++) { + const size_t row_offset = static_cast( + (batch_index * sequence_length + token) * dimensions + ); + const size_t gate_offset = static_cast(batch_index * dimensions); + for (uint64_t index = 0; index < dimensions; index++) { + const size_t value_index = row_offset + static_cast(index); + const size_t gate_index = gate_offset + static_cast(index); + output[value_index] = residual[value_index] + gate[gate_index] * update[value_index]; + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.h new file mode 100644 index 000000000..138c72bc6 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_modulation.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +struct BonsaiFluxSingleModulation { + uint64_t batch = 0; + uint64_t dimensions = 0; + std::vector shift; + std::vector scale; + std::vector gate; +}; + +struct BonsaiFluxDoubleModulation { + uint64_t batch = 0; + uint64_t dimensions = 0; + std::vector shift_msa; + std::vector scale_msa; + std::vector gate_msa; + std::vector shift_mlp; + std::vector scale_mlp; + std::vector gate_mlp; +}; + +struct BonsaiFluxNormOutModulation { + uint64_t batch = 0; + uint64_t dimensions = 0; + std::vector scale; + std::vector shift; +}; + +BonsaiFluxSingleModulation bonsai_flux_split_single_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +); + +BonsaiFluxDoubleModulation bonsai_flux_split_double_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +); + +BonsaiFluxNormOutModulation bonsai_flux_split_norm_out_modulation( + const std::vector& values, + uint64_t batch, + uint64_t dimensions +); + +std::vector bonsai_flux_apply_modulated_layer_norm( + const std::vector& input, + const std::vector& shift, + const std::vector& scale, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + float epsilon +); + +std::vector bonsai_flux_apply_gated_residual( + const std::vector& residual, + const std::vector& update, + const std::vector& gate, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.cpp new file mode 100644 index 000000000..2f941fc5b --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.cpp @@ -0,0 +1,65 @@ +#include "bonsai_flux_output.h" + +#include "bonsai_flux_modulation.h" + +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai Flux output shape overflow: ") + label); + } + return left * right; +} + +void require_positive(uint64_t value, const char* label) { + if (value == 0) { + throw std::runtime_error(std::string("Bonsai Flux output value must be positive: ") + label); + } +} + +} // namespace + +std::vector bonsai_flux_final_projection_input( + const std::vector& image_tokens, + const std::vector& norm_out_modulation_values, + uint64_t batch, + uint64_t image_sequence_length, + uint64_t dimensions, + float layer_norm_epsilon +) { + require_positive(batch, "batch"); + require_positive(image_sequence_length, "image sequence"); + require_positive(dimensions, "dimensions"); + if (layer_norm_epsilon <= 0.0F || !std::isfinite(layer_norm_epsilon)) { + throw std::runtime_error("Bonsai Flux output LayerNorm epsilon must be finite and positive."); + } + + const uint64_t expected = checked_multiply( + checked_multiply(batch, image_sequence_length, "image tokens"), + dimensions, + "image tokens" + ); + if (image_tokens.size() != static_cast(expected)) { + throw std::runtime_error("Bonsai Flux output image token shape mismatch."); + } + + const BonsaiFluxNormOutModulation modulation = bonsai_flux_split_norm_out_modulation( + norm_out_modulation_values, + batch, + dimensions + ); + return bonsai_flux_apply_modulated_layer_norm( + image_tokens, + modulation.shift, + modulation.scale, + batch, + image_sequence_length, + dimensions, + layer_norm_epsilon + ); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.h new file mode 100644 index 000000000..ad9848084 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_output.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +std::vector bonsai_flux_final_projection_input( + const std::vector& image_tokens, + const std::vector& norm_out_modulation_values, + uint64_t batch, + uint64_t image_sequence_length, + uint64_t dimensions, + float layer_norm_epsilon +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.cpp new file mode 100644 index 000000000..00e5d0259 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.cpp @@ -0,0 +1,48 @@ +#include "bonsai_flux_pos_embed.h" + +#include +#include + +namespace { + +constexpr float FLUX_POS_THETA = 2000.0F; +constexpr uint64_t FLUX_AXIS_COUNT = 4; +constexpr uint64_t FLUX_AXIS_DIMENSION = 32; +constexpr uint64_t FLUX_AXIS_HALF_DIMENSION = FLUX_AXIS_DIMENSION / 2; +constexpr uint64_t FLUX_ROTARY_DIMENSION = FLUX_AXIS_COUNT * FLUX_AXIS_HALF_DIMENSION; + +float flux_axis_omega(uint64_t index) { + const float scale = static_cast(index * 2U) / static_cast(FLUX_AXIS_DIMENSION); + return 1.0F / std::pow(FLUX_POS_THETA, scale); +} + +} // namespace + +BonsaiFluxRotaryEmbedding bonsai_flux_pos_embed( + const std::vector>& ids +) { + if (ids.empty()) { + throw std::runtime_error("Bonsai Flux position ids must not be empty."); + } + + BonsaiFluxRotaryEmbedding output { + static_cast(ids.size()), + FLUX_ROTARY_DIMENSION, + {}, + {}, + }; + output.cos.reserve(ids.size() * static_cast(FLUX_ROTARY_DIMENSION)); + output.sin.reserve(ids.size() * static_cast(FLUX_ROTARY_DIMENSION)); + + for (const std::array& id : ids) { + for (uint64_t axis = 0; axis < FLUX_AXIS_COUNT; axis++) { + for (uint64_t index = 0; index < FLUX_AXIS_HALF_DIMENSION; index++) { + const float value = id[static_cast(axis)] * flux_axis_omega(index); + output.cos.push_back(std::cos(value)); + output.sin.push_back(std::sin(value)); + } + } + } + + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.h new file mode 100644 index 000000000..5436f3254 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_pos_embed.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include + +struct BonsaiFluxRotaryEmbedding { + uint64_t token_count = 0; + uint64_t dimensions = 0; + std::vector cos; + std::vector sin; +}; + +BonsaiFluxRotaryEmbedding bonsai_flux_pos_embed( + const std::vector>& ids +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.cpp new file mode 100644 index 000000000..5e2475bc9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.cpp @@ -0,0 +1,71 @@ +#include "bonsai_flux_rope.h" + +#include +#include + +namespace { + +uint64_t attention_size(uint64_t heads, uint64_t sequence_length, uint64_t head_dimension) { + return heads * sequence_length * head_dimension; +} + +} // namespace + +std::vector bonsai_flux_apply_rms_norm_and_rope( + const std::vector& input, + const std::vector& norm_weight, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t heads, + uint64_t sequence_length, + uint64_t head_dimension, + float epsilon +) { + if (heads == 0 || sequence_length == 0 || head_dimension == 0 || head_dimension % 2 != 0) { + throw std::runtime_error("Bonsai Flux RoPE dimensions must be positive and even."); + } + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai Flux RoPE epsilon must be finite and positive."); + } + + const uint64_t expected_input = attention_size(heads, sequence_length, head_dimension); + const uint64_t rotary_dimension = head_dimension / 2; + const uint64_t expected_rotary = sequence_length * rotary_dimension; + if (input.size() != static_cast(expected_input) || + norm_weight.size() != static_cast(head_dimension) || + cos_values.size() != static_cast(expected_rotary) || + sin_values.size() != static_cast(expected_rotary)) { + throw std::runtime_error("Bonsai Flux RoPE input shape mismatch."); + } + + std::vector output(input.size(), 0.0F); + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t position = 0; position < sequence_length; position++) { + const uint64_t offset = (head * sequence_length + position) * head_dimension; + double mean_square = 0.0; + for (uint64_t index = 0; index < head_dimension; index++) { + const float value = input[static_cast(offset + index)]; + mean_square += static_cast(value) * static_cast(value); + } + mean_square /= static_cast(head_dimension); + const float scale = 1.0F / std::sqrt(static_cast(mean_square) + epsilon); + + for (uint64_t pair = 0; pair < rotary_dimension; pair++) { + const size_t real_index = static_cast(offset + pair * 2); + const size_t imag_index = static_cast(real_index + 1); + const size_t rotary_index = static_cast( + position * rotary_dimension + pair + ); + + const float real = input[real_index] * scale * norm_weight[static_cast(pair * 2)]; + const float imag = input[imag_index] * scale * + norm_weight[static_cast(pair * 2 + 1)]; + const float cos_value = cos_values[rotary_index]; + const float sin_value = sin_values[rotary_index]; + output[real_index] = real * cos_value - imag * sin_value; + output[imag_index] = imag * cos_value + real * sin_value; + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.h new file mode 100644 index 000000000..c6578f9b9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_rope.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +std::vector bonsai_flux_apply_rms_norm_and_rope( + const std::vector& input, + const std::vector& norm_weight, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t heads, + uint64_t sequence_length, + uint64_t head_dimension, + float epsilon +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.cpp new file mode 100644 index 000000000..2caf1a42a --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.cpp @@ -0,0 +1,183 @@ +#include "bonsai_flux_single_block.h" + +#include "bonsai_activation.h" +#include "bonsai_attention.h" +#include "bonsai_flux_attention_layout.h" +#include "bonsai_flux_modulation.h" +#include "bonsai_flux_rope.h" + +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai Flux single block shape overflow: ") + label); + } + return left * right; +} + +uint64_t checked_add(uint64_t left, uint64_t right, const char* label) { + if (left > std::numeric_limits::max() - right) { + throw std::runtime_error(std::string("Bonsai Flux single block shape overflow: ") + label); + } + return left + right; +} + +size_t checked_size_3(uint64_t first, uint64_t second, uint64_t third, const char* label) { + const uint64_t first_second = checked_multiply(first, second, label); + const uint64_t total = checked_multiply(first_second, third, label); + if (total > static_cast(std::numeric_limits::max())) { + throw std::runtime_error(std::string("Bonsai Flux single block shape overflow: ") + label); + } + return static_cast(total); +} + +void require_positive(uint64_t value, const char* label) { + if (value == 0) { + throw std::runtime_error(std::string("Bonsai Flux single block value must be positive: ") + label); + } +} + +void require_finite_positive(float value, const char* label) { + if (value <= 0.0F || !std::isfinite(value)) { + throw std::runtime_error(std::string("Bonsai Flux single block value must be finite and positive: ") + label); + } +} + +} // namespace + +BonsaiFluxSingleBlockReferenceOutput bonsai_flux_single_block_reference( + const std::vector& hidden, + const std::vector& modulation_values, + const std::vector& fused_projection, + const std::vector& projected_update, + const std::vector& norm_q_weight, + const std::vector& norm_k_weight, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension, + uint64_t mlp_hidden_dimensions, + float layer_norm_epsilon, + float rms_norm_epsilon +) { + require_positive(batch, "batch"); + require_positive(sequence_length, "sequence"); + require_positive(heads, "heads"); + require_positive(head_dimension, "head dimension"); + require_positive(mlp_hidden_dimensions, "mlp hidden"); + require_finite_positive(layer_norm_epsilon, "layer norm epsilon"); + require_finite_positive(rms_norm_epsilon, "rms norm epsilon"); + + const uint64_t dimensions = checked_multiply(heads, head_dimension, "hidden dimensions"); + const size_t hidden_size = checked_size_3(batch, sequence_length, dimensions, "hidden"); + if (hidden.size() != hidden_size || projected_update.size() != hidden_size) { + throw std::runtime_error("Bonsai Flux single block hidden/update shape mismatch."); + } + if (norm_q_weight.size() != static_cast(head_dimension) || + norm_k_weight.size() != static_cast(head_dimension)) { + throw std::runtime_error("Bonsai Flux single block norm weight shape mismatch."); + } + + const BonsaiFluxSingleModulation modulation = bonsai_flux_split_single_modulation( + modulation_values, + batch, + dimensions + ); + const std::vector normed = bonsai_flux_apply_modulated_layer_norm( + hidden, + modulation.shift, + modulation.scale, + batch, + sequence_length, + dimensions, + layer_norm_epsilon + ); + const BonsaiFluxSingleProjectionParts parts = bonsai_flux_split_single_projection( + fused_projection, + batch, + sequence_length, + dimensions, + mlp_hidden_dimensions + ); + + const uint64_t batch_heads = checked_multiply(batch, heads, "batch heads"); + const std::vector query = bonsai_flux_apply_rms_norm_and_rope( + bonsai_flux_sequence_to_heads(parts.query, batch, sequence_length, heads, head_dimension), + norm_q_weight, + cos_values, + sin_values, + batch_heads, + sequence_length, + head_dimension, + rms_norm_epsilon + ); + const std::vector key = bonsai_flux_apply_rms_norm_and_rope( + bonsai_flux_sequence_to_heads(parts.key, batch, sequence_length, heads, head_dimension), + norm_k_weight, + cos_values, + sin_values, + batch_heads, + sequence_length, + head_dimension, + rms_norm_epsilon + ); + const std::vector value = bonsai_flux_sequence_to_heads( + parts.value, + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector attended = bonsai_flux_heads_to_sequence( + bonsai_scaled_dot_product_attention( + query, + key, + value, + {}, + batch_heads, + sequence_length, + head_dimension, + 1.0F / std::sqrt(static_cast(head_dimension)) + ), + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector mlp_output = bonsai_swiglu_last_dimension( + parts.mlp_values, + checked_multiply(mlp_hidden_dimensions, 2U, "mlp width") + ); + const std::vector out_projection_input = bonsai_flux_concat_last_dimension( + attended, + mlp_output, + batch, + sequence_length, + dimensions, + mlp_hidden_dimensions + ); + + return BonsaiFluxSingleBlockReferenceOutput { + batch, + sequence_length, + dimensions, + checked_add(dimensions, mlp_hidden_dimensions, "out projection input"), + normed, + out_projection_input, + bonsai_flux_apply_gated_residual( + hidden, + projected_update, + modulation.gate, + batch, + sequence_length, + dimensions + ), + }; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.h new file mode 100644 index 000000000..46b78b681 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_single_block.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +struct BonsaiFluxSingleBlockReferenceOutput { + uint64_t batch = 0; + uint64_t sequence_length = 0; + uint64_t dimensions = 0; + uint64_t out_projection_input_dimensions = 0; + std::vector normalized_hidden; + std::vector out_projection_input; + std::vector residual_output; +}; + +BonsaiFluxSingleBlockReferenceOutput bonsai_flux_single_block_reference( + const std::vector& hidden, + const std::vector& modulation_values, + const std::vector& fused_projection, + const std::vector& projected_update, + const std::vector& norm_q_weight, + const std::vector& norm_k_weight, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension, + uint64_t mlp_hidden_dimensions, + float layer_norm_epsilon, + float rms_norm_epsilon +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.cpp new file mode 100644 index 000000000..8072eb0b2 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.cpp @@ -0,0 +1,50 @@ +#include "bonsai_flux_time_embedding.h" + +#include +#include + +namespace { + +constexpr float FLUX_TIMESTEP_THETA = 10000.0F; + +} // namespace + +BonsaiFluxTimestepEmbedding bonsai_flux_timestep_embedding( + const std::vector& timesteps, + uint64_t dimensions +) { + if (timesteps.empty()) { + throw std::runtime_error("Bonsai Flux timesteps must not be empty."); + } + if (dimensions == 0 || dimensions % 2 != 0) { + throw std::runtime_error("Bonsai Flux timestep dimensions must be positive and even."); + } + + const uint64_t half = dimensions / 2; + BonsaiFluxTimestepEmbedding output { + static_cast(timesteps.size()), + dimensions, + {}, + }; + output.values.reserve(timesteps.size() * static_cast(dimensions)); + + std::vector frequencies; + frequencies.reserve(static_cast(half)); + for (uint64_t index = 0; index < half; index++) { + frequencies.push_back( + std::exp(-std::log(FLUX_TIMESTEP_THETA) * static_cast(index) / + static_cast(half)) + ); + } + + for (float timestep : timesteps) { + for (float frequency : frequencies) { + output.values.push_back(std::cos(timestep * frequency)); + } + for (float frequency : frequencies) { + output.values.push_back(std::sin(timestep * frequency)); + } + } + + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.h new file mode 100644 index 000000000..c49ce7a8f --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_time_embedding.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +struct BonsaiFluxTimestepEmbedding { + uint64_t timestep_count = 0; + uint64_t dimensions = 0; + std::vector values; +}; + +BonsaiFluxTimestepEmbedding bonsai_flux_timestep_embedding( + const std::vector& timesteps, + uint64_t dimensions +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.cpp new file mode 100644 index 000000000..9a0612b61 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.cpp @@ -0,0 +1,1255 @@ +#include "bonsai_flux_transformer.h" + +#include "bonsai_activation.h" +#include "bonsai_attention.h" +#include "bonsai_flux_attention_layout.h" +#include "bonsai_flux_modulation.h" +#include "bonsai_flux_output.h" +#include "bonsai_flux_pos_embed.h" +#include "bonsai_flux_rope.h" +#include "bonsai_flux_time_embedding.h" +#include "bonsai_linear.h" +#include "bonsai_norm.h" +#include "bonsai_tensor.h" + +#include + +#include +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; +constexpr uint64_t FLUX_DOUBLE_BLOCKS = 5; +constexpr uint64_t FLUX_SINGLE_BLOCKS = 20; +constexpr uint64_t FLUX_DIM = 3072; +constexpr uint64_t FLUX_TEXT_HIDDEN = 7680; +constexpr uint64_t FLUX_LATENT_CHANNELS = 128; +constexpr uint64_t FLUX_TIME_EMBED = 256; +constexpr uint64_t FLUX_MLP_HIDDEN = 9216; + +uint64_t checked_multiply(uint64_t left, uint64_t right, const std::string& key) { + if (left != 0 && right > UINT64_MAX / left) { + throw std::runtime_error("Bonsai Flux tensor shape is too large: " + key); + } + return left * right; +} + +void log_flux_block_phase( + const char* phase, + uint64_t block, + uint64_t text_sequence_length, + uint64_t image_sequence_length +) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=%s block=%llu text=%llu image=%llu", + phase, + static_cast(block), + static_cast(text_sequence_length), + static_cast(image_sequence_length) + ); +} + +uint64_t leading_rows(const BonsaiTensorDescriptor& descriptor) { + if (descriptor.shape.empty()) { + throw std::runtime_error("Bonsai Flux tensor must have dimensions: " + descriptor.key); + } + uint64_t rows = 1; + for (size_t index = 0; index + 1 < descriptor.shape.size(); index++) { + rows = checked_multiply(rows, descriptor.shape[index], descriptor.key); + } + return rows; +} + +uint64_t trailing_columns(const BonsaiTensorDescriptor& descriptor) { + if (descriptor.shape.empty()) { + throw std::runtime_error("Bonsai Flux tensor must have dimensions: " + descriptor.key); + } + return descriptor.shape.back(); +} + +void require_dense_linear( + const BonsaiSafetensorsIndex& index, + const std::string& key, + uint64_t expected_input, + uint64_t expected_output, + uint64_t* count +) { + const BonsaiTensorDescriptor& descriptor = index.require(key); + if (leading_rows(descriptor) != expected_output || + trailing_columns(descriptor) != expected_input) { + throw std::runtime_error("Bonsai Flux dense linear shape mismatch: " + key); + } + (*count)++; +} + +void require_norm( + const BonsaiSafetensorsIndex& index, + const std::string& key, + uint64_t expected_elements, + uint64_t* count +) { + const BonsaiTensorDescriptor& descriptor = index.require(key); + if (bonsai_shape_element_count(descriptor.shape, key) != expected_elements) { + throw std::runtime_error("Bonsai Flux norm shape mismatch: " + key); + } + (*count)++; +} + +void require_packed_linear( + const BonsaiSafetensorsIndex& index, + const std::string& key, + int bits, + int group_size, + uint64_t expected_input, + uint64_t expected_output, + uint64_t* count +) { + const BonsaiPackedWeightDescriptor descriptor = index.require_packed_weight( + key, + bits, + group_size + ); + if (!descriptor.packed) { + const BonsaiTensorDescriptor& dense = index.require(key); + if (leading_rows(dense) != expected_output || trailing_columns(dense) != expected_input) { + throw std::runtime_error("Bonsai Flux dense fallback shape mismatch: " + key); + } + } else { + const BonsaiTensorDescriptor& scales = index.require(descriptor.scales_key); + if (leading_rows(scales) != expected_output || + checked_multiply( + trailing_columns(scales), + static_cast(group_size), + scales.key + ) != expected_input) { + throw std::runtime_error("Bonsai Flux packed linear shape mismatch: " + key); + } + } + (*count)++; +} + +BonsaiLinearViews require_dense_linear_view_checked( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& key, + uint64_t expected_input, + uint64_t expected_output +) { + BonsaiLinearViews views = bonsai_require_dense_linear_views( + storage, + index, + key, + key.substr(0, key.size() - std::string(".weight").size()) + ".bias" + ); + if (views.input_values != expected_input || views.output_rows != expected_output) { + throw std::runtime_error("Bonsai Flux dense linear view shape mismatch: " + key); + } + return views; +} + +BonsaiLinearViews require_packed_linear_view_checked( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& key, + int bits, + int group_size, + uint64_t expected_input, + uint64_t expected_output +) { + BonsaiLinearViews views = bonsai_require_packed_linear_views( + storage, + index, + index.require_packed_weight(key, bits, group_size), + key.substr(0, key.size() - std::string(".weight").size()) + ".bias" + ); + if (views.input_values != expected_input || views.output_rows != expected_output) { + throw std::runtime_error("Bonsai Flux packed linear view shape mismatch: " + key); + } + return views; +} + +BonsaiRmsNormWeightViews require_norm_view_checked( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& key, + uint64_t expected_elements +) { + BonsaiRmsNormWeightViews views = bonsai_require_rms_norm_weight(storage, index, key); + if (views.dimensions != expected_elements) { + throw std::runtime_error("Bonsai Flux norm view shape mismatch: " + key); + } + return views; +} + +void add_bytes(uint64_t* bytes, uint64_t extra, const char* label) { + if (*bytes > std::numeric_limits::max() - extra) { + throw std::runtime_error(std::string("Bonsai Flux byte count overflow: ") + label); + } + *bytes += extra; +} + +uint64_t checked_add(uint64_t left, uint64_t right, const std::string& key) { + if (left > std::numeric_limits::max() - right) { + throw std::runtime_error("Bonsai Flux tensor shape is too large: " + key); + } + return left + right; +} + +size_t checked_size(uint64_t value, const std::string& key) { + if (value > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai Flux tensor shape is too large: " + key); + } + return static_cast(value); +} + +size_t checked_size_3(uint64_t first, uint64_t second, uint64_t third, const std::string& key) { + return checked_size( + checked_multiply(checked_multiply(first, second, key), third, key), + key + ); +} + +void require_sequence_shape( + const std::vector& values, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + const std::string& key +) { + if (values.size() != checked_size_3(batch, sequence_length, dimensions, key)) { + throw std::runtime_error("Bonsai Flux sequence shape mismatch: " + key); + } +} + +std::vector norm_weight_values(const BonsaiRmsNormWeightViews& views) { + return bonsai_tensor_view_to_f32_vector(views.weight); +} + +std::vector linear_batch_vector( + const BonsaiLinearViews& views, + const std::vector& input, + uint64_t batch +) { + return bonsai_linear_sequence(views, input, batch, 1); +} + +std::vector concat_token_sequences( + const std::vector& first, + const std::vector& second, + uint64_t batch, + uint64_t first_sequence_length, + uint64_t second_sequence_length, + uint64_t dimensions +) { + require_sequence_shape(first, batch, first_sequence_length, dimensions, "concat first"); + require_sequence_shape(second, batch, second_sequence_length, dimensions, "concat second"); + const uint64_t output_sequence_length = checked_add( + first_sequence_length, + second_sequence_length, + "concat sequence" + ); + std::vector output( + checked_size_3(batch, output_sequence_length, dimensions, "concat output"), + 0.0F + ); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < first_sequence_length; token++) { + for (uint64_t column = 0; column < dimensions; column++) { + output[static_cast( + (batch_index * output_sequence_length + token) * dimensions + column + )] = first[static_cast( + (batch_index * first_sequence_length + token) * dimensions + column + )]; + } + } + for (uint64_t token = 0; token < second_sequence_length; token++) { + for (uint64_t column = 0; column < dimensions; column++) { + output[static_cast( + ( + batch_index * output_sequence_length + + first_sequence_length + + token + ) * dimensions + column + )] = second[static_cast( + (batch_index * second_sequence_length + token) * dimensions + column + )]; + } + } + } + return output; +} + +std::vector slice_token_sequence( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t dimensions, + uint64_t start_token, + uint64_t token_count +) { + require_sequence_shape(input, batch, sequence_length, dimensions, "slice input"); + if (start_token > sequence_length || token_count > sequence_length - start_token) { + throw std::runtime_error("Bonsai Flux token slice is out of range."); + } + std::vector output( + checked_size_3(batch, token_count, dimensions, "slice output"), + 0.0F + ); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t token = 0; token < token_count; token++) { + for (uint64_t column = 0; column < dimensions; column++) { + output[static_cast( + (batch_index * token_count + token) * dimensions + column + )] = input[static_cast( + (batch_index * sequence_length + start_token + token) * dimensions + column + )]; + } + } + } + return output; +} + +std::vector concat_rotary_rows( + const BonsaiFluxRotaryEmbedding& first, + const BonsaiFluxRotaryEmbedding& second, + bool cosine +) { + if (first.dimensions != second.dimensions || first.dimensions == 0) { + throw std::runtime_error("Bonsai Flux rotary concat dimension mismatch."); + } + const std::vector& first_values = cosine ? first.cos : first.sin; + const std::vector& second_values = cosine ? second.cos : second.sin; + std::vector output; + output.reserve(first_values.size() + second_values.size()); + output.insert(output.end(), first_values.begin(), first_values.end()); + output.insert(output.end(), second_values.begin(), second_values.end()); + return output; +} + +std::vector flux_apply_projection_rope( + const std::vector& projection, + const BonsaiRmsNormWeightViews& norm, + const std::vector& cos_values, + const std::vector& sin_values, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + return bonsai_flux_apply_rms_norm_and_rope( + bonsai_flux_sequence_to_heads(projection, batch, sequence_length, heads, head_dimension), + norm_weight_values(norm), + cos_values, + sin_values, + checked_multiply(batch, heads, "flux rope batch heads"), + sequence_length, + head_dimension, + 1e-5F + ); +} + +std::vector flux_swiglu_projection( + const BonsaiLinearViews& input_projection, + const BonsaiLinearViews& output_projection, + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t mlp_hidden_dimensions +) { + return bonsai_linear_sequence( + output_projection, + bonsai_swiglu_last_dimension( + bonsai_linear_sequence(input_projection, input, batch, sequence_length), + checked_multiply(mlp_hidden_dimensions, 2U, "flux swiglu width") + ), + batch, + sequence_length + ); +} + +void apply_flux_double_block( + const BonsaiFluxDoubleBlockViews& block, + const std::vector& text_modulation_values, + const std::vector& image_modulation_values, + const BonsaiFluxRotaryEmbedding& text_rotary, + const BonsaiFluxRotaryEmbedding& image_rotary, + uint64_t batch, + std::vector* text, + std::vector* image +) { + const uint64_t text_sequence_length = text_rotary.token_count; + const uint64_t image_sequence_length = image_rotary.token_count; + const uint64_t dimensions = block.dimensions; + require_sequence_shape(*text, batch, text_sequence_length, dimensions, "double text"); + require_sequence_shape(*image, batch, image_sequence_length, dimensions, "double image"); + + const BonsaiFluxDoubleModulation text_modulation = bonsai_flux_split_double_modulation( + text_modulation_values, + batch, + dimensions + ); + const BonsaiFluxDoubleModulation image_modulation = bonsai_flux_split_double_modulation( + image_modulation_values, + batch, + dimensions + ); + const std::vector normalized_text_msa = bonsai_flux_apply_modulated_layer_norm( + *text, + text_modulation.shift_msa, + text_modulation.scale_msa, + batch, + text_sequence_length, + dimensions, + 1e-6F + ); + const std::vector normalized_image_msa = bonsai_flux_apply_modulated_layer_norm( + *image, + image_modulation.shift_msa, + image_modulation.scale_msa, + batch, + image_sequence_length, + dimensions, + 1e-6F + ); + + const uint64_t combined_sequence_length = checked_add( + text_sequence_length, + image_sequence_length, + "double attention sequence" + ); + const uint64_t batch_heads = checked_multiply(batch, block.heads, "double attention heads"); + const std::vector full_queries = bonsai_flux_concat_head_sequences( + flux_apply_projection_rope( + bonsai_linear_sequence(block.add_q, normalized_text_msa, batch, text_sequence_length), + block.norm_added_q, + text_rotary.cos, + text_rotary.sin, + batch, + text_sequence_length, + block.heads, + block.head_dimension + ), + flux_apply_projection_rope( + bonsai_linear_sequence(block.to_q, normalized_image_msa, batch, image_sequence_length), + block.norm_q, + image_rotary.cos, + image_rotary.sin, + batch, + image_sequence_length, + block.heads, + block.head_dimension + ), + batch, + block.heads, + text_sequence_length, + image_sequence_length, + block.head_dimension + ); + const std::vector full_keys = bonsai_flux_concat_head_sequences( + flux_apply_projection_rope( + bonsai_linear_sequence(block.add_k, normalized_text_msa, batch, text_sequence_length), + block.norm_added_k, + text_rotary.cos, + text_rotary.sin, + batch, + text_sequence_length, + block.heads, + block.head_dimension + ), + flux_apply_projection_rope( + bonsai_linear_sequence(block.to_k, normalized_image_msa, batch, image_sequence_length), + block.norm_k, + image_rotary.cos, + image_rotary.sin, + batch, + image_sequence_length, + block.heads, + block.head_dimension + ), + batch, + block.heads, + text_sequence_length, + image_sequence_length, + block.head_dimension + ); + const std::vector full_values = bonsai_flux_concat_head_sequences( + bonsai_flux_sequence_to_heads( + bonsai_linear_sequence(block.add_v, normalized_text_msa, batch, text_sequence_length), + batch, + text_sequence_length, + block.heads, + block.head_dimension + ), + bonsai_flux_sequence_to_heads( + bonsai_linear_sequence(block.to_v, normalized_image_msa, batch, image_sequence_length), + batch, + image_sequence_length, + block.heads, + block.head_dimension + ), + batch, + block.heads, + text_sequence_length, + image_sequence_length, + block.head_dimension + ); + const BonsaiFluxHeadSequenceParts attention_parts = bonsai_flux_split_head_sequences( + bonsai_scaled_dot_product_attention( + full_queries, + full_keys, + full_values, + {}, + batch_heads, + combined_sequence_length, + block.head_dimension, + 1.0F / std::sqrt(static_cast(block.head_dimension)) + ), + batch, + block.heads, + text_sequence_length, + image_sequence_length, + block.head_dimension + ); + const std::vector text_attention_update = bonsai_linear_sequence( + block.to_add_out, + bonsai_flux_heads_to_sequence( + attention_parts.first, + batch, + text_sequence_length, + block.heads, + block.head_dimension + ), + batch, + text_sequence_length + ); + const std::vector image_attention_update = bonsai_linear_sequence( + block.to_out, + bonsai_flux_heads_to_sequence( + attention_parts.second, + batch, + image_sequence_length, + block.heads, + block.head_dimension + ), + batch, + image_sequence_length + ); + + std::vector text_after_attention = bonsai_flux_apply_gated_residual( + *text, + text_attention_update, + text_modulation.gate_msa, + batch, + text_sequence_length, + dimensions + ); + std::vector image_after_attention = bonsai_flux_apply_gated_residual( + *image, + image_attention_update, + image_modulation.gate_msa, + batch, + image_sequence_length, + dimensions + ); + const std::vector normalized_text_mlp = bonsai_flux_apply_modulated_layer_norm( + text_after_attention, + text_modulation.shift_mlp, + text_modulation.scale_mlp, + batch, + text_sequence_length, + dimensions, + 1e-6F + ); + const std::vector normalized_image_mlp = bonsai_flux_apply_modulated_layer_norm( + image_after_attention, + image_modulation.shift_mlp, + image_modulation.scale_mlp, + batch, + image_sequence_length, + dimensions, + 1e-6F + ); + + *text = bonsai_flux_apply_gated_residual( + text_after_attention, + flux_swiglu_projection( + block.ff_context_in, + block.ff_context_out, + normalized_text_mlp, + batch, + text_sequence_length, + block.mlp_hidden_dimensions + ), + text_modulation.gate_mlp, + batch, + text_sequence_length, + dimensions + ); + *image = bonsai_flux_apply_gated_residual( + image_after_attention, + flux_swiglu_projection( + block.ff_in, + block.ff_out, + normalized_image_mlp, + batch, + image_sequence_length, + block.mlp_hidden_dimensions + ), + image_modulation.gate_mlp, + batch, + image_sequence_length, + dimensions + ); +} + +std::vector apply_flux_single_block( + const BonsaiFluxSingleBlockViews& block, + const std::vector& modulation_values, + const std::vector& rotary_cos, + const std::vector& rotary_sin, + const std::vector& hidden, + uint64_t batch, + uint64_t sequence_length +) { + const uint64_t dimensions = block.dimensions; + require_sequence_shape(hidden, batch, sequence_length, dimensions, "single hidden"); + const BonsaiFluxSingleModulation modulation = bonsai_flux_split_single_modulation( + modulation_values, + batch, + dimensions + ); + const std::vector normalized = bonsai_flux_apply_modulated_layer_norm( + hidden, + modulation.shift, + modulation.scale, + batch, + sequence_length, + dimensions, + 1e-6F + ); + const BonsaiFluxSingleProjectionParts parts = bonsai_flux_split_single_projection( + bonsai_linear_sequence(block.qkv_mlp_proj, normalized, batch, sequence_length), + batch, + sequence_length, + dimensions, + block.mlp_hidden_dimensions + ); + const uint64_t batch_heads = checked_multiply(batch, block.heads, "single batch heads"); + const std::vector attended = bonsai_flux_heads_to_sequence( + bonsai_scaled_dot_product_attention( + flux_apply_projection_rope( + parts.query, + block.norm_q, + rotary_cos, + rotary_sin, + batch, + sequence_length, + block.heads, + block.head_dimension + ), + flux_apply_projection_rope( + parts.key, + block.norm_k, + rotary_cos, + rotary_sin, + batch, + sequence_length, + block.heads, + block.head_dimension + ), + bonsai_flux_sequence_to_heads( + parts.value, + batch, + sequence_length, + block.heads, + block.head_dimension + ), + {}, + batch_heads, + sequence_length, + block.head_dimension, + 1.0F / std::sqrt(static_cast(block.head_dimension)) + ), + batch, + sequence_length, + block.heads, + block.head_dimension + ); + const std::vector projection_input = bonsai_flux_concat_last_dimension( + attended, + bonsai_swiglu_last_dimension( + parts.mlp_values, + checked_multiply(block.mlp_hidden_dimensions, 2U, "single swiglu width") + ), + batch, + sequence_length, + dimensions, + block.mlp_hidden_dimensions + ); + return bonsai_flux_apply_gated_residual( + hidden, + bonsai_linear_sequence(block.out_proj, projection_input, batch, sequence_length), + modulation.gate, + batch, + sequence_length, + dimensions + ); +} + +BonsaiFluxDoubleBlockViews require_double_block_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + uint64_t block_index, + int bits, + int group_size +) { + const std::string block = "transformer_blocks." + std::to_string(block_index); + const std::string attn = block + ".attn"; + return BonsaiFluxDoubleBlockViews { + require_packed_linear_view_checked(storage, index, attn + ".to_q.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".to_k.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".to_v.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".add_q_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".add_k_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".add_v_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".to_out.0.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, attn + ".to_add_out.weight", bits, group_size, FLUX_DIM, FLUX_DIM), + require_packed_linear_view_checked(storage, index, block + ".ff.linear_in.weight", bits, group_size, FLUX_DIM, FLUX_MLP_HIDDEN * 2), + require_packed_linear_view_checked(storage, index, block + ".ff.linear_out.weight", bits, group_size, FLUX_MLP_HIDDEN, FLUX_DIM), + require_packed_linear_view_checked(storage, index, block + ".ff_context.linear_in.weight", bits, group_size, FLUX_DIM, FLUX_MLP_HIDDEN * 2), + require_packed_linear_view_checked(storage, index, block + ".ff_context.linear_out.weight", bits, group_size, FLUX_MLP_HIDDEN, FLUX_DIM), + require_norm_view_checked(storage, index, attn + ".norm_q.weight", 128), + require_norm_view_checked(storage, index, attn + ".norm_k.weight", 128), + require_norm_view_checked(storage, index, attn + ".norm_added_q.weight", 128), + require_norm_view_checked(storage, index, attn + ".norm_added_k.weight", 128), + FLUX_DIM, + 24, + 128, + FLUX_MLP_HIDDEN, + }; +} + +BonsaiFluxSingleBlockViews require_single_block_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + uint64_t block_index, + int bits, + int group_size +) { + const std::string attn = "single_transformer_blocks." + + std::to_string(block_index) + + ".attn"; + return BonsaiFluxSingleBlockViews { + require_packed_linear_view_checked( + storage, + index, + attn + ".to_qkv_mlp_proj.weight", + bits, + group_size, + FLUX_DIM, + FLUX_DIM * 3 + FLUX_MLP_HIDDEN * 2 + ), + require_packed_linear_view_checked( + storage, + index, + attn + ".to_out.weight", + bits, + group_size, + FLUX_DIM + FLUX_MLP_HIDDEN, + FLUX_DIM + ), + require_norm_view_checked(storage, index, attn + ".norm_q.weight", 128), + require_norm_view_checked(storage, index, attn + ".norm_k.weight", 128), + FLUX_DIM, + 24, + 128, + FLUX_MLP_HIDDEN, + }; +} + +void require_double_block( + const BonsaiSafetensorsIndex& index, + uint64_t block_index, + int bits, + int group_size, + uint64_t* count +) { + const std::string block = "transformer_blocks." + std::to_string(block_index); + const std::string attn = block + ".attn"; + + require_packed_linear(index, attn + ".to_q.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".to_k.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".to_v.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".add_q_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".add_k_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".add_v_proj.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".to_out.0.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear(index, attn + ".to_add_out.weight", bits, group_size, FLUX_DIM, FLUX_DIM, count); + require_packed_linear( + index, + block + ".ff.linear_in.weight", + bits, + group_size, + FLUX_DIM, + FLUX_MLP_HIDDEN * 2, + count + ); + require_packed_linear( + index, + block + ".ff.linear_out.weight", + bits, + group_size, + FLUX_MLP_HIDDEN, + FLUX_DIM, + count + ); + require_packed_linear( + index, + block + ".ff_context.linear_in.weight", + bits, + group_size, + FLUX_DIM, + FLUX_MLP_HIDDEN * 2, + count + ); + require_packed_linear( + index, + block + ".ff_context.linear_out.weight", + bits, + group_size, + FLUX_MLP_HIDDEN, + FLUX_DIM, + count + ); + + require_norm(index, attn + ".norm_q.weight", 128, count); + require_norm(index, attn + ".norm_k.weight", 128, count); + require_norm(index, attn + ".norm_added_q.weight", 128, count); + require_norm(index, attn + ".norm_added_k.weight", 128, count); +} + +void require_single_block( + const BonsaiSafetensorsIndex& index, + uint64_t block_index, + int bits, + int group_size, + uint64_t* count +) { + const std::string attn = "single_transformer_blocks." + + std::to_string(block_index) + + ".attn"; + + require_packed_linear( + index, + attn + ".to_qkv_mlp_proj.weight", + bits, + group_size, + FLUX_DIM, + FLUX_DIM * 3 + FLUX_MLP_HIDDEN * 2, + count + ); + require_packed_linear( + index, + attn + ".to_out.weight", + bits, + group_size, + FLUX_DIM + FLUX_MLP_HIDDEN, + FLUX_DIM, + count + ); + require_norm(index, attn + ".norm_q.weight", 128, count); + require_norm(index, attn + ".norm_k.weight", 128, count); +} + +} // namespace + +BonsaiFluxTransformerInventorySummary bonsai_require_flux_transformer_tensors( + const BonsaiSafetensorsIndex& index, + int bits, + int group_size +) { + BonsaiFluxTransformerInventorySummary summary { + FLUX_DOUBLE_BLOCKS, + FLUX_SINGLE_BLOCKS, + 0, + }; + + require_dense_linear( + index, + "x_embedder.weight", + FLUX_LATENT_CHANNELS, + FLUX_DIM, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "context_embedder.weight", + FLUX_TEXT_HIDDEN, + FLUX_DIM, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "norm_out.linear.weight", + FLUX_DIM, + FLUX_DIM * 2, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "proj_out.weight", + FLUX_DIM, + FLUX_LATENT_CHANNELS, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "time_guidance_embed.timestep_embedder.linear_1.weight", + FLUX_TIME_EMBED, + FLUX_DIM, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "time_guidance_embed.timestep_embedder.linear_2.weight", + FLUX_DIM, + FLUX_DIM, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "double_stream_modulation_img.linear.weight", + FLUX_DIM, + FLUX_DIM * 6, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "double_stream_modulation_txt.linear.weight", + FLUX_DIM, + FLUX_DIM * 6, + &summary.logical_tensor_count + ); + require_dense_linear( + index, + "single_stream_modulation.linear.weight", + FLUX_DIM, + FLUX_DIM * 3, + &summary.logical_tensor_count + ); + + for (uint64_t index_value = 0; index_value < FLUX_DOUBLE_BLOCKS; index_value++) { + require_double_block( + index, + index_value, + bits, + group_size, + &summary.logical_tensor_count + ); + } + for (uint64_t index_value = 0; index_value < FLUX_SINGLE_BLOCKS; index_value++) { + require_single_block( + index, + index_value, + bits, + group_size, + &summary.logical_tensor_count + ); + } + + return summary; +} + +BonsaiFluxTransformerViews bonsai_require_flux_transformer_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + int bits, + int group_size +) { + BonsaiFluxTransformerViews views { + require_dense_linear_view_checked(storage, index, "x_embedder.weight", FLUX_LATENT_CHANNELS, FLUX_DIM), + require_dense_linear_view_checked(storage, index, "context_embedder.weight", FLUX_TEXT_HIDDEN, FLUX_DIM), + require_dense_linear_view_checked(storage, index, "time_guidance_embed.timestep_embedder.linear_1.weight", FLUX_TIME_EMBED, FLUX_DIM), + require_dense_linear_view_checked(storage, index, "time_guidance_embed.timestep_embedder.linear_2.weight", FLUX_DIM, FLUX_DIM), + require_dense_linear_view_checked(storage, index, "double_stream_modulation_img.linear.weight", FLUX_DIM, FLUX_DIM * 6), + require_dense_linear_view_checked(storage, index, "double_stream_modulation_txt.linear.weight", FLUX_DIM, FLUX_DIM * 6), + require_dense_linear_view_checked(storage, index, "single_stream_modulation.linear.weight", FLUX_DIM, FLUX_DIM * 3), + require_dense_linear_view_checked(storage, index, "norm_out.linear.weight", FLUX_DIM, FLUX_DIM * 2), + require_dense_linear_view_checked(storage, index, "proj_out.weight", FLUX_DIM, FLUX_LATENT_CHANNELS), + {}, + {}, + FLUX_DIM, + FLUX_TEXT_HIDDEN, + FLUX_LATENT_CHANNELS, + FLUX_TIME_EMBED, + }; + views.double_blocks.reserve(static_cast(FLUX_DOUBLE_BLOCKS)); + for (uint64_t block = 0; block < FLUX_DOUBLE_BLOCKS; block++) { + views.double_blocks.push_back(require_double_block_views( + storage, + index, + block, + bits, + group_size + )); + } + views.single_blocks.reserve(static_cast(FLUX_SINGLE_BLOCKS)); + for (uint64_t block = 0; block < FLUX_SINGLE_BLOCKS; block++) { + views.single_blocks.push_back(require_single_block_views( + storage, + index, + block, + bits, + group_size + )); + } + return views; +} + +BonsaiFluxTransformerOutput bonsai_flux_transformer_forward( + const BonsaiFluxTransformerViews& views, + const std::vector& latent_tokens, + const std::vector& prompt_embeddings, + const std::vector>& image_ids, + const std::vector>& text_ids, + float timestep +) { + if (image_ids.empty() || text_ids.empty()) { + throw std::runtime_error("Bonsai Flux transformer ids must not be empty."); + } + if (!std::isfinite(timestep)) { + throw std::runtime_error("Bonsai Flux transformer timestep must be finite."); + } + if (views.dimensions == 0 || + views.text_hidden_size == 0 || + views.latent_channels == 0 || + views.timestep_embedding_size == 0 || + views.double_blocks.empty() || + views.single_blocks.empty()) { + throw std::runtime_error("Bonsai Flux transformer views are incomplete."); + } + + const uint64_t image_sequence_length = static_cast(image_ids.size()); + const uint64_t text_sequence_length = static_cast(text_ids.size()); + const uint64_t latent_batch_stride = checked_multiply( + image_sequence_length, + views.latent_channels, + "flux latent stride" + ); + if (latent_batch_stride == 0 || + latent_tokens.empty() || + latent_tokens.size() % checked_size(latent_batch_stride, "flux latent stride") != 0) { + throw std::runtime_error("Bonsai Flux latent token shape mismatch."); + } + const uint64_t batch = static_cast( + latent_tokens.size() / checked_size(latent_batch_stride, "flux latent stride") + ); + require_sequence_shape( + prompt_embeddings, + batch, + text_sequence_length, + views.text_hidden_size, + "flux prompt embeddings" + ); + + const float timestep_value = timestep <= 1.0F ? timestep * 1000.0F : timestep; + const BonsaiFluxTimestepEmbedding timestep_embedding = bonsai_flux_timestep_embedding( + std::vector(static_cast(batch), timestep_value), + views.timestep_embedding_size + ); + std::vector timestep_values = linear_batch_vector( + views.timestep_linear1, + timestep_embedding.values, + batch + ); + timestep_values = linear_batch_vector( + views.timestep_linear2, + bonsai_silu(timestep_values), + batch + ); + const std::vector modulation_input = bonsai_silu(timestep_values); + + const BonsaiFluxRotaryEmbedding text_rotary = bonsai_flux_pos_embed(text_ids); + const BonsaiFluxRotaryEmbedding image_rotary = bonsai_flux_pos_embed(image_ids); + const std::vector combined_rotary_cos = concat_rotary_rows( + text_rotary, + image_rotary, + true + ); + const std::vector combined_rotary_sin = concat_rotary_rows( + text_rotary, + image_rotary, + false + ); + + std::vector image = bonsai_linear_sequence( + views.x_embedder, + latent_tokens, + batch, + image_sequence_length + ); + std::vector text = bonsai_linear_sequence( + views.context_embedder, + prompt_embeddings, + batch, + text_sequence_length + ); + + const std::vector image_double_modulation = linear_batch_vector( + views.double_modulation_img, + modulation_input, + batch + ); + const std::vector text_double_modulation = linear_batch_vector( + views.double_modulation_txt, + modulation_input, + batch + ); + for (size_t index = 0; index < views.double_blocks.size(); index++) { + const BonsaiFluxDoubleBlockViews& block = views.double_blocks[index]; + log_flux_block_phase( + "flux_double_block_start", + static_cast(index + 1U), + text_sequence_length, + image_sequence_length + ); + apply_flux_double_block( + block, + text_double_modulation, + image_double_modulation, + text_rotary, + image_rotary, + batch, + &text, + &image + ); + log_flux_block_phase( + "flux_double_block_done", + static_cast(index + 1U), + text_sequence_length, + image_sequence_length + ); + } + + const uint64_t combined_sequence_length = checked_add( + text_sequence_length, + image_sequence_length, + "flux combined sequence" + ); + std::vector hidden = concat_token_sequences( + text, + image, + batch, + text_sequence_length, + image_sequence_length, + views.dimensions + ); + const std::vector single_modulation = linear_batch_vector( + views.single_modulation, + modulation_input, + batch + ); + for (size_t index = 0; index < views.single_blocks.size(); index++) { + const BonsaiFluxSingleBlockViews& block = views.single_blocks[index]; + log_flux_block_phase( + "flux_single_block_start", + static_cast(index + 1U), + combined_sequence_length, + image_sequence_length + ); + hidden = apply_flux_single_block( + block, + single_modulation, + combined_rotary_cos, + combined_rotary_sin, + hidden, + batch, + combined_sequence_length + ); + log_flux_block_phase( + "flux_single_block_done", + static_cast(index + 1U), + combined_sequence_length, + image_sequence_length + ); + } + + const std::vector image_output = slice_token_sequence( + hidden, + batch, + combined_sequence_length, + views.dimensions, + text_sequence_length, + image_sequence_length + ); + BonsaiFluxTransformerOutput output; + output.batch = batch; + output.sequence_length = image_sequence_length; + output.channels = views.latent_channels; + output.values = bonsai_linear_sequence( + views.proj_out, + bonsai_flux_final_projection_input( + image_output, + linear_batch_vector(views.norm_out_linear, modulation_input, batch), + batch, + image_sequence_length, + views.dimensions, + 1e-6F + ), + batch, + image_sequence_length + ); + return output; +} + +uint64_t bonsai_flux_single_block_byte_count(const BonsaiFluxSingleBlockViews& views) { + uint64_t bytes = bonsai_linear_byte_count(views.qkv_mlp_proj); + add_bytes(&bytes, bonsai_linear_byte_count(views.out_proj), "single out"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_q), "single norm q"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_k), "single norm k"); + return bytes; +} + +uint64_t bonsai_flux_double_block_byte_count(const BonsaiFluxDoubleBlockViews& views) { + uint64_t bytes = bonsai_linear_byte_count(views.to_q); + add_bytes(&bytes, bonsai_linear_byte_count(views.to_k), "double to_k"); + add_bytes(&bytes, bonsai_linear_byte_count(views.to_v), "double to_v"); + add_bytes(&bytes, bonsai_linear_byte_count(views.add_q), "double add_q"); + add_bytes(&bytes, bonsai_linear_byte_count(views.add_k), "double add_k"); + add_bytes(&bytes, bonsai_linear_byte_count(views.add_v), "double add_v"); + add_bytes(&bytes, bonsai_linear_byte_count(views.to_out), "double to_out"); + add_bytes(&bytes, bonsai_linear_byte_count(views.to_add_out), "double to_add_out"); + add_bytes(&bytes, bonsai_linear_byte_count(views.ff_in), "double ff_in"); + add_bytes(&bytes, bonsai_linear_byte_count(views.ff_out), "double ff_out"); + add_bytes(&bytes, bonsai_linear_byte_count(views.ff_context_in), "double ff_context_in"); + add_bytes(&bytes, bonsai_linear_byte_count(views.ff_context_out), "double ff_context_out"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_q), "double norm_q"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_k), "double norm_k"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_added_q), "double norm_added_q"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.norm_added_k), "double norm_added_k"); + return bytes; +} + +uint64_t bonsai_flux_transformer_byte_count(const BonsaiFluxTransformerViews& views) { + uint64_t bytes = bonsai_linear_byte_count(views.x_embedder); + add_bytes(&bytes, bonsai_linear_byte_count(views.context_embedder), "context embedder"); + add_bytes(&bytes, bonsai_linear_byte_count(views.timestep_linear1), "timestep linear1"); + add_bytes(&bytes, bonsai_linear_byte_count(views.timestep_linear2), "timestep linear2"); + add_bytes(&bytes, bonsai_linear_byte_count(views.double_modulation_img), "double mod img"); + add_bytes(&bytes, bonsai_linear_byte_count(views.double_modulation_txt), "double mod txt"); + add_bytes(&bytes, bonsai_linear_byte_count(views.single_modulation), "single mod"); + add_bytes(&bytes, bonsai_linear_byte_count(views.norm_out_linear), "norm out"); + add_bytes(&bytes, bonsai_linear_byte_count(views.proj_out), "proj out"); + for (const BonsaiFluxDoubleBlockViews& block : views.double_blocks) { + add_bytes(&bytes, bonsai_flux_double_block_byte_count(block), "double block"); + } + for (const BonsaiFluxSingleBlockViews& block : views.single_blocks) { + add_bytes(&bytes, bonsai_flux_single_block_byte_count(block), "single block"); + } + return bytes; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.h new file mode 100644 index 000000000..d1d6def40 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_transformer.h @@ -0,0 +1,103 @@ +#pragma once + +#include "bonsai_linear.h" +#include "bonsai_norm.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include +#include + +struct BonsaiFluxTransformerInventorySummary { + uint64_t double_block_count = 0; + uint64_t single_block_count = 0; + uint64_t logical_tensor_count = 0; +}; + +struct BonsaiFluxSingleBlockViews { + BonsaiLinearViews qkv_mlp_proj; + BonsaiLinearViews out_proj; + BonsaiRmsNormWeightViews norm_q; + BonsaiRmsNormWeightViews norm_k; + uint64_t dimensions = 0; + uint64_t heads = 0; + uint64_t head_dimension = 0; + uint64_t mlp_hidden_dimensions = 0; +}; + +struct BonsaiFluxDoubleBlockViews { + BonsaiLinearViews to_q; + BonsaiLinearViews to_k; + BonsaiLinearViews to_v; + BonsaiLinearViews add_q; + BonsaiLinearViews add_k; + BonsaiLinearViews add_v; + BonsaiLinearViews to_out; + BonsaiLinearViews to_add_out; + BonsaiLinearViews ff_in; + BonsaiLinearViews ff_out; + BonsaiLinearViews ff_context_in; + BonsaiLinearViews ff_context_out; + BonsaiRmsNormWeightViews norm_q; + BonsaiRmsNormWeightViews norm_k; + BonsaiRmsNormWeightViews norm_added_q; + BonsaiRmsNormWeightViews norm_added_k; + uint64_t dimensions = 0; + uint64_t heads = 0; + uint64_t head_dimension = 0; + uint64_t mlp_hidden_dimensions = 0; +}; + +struct BonsaiFluxTransformerViews { + BonsaiLinearViews x_embedder; + BonsaiLinearViews context_embedder; + BonsaiLinearViews timestep_linear1; + BonsaiLinearViews timestep_linear2; + BonsaiLinearViews double_modulation_img; + BonsaiLinearViews double_modulation_txt; + BonsaiLinearViews single_modulation; + BonsaiLinearViews norm_out_linear; + BonsaiLinearViews proj_out; + std::vector double_blocks; + std::vector single_blocks; + uint64_t dimensions = 0; + uint64_t text_hidden_size = 0; + uint64_t latent_channels = 0; + uint64_t timestep_embedding_size = 0; +}; + +struct BonsaiFluxTransformerOutput { + std::vector values; + uint64_t batch = 1; + uint64_t sequence_length = 0; + uint64_t channels = 0; +}; + +BonsaiFluxTransformerInventorySummary bonsai_require_flux_transformer_tensors( + const BonsaiSafetensorsIndex& index, + int bits, + int group_size +); + +BonsaiFluxTransformerViews bonsai_require_flux_transformer_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + int bits, + int group_size +); + +BonsaiFluxTransformerOutput bonsai_flux_transformer_forward( + const BonsaiFluxTransformerViews& views, + const std::vector& latent_tokens, + const std::vector& prompt_embeddings, + const std::vector>& image_ids, + const std::vector>& text_ids, + float timestep +); + +uint64_t bonsai_flux_single_block_byte_count(const BonsaiFluxSingleBlockViews& views); + +uint64_t bonsai_flux_double_block_byte_count(const BonsaiFluxDoubleBlockViews& views); + +uint64_t bonsai_flux_transformer_byte_count(const BonsaiFluxTransformerViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.cpp new file mode 100644 index 000000000..c22fa7187 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.cpp @@ -0,0 +1,312 @@ +#include "bonsai_flux_vae.h" + +#include "bonsai_tensor.h" + +#include +#include + +namespace { + +void require_tensor( + const BonsaiSafetensorsIndex& index, + const std::string& key, + uint64_t* count +) { + index.require(key); + (*count)++; +} + +void require_conv( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t expected_input_channels, + uint64_t expected_output_channels, + uint64_t* count +) { + const std::string weight_key = prefix + ".weight"; + const BonsaiTensorDescriptor& weight = index.require(weight_key); + if (weight.shape.size() != 4) { + throw std::runtime_error("Bonsai VAE conv weight must be 4D: " + weight_key); + } + if (expected_output_channels != 0 && weight.shape[0] != expected_output_channels) { + throw std::runtime_error("Bonsai VAE conv output channel mismatch: " + weight_key); + } + if (expected_input_channels != 0 && weight.shape[1] != expected_input_channels) { + throw std::runtime_error("Bonsai VAE conv input channel mismatch: " + weight_key); + } + (*count)++; + + const BonsaiTensorDescriptor* bias = index.optional(prefix + ".bias"); + if (bias != nullptr) { + if (bonsai_shape_element_count(bias->shape, bias->key) != weight.shape[0]) { + throw std::runtime_error("Bonsai VAE conv bias shape mismatch: " + bias->key); + } + (*count)++; + } +} + +void require_group_norm( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t expected_channels, + uint64_t group_count, + uint64_t* count +) { + const BonsaiTensorDescriptor& weight = index.require(prefix + ".weight"); + const BonsaiTensorDescriptor& bias = index.require(prefix + ".bias"); + const uint64_t weight_elements = bonsai_shape_element_count(weight.shape, weight.key); + const uint64_t bias_elements = bonsai_shape_element_count(bias.shape, bias.key); + if (weight_elements != bias_elements || weight_elements != expected_channels) { + throw std::runtime_error("Bonsai VAE group norm shape mismatch: " + prefix); + } + if (group_count == 0 || expected_channels % group_count != 0) { + throw std::runtime_error("Bonsai VAE group norm channel/group mismatch: " + prefix); + } + *count += 2; +} + +uint64_t leading_rows(const BonsaiTensorDescriptor& descriptor) { + if (descriptor.shape.empty()) { + throw std::runtime_error("Bonsai VAE linear weight must have dimensions: " + descriptor.key); + } + uint64_t rows = 1; + for (size_t index = 0; index + 1 < descriptor.shape.size(); index++) { + if (rows != 0 && descriptor.shape[index] > UINT64_MAX / rows) { + throw std::runtime_error("Bonsai VAE linear shape is too large: " + descriptor.key); + } + rows *= descriptor.shape[index]; + } + return rows; +} + +uint64_t trailing_columns(const BonsaiTensorDescriptor& descriptor) { + if (descriptor.shape.empty()) { + throw std::runtime_error("Bonsai VAE linear weight must have dimensions: " + descriptor.key); + } + return descriptor.shape.back(); +} + +void require_dense_linear( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + const std::string& fallback_prefix, + uint64_t expected_input_channels, + uint64_t expected_output_channels, + uint64_t* count +) { + const std::string preferred = prefix + ".weight"; + const std::string fallback = fallback_prefix.empty() ? preferred : fallback_prefix + ".weight"; + const std::string weight_key = index.contains(preferred) ? preferred : fallback; + const BonsaiTensorDescriptor& weight = index.require(weight_key); + const uint64_t rows = leading_rows(weight); + const uint64_t columns = trailing_columns(weight); + if (rows != expected_output_channels || columns != expected_input_channels) { + throw std::runtime_error("Bonsai VAE linear shape mismatch: " + weight_key); + } + (*count)++; + const std::string bias_key = weight_key.substr(0, weight_key.size() - std::string(".weight").size()) + + ".bias"; + const BonsaiTensorDescriptor* bias = index.optional(bias_key); + if (bias != nullptr) { + if (bonsai_shape_element_count(bias->shape, bias->key) != rows) { + throw std::runtime_error("Bonsai VAE linear bias shape mismatch: " + bias->key); + } + (*count)++; + } +} + +void require_resnet( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t input_channels, + uint64_t output_channels, + uint64_t group_count, + BonsaiFluxVaeInventorySummary* summary +) { + require_group_norm( + index, + prefix + ".norm1", + input_channels, + group_count, + &summary->logical_tensor_count + ); + require_conv( + index, + prefix + ".conv1", + input_channels, + output_channels, + &summary->logical_tensor_count + ); + require_group_norm( + index, + prefix + ".norm2", + output_channels, + group_count, + &summary->logical_tensor_count + ); + require_conv( + index, + prefix + ".conv2", + output_channels, + output_channels, + &summary->logical_tensor_count + ); + if (index.optional(prefix + ".conv_shortcut.weight") != nullptr) { + require_conv( + index, + prefix + ".conv_shortcut", + input_channels, + output_channels, + &summary->logical_tensor_count + ); + } + summary->resnet_block_count++; +} + +void require_attention( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t channels, + uint64_t group_count, + BonsaiFluxVaeInventorySummary* summary +) { + require_group_norm( + index, + prefix + ".group_norm", + channels, + group_count, + &summary->logical_tensor_count + ); + require_dense_linear(index, prefix + ".to_q", "", channels, channels, &summary->logical_tensor_count); + require_dense_linear(index, prefix + ".to_k", "", channels, channels, &summary->logical_tensor_count); + require_dense_linear(index, prefix + ".to_v", "", channels, channels, &summary->logical_tensor_count); + require_dense_linear( + index, + prefix + ".to_out.0", + prefix + ".to_out", + channels, + channels, + &summary->logical_tensor_count + ); + summary->attention_block_count++; +} + +void require_mid_block( + const BonsaiSafetensorsIndex& index, + uint64_t channels, + uint64_t group_count, + BonsaiFluxVaeInventorySummary* summary +) { + require_resnet( + index, + "decoder.mid_block.resnets.0", + channels, + channels, + group_count, + summary + ); + require_attention(index, "decoder.mid_block.attentions.0", channels, group_count, summary); + require_resnet( + index, + "decoder.mid_block.resnets.1", + channels, + channels, + group_count, + summary + ); +} + +void require_up_block( + const BonsaiSafetensorsIndex& index, + uint64_t block_index, + uint64_t input_channels, + uint64_t output_channels, + uint64_t layer_count, + uint64_t group_count, + bool add_upsample, + BonsaiFluxVaeInventorySummary* summary +) { + const std::string prefix = "decoder.up_blocks." + std::to_string(block_index); + for (uint64_t layer = 0; layer < layer_count; layer++) { + require_resnet( + index, + prefix + ".resnets." + std::to_string(layer), + layer == 0 ? input_channels : output_channels, + output_channels, + group_count, + summary + ); + } + if (add_upsample) { + require_conv( + index, + prefix + ".upsamplers.0.conv", + output_channels, + output_channels, + &summary->logical_tensor_count + ); + } + summary->up_block_count++; +} + +} // namespace + +BonsaiFluxVaeInventorySummary bonsai_require_flux_vae_tensors( + const BonsaiSafetensorsIndex& index, + const BonsaiFluxVaeConfig& config +) { + if (config.block_out_channels_count != 4 || + config.block_out_channels.size() != 4 || + config.layers_per_block == 0 || + config.norm_num_groups == 0) { + throw std::runtime_error("invalid Flux VAE config for Bonsai native runtime."); + } + for (uint64_t channels : config.block_out_channels) { + if (channels == 0 || channels % config.norm_num_groups != 0) { + throw std::runtime_error("invalid Flux VAE channel/group config."); + } + } + + BonsaiFluxVaeInventorySummary summary; + require_conv(index, "post_quant_conv", 0, 0, &summary.logical_tensor_count); + require_tensor(index, "bn.running_mean", &summary.logical_tensor_count); + require_tensor(index, "bn.running_var", &summary.logical_tensor_count); + require_conv(index, "decoder.conv_in", 0, config.block_out_channels.back(), &summary.logical_tensor_count); + require_mid_block( + index, + config.block_out_channels.back(), + config.norm_num_groups, + &summary + ); + + const uint64_t up_block_layer_count = config.layers_per_block + 1; + for (uint64_t block = 0; block < config.block_out_channels_count; block++) { + const uint64_t output_channels = + config.block_out_channels[static_cast(config.block_out_channels_count - 1 - block)]; + const uint64_t input_channels = block == 0 + ? output_channels + : config.block_out_channels[ + static_cast(config.block_out_channels_count - block) + ]; + require_up_block( + index, + block, + input_channels, + output_channels, + up_block_layer_count, + config.norm_num_groups, + block + 1 < config.block_out_channels_count, + &summary + ); + } + + require_group_norm( + index, + "decoder.conv_norm_out", + config.block_out_channels.front(), + config.norm_num_groups, + &summary.logical_tensor_count + ); + require_conv(index, "decoder.conv_out", config.block_out_channels.front(), 0, &summary.logical_tensor_count); + return summary; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.h b/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.h new file mode 100644 index 000000000..f5d8a9659 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_flux_vae.h @@ -0,0 +1,26 @@ +#pragma once + +#include "bonsai_safetensors.h" + +#include +#include + +struct BonsaiFluxVaeConfig { + uint64_t block_out_channels_count = 0; + uint64_t layers_per_block = 0; + uint64_t norm_num_groups = 0; + float batch_norm_eps = 0.0F; + std::vector block_out_channels; +}; + +struct BonsaiFluxVaeInventorySummary { + uint64_t up_block_count = 0; + uint64_t resnet_block_count = 0; + uint64_t attention_block_count = 0; + uint64_t logical_tensor_count = 0; +}; + +BonsaiFluxVaeInventorySummary bonsai_require_flux_vae_tensors( + const BonsaiSafetensorsIndex& index, + const BonsaiFluxVaeConfig& config +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.cpp new file mode 100644 index 000000000..f1cd9fff1 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.cpp @@ -0,0 +1,217 @@ +#include "bonsai_image_encoder.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr uint8_t PNG_SIGNATURE[] = {137, 80, 78, 71, 13, 10, 26, 10}; +constexpr char BASE64_ALPHABET[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +struct OutputStats { + float min_value = std::numeric_limits::max(); + float max_value = -std::numeric_limits::max(); + double sum = 0.0; + uint64_t finite_count = 0; + uint64_t non_finite_count = 0; +}; + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai image encoder shape overflow: ") + label); + } + return left * right; +} + +size_t checked_size(uint64_t value, const char* label) { + if (value > static_cast(std::numeric_limits::max())) { + throw std::runtime_error(std::string("Bonsai image encoder shape overflow: ") + label); + } + return static_cast(value); +} + +void append_u32_be(std::vector* output, uint32_t value) { + output->push_back(static_cast((value >> 24U) & 0xFFU)); + output->push_back(static_cast((value >> 16U) & 0xFFU)); + output->push_back(static_cast((value >> 8U) & 0xFFU)); + output->push_back(static_cast(value & 0xFFU)); +} + +void append_chunk( + std::vector* output, + const char type[4], + const std::vector& data +) { + if (data.size() > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai PNG chunk is too large."); + } + append_u32_be(output, static_cast(data.size())); + const size_t type_offset = output->size(); + output->insert(output->end(), type, type + 4); + output->insert(output->end(), data.begin(), data.end()); + + uLong crc = crc32(0L, Z_NULL, 0); + crc = crc32( + crc, + reinterpret_cast(output->data() + type_offset), + static_cast(4 + data.size()) + ); + append_u32_be(output, static_cast(crc)); +} + +uint8_t to_byte(float value) { + const float scaled = std::round(value * 255.0F); + const int integer = static_cast(scaled); + return static_cast(std::clamp(integer, 0, 255)); +} + +float normalized_pixel(const BonsaiNchwTensor& decoded, uint64_t y, uint64_t x, uint64_t channel) { + const uint64_t index = ( + (channel * decoded.height + y) * decoded.width + x + ); + return std::clamp( + decoded.values[checked_size(index, "pixel index")] / 2.0F + 0.5F, + 0.0F, + 1.0F + ); +} + +void record(OutputStats* stats, float value) { + if (!std::isfinite(value)) { + stats->non_finite_count++; + return; + } + stats->min_value = std::min(stats->min_value, value); + stats->max_value = std::max(stats->max_value, value); + stats->sum += static_cast(value); + stats->finite_count++; +} + +void validate_stats(const OutputStats& stats) { + if (stats.finite_count == 0) { + throw std::runtime_error("Bonsai output image is invalid: all RGB values are non-finite."); + } + if (stats.non_finite_count != 0) { + throw std::runtime_error("Bonsai output image is invalid: non-finite RGB values."); + } + const double mean = stats.sum / static_cast(stats.finite_count); + if (mean <= 0.02 || (stats.max_value <= 0.08F && mean <= 0.03)) { + throw std::runtime_error("Bonsai output image is invalid: nearly black output."); + } +} + +std::vector rgb_scanlines(const BonsaiNchwTensor& decoded) { + if (decoded.batch_size != 1 || decoded.channels < 3 || decoded.height == 0 || decoded.width == 0) { + throw std::runtime_error("Bonsai decoded image tensor shape is unsupported."); + } + const uint64_t expected = checked_multiply( + checked_multiply( + checked_multiply(decoded.batch_size, decoded.channels, "decoded tensor"), + decoded.height, + "decoded tensor" + ), + decoded.width, + "decoded tensor" + ); + if (decoded.values.size() != checked_size(expected, "decoded tensor")) { + throw std::runtime_error("Bonsai decoded image tensor value count mismatch."); + } + + const uint64_t row_bytes = checked_multiply(decoded.width, 3U, "png row"); + const uint64_t scanline_bytes = checked_multiply( + checked_multiply(decoded.height, row_bytes + 1U, "png scanlines"), + 1U, + "png scanlines" + ); + std::vector output; + output.reserve(checked_size(scanline_bytes, "png scanlines")); + + OutputStats stats; + for (uint64_t y = 0; y < decoded.height; y++) { + output.push_back(0); + for (uint64_t x = 0; x < decoded.width; x++) { + const float red = normalized_pixel(decoded, y, x, 0); + const float green = normalized_pixel(decoded, y, x, 1); + const float blue = normalized_pixel(decoded, y, x, 2); + record(&stats, red); + record(&stats, green); + record(&stats, blue); + output.push_back(std::isfinite(red) ? to_byte(red) : 0); + output.push_back(std::isfinite(green) ? to_byte(green) : 0); + output.push_back(std::isfinite(blue) ? to_byte(blue) : 0); + } + } + validate_stats(stats); + return output; +} + +std::vector zlib_compress(const std::vector& input) { + uLongf compressed_size = compressBound(static_cast(input.size())); + std::vector output(static_cast(compressed_size), 0); + const int result = compress2( + output.data(), + &compressed_size, + input.data(), + static_cast(input.size()), + Z_BEST_SPEED + ); + if (result != Z_OK) { + throw std::runtime_error("Bonsai PNG compression failed."); + } + output.resize(static_cast(compressed_size)); + return output; +} + +std::vector png_bytes(const BonsaiNchwTensor& decoded) { + if (decoded.width > std::numeric_limits::max() || + decoded.height > std::numeric_limits::max()) { + throw std::runtime_error("Bonsai decoded image is too large for PNG."); + } + + std::vector png; + png.insert(png.end(), std::begin(PNG_SIGNATURE), std::end(PNG_SIGNATURE)); + + std::vector ihdr; + append_u32_be(&ihdr, static_cast(decoded.width)); + append_u32_be(&ihdr, static_cast(decoded.height)); + ihdr.push_back(8); + ihdr.push_back(2); + ihdr.push_back(0); + ihdr.push_back(0); + ihdr.push_back(0); + append_chunk(&png, "IHDR", ihdr); + append_chunk(&png, "IDAT", zlib_compress(rgb_scanlines(decoded))); + append_chunk(&png, "IEND", {}); + return png; +} + +std::string base64_encode(const std::vector& input) { + std::string output; + output.reserve(((input.size() + 2U) / 3U) * 4U); + for (size_t index = 0; index < input.size(); index += 3U) { + const uint32_t octet_a = input[index]; + const uint32_t octet_b = index + 1U < input.size() ? input[index + 1U] : 0U; + const uint32_t octet_c = index + 2U < input.size() ? input[index + 2U] : 0U; + const uint32_t triple = (octet_a << 16U) | (octet_b << 8U) | octet_c; + output.push_back(BASE64_ALPHABET[(triple >> 18U) & 0x3FU]); + output.push_back(BASE64_ALPHABET[(triple >> 12U) & 0x3FU]); + output.push_back(index + 1U < input.size() ? BASE64_ALPHABET[(triple >> 6U) & 0x3FU] : '='); + output.push_back(index + 2U < input.size() ? BASE64_ALPHABET[triple & 0x3FU] : '='); + } + return output; +} + +} // namespace + +std::string bonsai_encode_nchw_tensor_as_png_base64( + const BonsaiNchwTensor& decoded +) { + return base64_encode(png_bytes(decoded)); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.h b/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.h new file mode 100644 index 000000000..badc38ab9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_image_encoder.h @@ -0,0 +1,9 @@ +#pragma once + +#include "bonsai_vae_ops.h" + +#include + +std::string bonsai_encode_nchw_tensor_as_png_base64( + const BonsaiNchwTensor& decoded +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_jni.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_jni.cpp new file mode 100644 index 000000000..8795a9c67 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_jni.cpp @@ -0,0 +1,250 @@ +#include "bonsai_model_probe.h" +#include "bonsai_runtime.h" + +#include + +#include + +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; +constexpr const char* BRIDGE_CLASS = + "com/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge"; +constexpr const char* CALLBACK_CLASS = + "com/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge$ProgressCallback"; + +std::atomic_bool g_cancel_requested(false); + +class JStringChars { +public: + JStringChars(JNIEnv* env, jstring value) : env_(env), value_(value) { + if (value_ != nullptr) { + chars_ = env_->GetStringUTFChars(value_, nullptr); + } + } + + ~JStringChars() { + if (value_ != nullptr && chars_ != nullptr) { + env_->ReleaseStringUTFChars(value_, chars_); + } + } + + std::string str() const { + return chars_ == nullptr ? "" : std::string(chars_); + } + +private: + JNIEnv* env_; + jstring value_; + const char* chars_ = nullptr; +}; + +void throw_java(JNIEnv* env, const char* class_name, const std::string& message) { + jclass clazz = env->FindClass(class_name); + if (clazz != nullptr) { + env->ThrowNew(clazz, message.c_str()); + } +} + +void throw_illegal_state(JNIEnv* env, const std::string& message) { + throw_java(env, "java/lang/IllegalStateException", message); +} + +void emit_progress(JNIEnv* env, jobject callback, jint current, jint total) { + if (callback == nullptr) { + return; + } + + jclass callback_class = env->FindClass(CALLBACK_CLASS); + if (callback_class == nullptr) { + return; + } + + jmethodID on_progress = env->GetMethodID(callback_class, "onProgress", "(II)V"); + if (on_progress == nullptr) { + return; + } + + env->CallVoidMethod(callback, on_progress, current, total); +} + +jstring native_probe_model( + JNIEnv* env, + jobject, + jstring root_path, + jstring packed_transformer_path, + jstring text_encoder_path, + jstring tokenizer_path, + jstring vae_path, + jstring scheduler_path +) { + const JStringChars root_path_chars(env, root_path); + const JStringChars packed_transformer_path_chars(env, packed_transformer_path); + const JStringChars text_encoder_path_chars(env, text_encoder_path); + const JStringChars tokenizer_path_chars(env, tokenizer_path); + const JStringChars vae_path_chars(env, vae_path); + const JStringChars scheduler_path_chars(env, scheduler_path); + + try { + const std::string summary = probe_bonsai_model( + BonsaiModelPaths { + root_path_chars.str(), + packed_transformer_path_chars.str(), + text_encoder_path_chars.str(), + tokenizer_path_chars.str(), + vae_path_chars.str(), + scheduler_path_chars.str(), + } + ); + __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "model probe %s", summary.c_str()); + return env->NewStringUTF(summary.c_str()); + } catch (const std::exception& error) { + throw_illegal_state(env, error.what()); + return nullptr; + } +} + +jstring native_generate( + JNIEnv* env, + jobject, + jstring root_path, + jstring packed_transformer_path, + jstring text_encoder_path, + jstring tokenizer_path, + jstring vae_path, + jstring scheduler_path, + jstring prompt, + jstring negative_prompt, + jint sampling_steps, + jfloat cfg_scale, + jint width, + jint height, + jstring seed, + jint batch_count, + jboolean allow_nsfw, + jstring backend, + jobject callback +) { + g_cancel_requested.store(false); + const JStringChars root_path_chars(env, root_path); + const JStringChars packed_transformer_path_chars(env, packed_transformer_path); + const JStringChars text_encoder_path_chars(env, text_encoder_path); + const JStringChars tokenizer_path_chars(env, tokenizer_path); + const JStringChars vae_path_chars(env, vae_path); + const JStringChars scheduler_path_chars(env, scheduler_path); + const JStringChars prompt_chars(env, prompt); + const JStringChars negative_prompt_chars(env, negative_prompt); + const JStringChars seed_chars(env, seed); + const JStringChars backend_chars(env, backend); + + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "request rootPath=%s transformerPath=%s textEncoderPath=%s vaePath=%s size=%dx%d steps=%d cfg=%.3f batch=%d allowNsfw=%s backend=%s seedBlank=%s promptChars=%zu negativeChars=%zu", + root_path_chars.str().c_str(), + packed_transformer_path_chars.str().c_str(), + text_encoder_path_chars.str().c_str(), + vae_path_chars.str().c_str(), + width, + height, + sampling_steps, + cfg_scale, + batch_count, + allow_nsfw == JNI_TRUE ? "true" : "false", + backend_chars.str().c_str(), + seed_chars.str().empty() ? "true" : "false", + prompt_chars.str().size(), + negative_prompt_chars.str().size() + ); + + try { + const std::string output = bonsai_generate_image( + BonsaiGenerationRequest { + BonsaiModelPaths { + root_path_chars.str(), + packed_transformer_path_chars.str(), + text_encoder_path_chars.str(), + tokenizer_path_chars.str(), + vae_path_chars.str(), + scheduler_path_chars.str(), + }, + prompt_chars.str(), + negative_prompt_chars.str(), + sampling_steps, + cfg_scale, + width, + height, + seed_chars.str(), + batch_count, + allow_nsfw == JNI_TRUE, + backend_chars.str(), + }, + [env, callback](int current, int total) { + emit_progress(env, callback, current, total); + }, + g_cancel_requested + ); + return env->NewStringUTF(output.c_str()); + } catch (const BonsaiGenerationCancelled& error) { + throw_java(env, "java/util/concurrent/CancellationException", error.what()); + return nullptr; + } catch (const std::exception& error) { + throw_illegal_state(env, error.what()); + return nullptr; + } +} + +void native_interrupt(JNIEnv*, jobject) { + g_cancel_requested.store(true); + __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "interrupt requested"); +} + +JNINativeMethod g_methods[] = { + { + const_cast("probeModel"), + const_cast( + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;" + ), + reinterpret_cast(native_probe_model), + }, + { + const_cast("generateModel"), + const_cast( + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;IFIILjava/lang/String;IZLjava/lang/String;Lcom/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge$ProgressCallback;)Ljava/lang/String;" + ), + reinterpret_cast(native_generate), + }, + { + const_cast("interrupt"), + const_cast("()V"), + reinterpret_cast(native_interrupt), + }, +}; + +} // namespace + +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env = nullptr; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK || env == nullptr) { + return JNI_ERR; + } + + jclass bridge_class = env->FindClass(BRIDGE_CLASS); + if (bridge_class == nullptr) { + return JNI_ERR; + } + + if (env->RegisterNatives( + bridge_class, + g_methods, + sizeof(g_methods) / sizeof(g_methods[0]) + ) != JNI_OK) { + return JNI_ERR; + } + + return JNI_VERSION_1_6; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_latents.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_latents.cpp new file mode 100644 index 000000000..711ad2057 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_latents.cpp @@ -0,0 +1,215 @@ +#include "bonsai_latents.h" + +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai latent shape is too large: ") + label); + } + return left * right; +} + +uint64_t checked_size( + uint64_t first, + uint64_t second, + uint64_t third, + uint64_t fourth, + const char* label +) { + return checked_multiply( + checked_multiply(checked_multiply(first, second, label), third, label), + fourth, + label + ); +} + +void require_positive(uint64_t value, const char* label) { + if (value == 0) { + throw std::runtime_error(std::string("Bonsai latent ") + label + " must be positive."); + } +} + +} // namespace + +BonsaiLatentShape bonsai_packed_latent_shape( + uint64_t image_height, + uint64_t image_width, + uint64_t batch_size, + uint64_t channels, + uint64_t vae_scale_factor +) { + require_positive(image_height, "image height"); + require_positive(image_width, "image width"); + require_positive(batch_size, "batch size"); + require_positive(channels, "channels"); + require_positive(vae_scale_factor, "VAE scale factor"); + + const uint64_t divisor = checked_multiply(vae_scale_factor, 2, "vae scale divisor"); + const uint64_t normalized_height = 2 * (image_height / divisor); + const uint64_t normalized_width = 2 * (image_width / divisor); + const uint64_t latent_height = normalized_height / 2; + const uint64_t latent_width = normalized_width / 2; + require_positive(latent_height, "height"); + require_positive(latent_width, "width"); + + return BonsaiLatentShape { + batch_size, + channels, + latent_height, + latent_width, + checked_multiply(latent_height, latent_width, "sequence length"), + }; +} + +std::vector bonsai_latent_grid_ids( + uint64_t batch_size, + uint64_t latent_height, + uint64_t latent_width +) { + require_positive(batch_size, "batch size"); + require_positive(latent_height, "height"); + require_positive(latent_width, "width"); + + const uint64_t value_count = checked_size( + batch_size, + latent_height, + latent_width, + 4, + "grid ids" + ); + if (latent_height > static_cast(std::numeric_limits::max()) || + latent_width > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai latent grid is too large."); + } + + std::vector values; + values.reserve(static_cast(value_count)); + for (uint64_t batch = 0; batch < batch_size; batch++) { + for (uint64_t row = 0; row < latent_height; row++) { + for (uint64_t column = 0; column < latent_width; column++) { + values.push_back(0); + values.push_back(static_cast(row)); + values.push_back(static_cast(column)); + values.push_back(0); + } + } + } + return values; +} + +std::vector bonsai_random_latents_nchw( + uint64_t batch_size, + uint64_t channels, + uint64_t latent_height, + uint64_t latent_width, + int64_t seed +) { + require_positive(batch_size, "batch size"); + require_positive(channels, "channels"); + require_positive(latent_height, "height"); + require_positive(latent_width, "width"); + const uint64_t value_count = checked_size( + batch_size, + channels, + latent_height, + latent_width, + "random latents" + ); + + std::mt19937 generator(static_cast(seed)); + std::normal_distribution distribution(0.0F, 1.0F); + std::vector output; + output.reserve(static_cast(value_count)); + for (uint64_t index = 0; index < value_count; index++) { + output.push_back(distribution(generator)); + } + return output; +} + +std::vector bonsai_pack_latents_nchw( + const std::vector& latents, + uint64_t batch_size, + uint64_t channels, + uint64_t latent_height, + uint64_t latent_width +) { + const uint64_t input_count = checked_size( + batch_size, + channels, + latent_height, + latent_width, + "pack" + ); + if (latents.size() != static_cast(input_count)) { + throw std::runtime_error("Bonsai latent pack input size mismatch."); + } + + std::vector output; + output.reserve(static_cast(input_count)); + const uint64_t sequence_length = checked_multiply(latent_height, latent_width, "pack seq"); + for (uint64_t batch = 0; batch < batch_size; batch++) { + for (uint64_t position = 0; position < sequence_length; position++) { + const uint64_t row = position / latent_width; + const uint64_t column = position % latent_width; + for (uint64_t channel = 0; channel < channels; channel++) { + const uint64_t input_index = + ((batch * channels + channel) * latent_height + row) * latent_width + column; + output.push_back(latents[static_cast(input_index)]); + } + } + } + return output; +} + +std::vector bonsai_unpack_packed_latents( + const std::vector& latents, + uint64_t batch_size, + uint64_t sequence_length, + uint64_t channels, + uint64_t image_height, + uint64_t image_width, + uint64_t vae_scale_factor +) { + const BonsaiLatentShape shape = bonsai_packed_latent_shape( + image_height, + image_width, + batch_size, + channels, + vae_scale_factor + ); + if (sequence_length != shape.sequence_length) { + throw std::runtime_error("Bonsai packed latent sequence length mismatch."); + } + + const uint64_t input_count = checked_multiply( + checked_multiply(batch_size, sequence_length, "unpack"), + channels, + "unpack" + ); + if (latents.size() != static_cast(input_count)) { + throw std::runtime_error("Bonsai latent unpack input size mismatch."); + } + + std::vector output(static_cast(input_count), 0.0F); + for (uint64_t batch = 0; batch < batch_size; batch++) { + for (uint64_t position = 0; position < sequence_length; position++) { + const uint64_t row = position / shape.latent_width; + const uint64_t column = position % shape.latent_width; + for (uint64_t channel = 0; channel < channels; channel++) { + const uint64_t input_index = (batch * sequence_length + position) * channels + channel; + const uint64_t output_index = + ((batch * channels + channel) * shape.latent_height + row) * + shape.latent_width + + column; + output[static_cast(output_index)] = + latents[static_cast(input_index)]; + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_latents.h b/feature/bonsai/src/androidMain/cpp/bonsai_latents.h new file mode 100644 index 000000000..eea58349b --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_latents.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include + +struct BonsaiLatentShape { + uint64_t batch_size = 0; + uint64_t channels = 0; + uint64_t latent_height = 0; + uint64_t latent_width = 0; + uint64_t sequence_length = 0; +}; + +BonsaiLatentShape bonsai_packed_latent_shape( + uint64_t image_height, + uint64_t image_width, + uint64_t batch_size, + uint64_t channels, + uint64_t vae_scale_factor +); + +std::vector bonsai_latent_grid_ids( + uint64_t batch_size, + uint64_t latent_height, + uint64_t latent_width +); + +std::vector bonsai_random_latents_nchw( + uint64_t batch_size, + uint64_t channels, + uint64_t latent_height, + uint64_t latent_width, + int64_t seed +); + +std::vector bonsai_pack_latents_nchw( + const std::vector& latents, + uint64_t batch_size, + uint64_t channels, + uint64_t latent_height, + uint64_t latent_width +); + +std::vector bonsai_unpack_packed_latents( + const std::vector& latents, + uint64_t batch_size, + uint64_t sequence_length, + uint64_t channels, + uint64_t image_height, + uint64_t image_width, + uint64_t vae_scale_factor +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.cpp new file mode 100644 index 000000000..a8137c450 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.cpp @@ -0,0 +1,70 @@ +#include "bonsai_layer_norm.h" + +#include +#include +#include + +namespace { + +void require_optional_vector_size( + const std::vector* values, + size_t expected, + const char* label +) { + if (values != nullptr && values->size() != expected) { + throw std::runtime_error(std::string("Bonsai LayerNorm ") + label + " size mismatch."); + } +} + +} // namespace + +std::vector bonsai_layer_norm( + const std::vector& input, + uint64_t last_dimension, + float epsilon, + const std::vector* weight, + const std::vector* bias +) { + if (last_dimension == 0) { + throw std::runtime_error("Bonsai LayerNorm last dimension must be positive."); + } + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai LayerNorm epsilon must be finite and positive."); + } + + const size_t dimension = static_cast(last_dimension); + if (input.empty() || input.size() % dimension != 0) { + throw std::runtime_error("Bonsai LayerNorm input shape mismatch."); + } + require_optional_vector_size(weight, dimension, "weight"); + require_optional_vector_size(bias, dimension, "bias"); + + std::vector output(input.size(), 0.0F); + for (size_t offset = 0; offset < input.size(); offset += dimension) { + double mean = 0.0; + for (size_t index = 0; index < dimension; index++) { + mean += static_cast(input[offset + index]); + } + mean /= static_cast(dimension); + + double variance = 0.0; + for (size_t index = 0; index < dimension; index++) { + const double centered = static_cast(input[offset + index]) - mean; + variance += centered * centered; + } + variance /= static_cast(dimension); + + const float scale = 1.0F / std::sqrt(static_cast(variance) + epsilon); + for (size_t index = 0; index < dimension; index++) { + float value = (input[offset + index] - static_cast(mean)) * scale; + if (weight != nullptr) { + value *= (*weight)[index]; + } + if (bias != nullptr) { + value += (*bias)[index]; + } + output[offset + index] = value; + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.h b/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.h new file mode 100644 index 000000000..03808a16b --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_layer_norm.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +std::vector bonsai_layer_norm( + const std::vector& input, + uint64_t last_dimension, + float epsilon, + const std::vector* weight, + const std::vector* bias +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_linear.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_linear.cpp new file mode 100644 index 000000000..6f8006685 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_linear.cpp @@ -0,0 +1,309 @@ +#include "bonsai_linear.h" + +#include "bonsai_tensor.h" +#include "bonsai_vulkan.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai linear shape overflow: ") + label); + } + return left * right; +} + +BonsaiTensorView optional_bias_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& bias_key +) { + if (bias_key.empty()) { + return {}; + } + + const BonsaiTensorDescriptor* descriptor = index.optional(bias_key); + if (descriptor == nullptr) { + return {}; + } + return storage.view(*descriptor); +} + +void validate_bias( + const BonsaiTensorView& bias, + uint64_t output_rows, + const std::string& weight_key +) { + if (!bonsai_dtype_is_floating_point(bias.dtype)) { + throw std::runtime_error("Bonsai linear bias must be floating point: " + bias.descriptor->key); + } + if (bias.element_count != output_rows) { + throw std::runtime_error( + "Bonsai linear bias size mismatch for " + + weight_key + + ": " + + bias.descriptor->key + ); + } +} + +float bias_at(const BonsaiLinearViews& views, uint64_t row) { + if (!views.has_bias) { + return 0.0F; + } + return bonsai_read_scalar_as_f32( + views.bias.data + row * views.bias.dtype_byte_count, + views.bias.dtype + ); +} + +void require_input(const BonsaiLinearViews& views, const std::vector& input) { + if (input.size() != static_cast(views.input_values)) { + throw std::runtime_error("Bonsai linear input size mismatch."); + } +} + +void require_row(const BonsaiLinearViews& views, uint64_t row) { + if (row >= views.output_rows) { + throw std::runtime_error("Bonsai linear row is out of range."); + } +} + +void add_bias_to_output(const BonsaiLinearViews& views, float* output) { + if (!views.has_bias) { + return; + } + for (uint64_t row = 0; row < views.output_rows; row++) { + output[row] += bias_at(views, row); + } +} + +void linear_into_unchecked( + const BonsaiLinearViews& views, + const float* input, + float* output +) { + switch (views.kind) { + case BonsaiLinearWeightKind::Dense: + bonsai_dense_matvec_into(views.dense, input, output); + break; + case BonsaiLinearWeightKind::Packed: + bonsai_quantized_matvec_into(views.packed, input, output); + break; + } + add_bias_to_output(views, output); +} + +uint64_t linear_sequence_worker_count( + uint64_t token_count, + uint64_t input_values, + uint64_t output_rows +) { + if (token_count < 4 || input_values < 512 || output_rows < 512) { + return 1; + } + const unsigned hardware_threads = std::thread::hardware_concurrency(); + const uint64_t available_threads = hardware_threads == 0 + ? 2U + : static_cast(hardware_threads); + return std::max( + 1, + std::min({ token_count, available_threads, 4U }) + ); +} + +} // namespace + +BonsaiLinearViews bonsai_require_dense_linear_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& weight_key, + const std::string& bias_key +) { + BonsaiLinearViews views; + views.kind = BonsaiLinearWeightKind::Dense; + views.dense = bonsai_require_dense_weight_view(storage, index, weight_key); + views.output_rows = views.dense.leading_rows; + views.input_values = views.dense.input_values; + views.bias = optional_bias_view(storage, index, bias_key); + views.has_bias = views.bias.descriptor != nullptr; + if (views.has_bias) { + validate_bias(views.bias, views.output_rows, weight_key); + } + return views; +} + +BonsaiLinearViews bonsai_require_packed_linear_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor, + const std::string& bias_key +) { + BonsaiLinearViews views; + views.kind = descriptor.packed ? BonsaiLinearWeightKind::Packed : BonsaiLinearWeightKind::Dense; + if (descriptor.packed) { + views.packed = bonsai_require_packed_weight_views(storage, index, descriptor); + views.output_rows = views.packed.leading_rows; + views.input_values = views.packed.input_values; + } else { + views.dense = bonsai_require_dense_weight_view(storage, index, descriptor.weight_key); + views.output_rows = views.dense.leading_rows; + views.input_values = views.dense.input_values; + } + + views.bias = optional_bias_view(storage, index, bias_key); + views.has_bias = views.bias.descriptor != nullptr; + if (views.has_bias) { + validate_bias(views.bias, views.output_rows, descriptor.weight_key); + } + return views; +} + +float bonsai_linear_row( + const BonsaiLinearViews& views, + const std::vector& input, + uint64_t row +) { + require_input(views, input); + require_row(views, row); + + float output = 0.0F; + switch (views.kind) { + case BonsaiLinearWeightKind::Dense: + output = bonsai_dense_matvec_row(views.dense, input, row); + break; + case BonsaiLinearWeightKind::Packed: + output = bonsai_quantized_matvec_row(views.packed, input, row); + break; + } + return output + bias_at(views, row); +} + +std::vector bonsai_linear( + const BonsaiLinearViews& views, + const std::vector& input +) { + require_input(views, input); + + std::vector output(static_cast(views.output_rows)); + linear_into_unchecked(views, input.data(), output.data()); + return output; +} + +std::vector bonsai_linear_sequence( + const BonsaiLinearViews& views, + const std::vector& input, + uint64_t batch, + uint64_t sequence_length +) { + if (batch == 0 || sequence_length == 0) { + throw std::runtime_error("Bonsai linear sequence shape must be positive."); + } + + const uint64_t token_count = checked_multiply(batch, sequence_length, "sequence tokens"); + const uint64_t expected_input = checked_multiply( + token_count, + views.input_values, + "sequence input" + ); + if (input.size() != static_cast(expected_input)) { + throw std::runtime_error("Bonsai linear sequence input size mismatch."); + } + + const uint64_t expected_output = checked_multiply( + token_count, + views.output_rows, + "sequence output" + ); + if (expected_output > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai linear sequence output is too large."); + } + + std::vector output(static_cast(expected_output)); + if (views.kind == BonsaiLinearWeightKind::Packed && + bonsai_vulkan_quantized_matvec_sequence_into( + views.packed, + input.data(), + output.data(), + token_count + )) { + for (uint64_t token = 0; token < token_count; token++) { + add_bias_to_output(views, output.data() + token * views.output_rows); + } + return output; + } + + const uint64_t worker_count = linear_sequence_worker_count( + token_count, + views.input_values, + views.output_rows + ); + const auto run_range = [&views, &input, &output](uint64_t begin, uint64_t end) { + for (uint64_t token = begin; token < end; token++) { + linear_into_unchecked( + views, + input.data() + token * views.input_values, + output.data() + token * views.output_rows + ); + } + }; + + if (worker_count == 1) { + run_range(0, token_count); + return output; + } + + std::vector workers; + workers.reserve(static_cast(worker_count)); + std::exception_ptr first_error = nullptr; + std::mutex error_mutex; + const uint64_t chunk_size = (token_count + worker_count - 1U) / worker_count; + for (uint64_t worker = 0; worker < worker_count; worker++) { + const uint64_t begin = worker * chunk_size; + const uint64_t end = std::min(token_count, begin + chunk_size); + if (begin >= end) { + continue; + } + workers.emplace_back([&, begin, end]() { + try { + run_range(begin, end); + } catch (...) { + std::lock_guard lock(error_mutex); + if (first_error == nullptr) { + first_error = std::current_exception(); + } + } + }); + } + for (std::thread& worker : workers) { + worker.join(); + } + if (first_error != nullptr) { + std::rethrow_exception(first_error); + } + return output; +} + +uint64_t bonsai_linear_byte_count(const BonsaiLinearViews& views) { + uint64_t bytes = 0; + switch (views.kind) { + case BonsaiLinearWeightKind::Dense: + bytes = views.dense.weight.byte_count; + break; + case BonsaiLinearWeightKind::Packed: + bytes = bonsai_packed_weight_byte_count(views.packed); + break; + } + if (views.has_bias) { + bytes += views.bias.byte_count; + } + return bytes; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_linear.h b/feature/bonsai/src/androidMain/cpp/bonsai_linear.h new file mode 100644 index 000000000..39b9a9436 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_linear.h @@ -0,0 +1,59 @@ +#pragma once + +#include "bonsai_matmul.h" +#include "bonsai_packed_weight.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include +#include + +enum class BonsaiLinearWeightKind { + Dense, + Packed, +}; + +struct BonsaiLinearViews { + BonsaiLinearWeightKind kind = BonsaiLinearWeightKind::Dense; + BonsaiDenseWeightViews dense; + BonsaiPackedWeightViews packed; + BonsaiTensorView bias; + bool has_bias = false; + uint64_t output_rows = 0; + uint64_t input_values = 0; +}; + +BonsaiLinearViews bonsai_require_dense_linear_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& weight_key, + const std::string& bias_key +); + +BonsaiLinearViews bonsai_require_packed_linear_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor, + const std::string& bias_key +); + +float bonsai_linear_row( + const BonsaiLinearViews& views, + const std::vector& input, + uint64_t row +); + +std::vector bonsai_linear( + const BonsaiLinearViews& views, + const std::vector& input +); + +std::vector bonsai_linear_sequence( + const BonsaiLinearViews& views, + const std::vector& input, + uint64_t batch, + uint64_t sequence_length +); + +uint64_t bonsai_linear_byte_count(const BonsaiLinearViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_matmul.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_matmul.cpp new file mode 100644 index 000000000..73562b134 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_matmul.cpp @@ -0,0 +1,312 @@ +#include "bonsai_matmul.h" + +#include "bonsai_dequant.h" +#include "bonsai_tensor.h" +#include "bonsai_vulkan.h" + +#include +#include +#include +#include + +namespace { + +template +T read_unaligned(const uint8_t* data) { + T value {}; + std::memcpy(&value, data, sizeof(T)); + return value; +} + +uint64_t checked_multiply(uint64_t left, uint64_t right, const std::string& tensor_key) { + if (left != 0 && right > UINT64_MAX / left) { + throw std::runtime_error("Bonsai dense weight shape is too large: " + tensor_key); + } + return left * right; +} + +uint64_t leading_row_count(const BonsaiTensorView& view) { + if (view.descriptor->shape.empty()) { + throw std::runtime_error("Bonsai dense weight must have at least one dimension."); + } + + uint64_t rows = 1; + for (size_t index = 0; index + 1 < view.descriptor->shape.size(); index++) { + rows = checked_multiply(rows, view.descriptor->shape[index], view.descriptor->key); + } + return rows; +} + +uint64_t last_dimension(const BonsaiTensorView& view) { + if (view.descriptor->shape.empty()) { + throw std::runtime_error("Bonsai dense weight must have at least one dimension."); + } + return view.descriptor->shape.back(); +} + +float dot_u2_word(uint32_t word, const float* input) { + return input[0] * static_cast(word & 0x3U) + + input[1] * static_cast((word >> 2U) & 0x3U) + + input[2] * static_cast((word >> 4U) & 0x3U) + + input[3] * static_cast((word >> 6U) & 0x3U) + + input[4] * static_cast((word >> 8U) & 0x3U) + + input[5] * static_cast((word >> 10U) & 0x3U) + + input[6] * static_cast((word >> 12U) & 0x3U) + + input[7] * static_cast((word >> 14U) & 0x3U) + + input[8] * static_cast((word >> 16U) & 0x3U) + + input[9] * static_cast((word >> 18U) & 0x3U) + + input[10] * static_cast((word >> 20U) & 0x3U) + + input[11] * static_cast((word >> 22U) & 0x3U) + + input[12] * static_cast((word >> 24U) & 0x3U) + + input[13] * static_cast((word >> 26U) & 0x3U) + + input[14] * static_cast((word >> 28U) & 0x3U) + + input[15] * static_cast((word >> 30U) & 0x3U); +} + +float sum_16_values(const float* input) { + return input[0] + input[1] + input[2] + input[3] + + input[4] + input[5] + input[6] + input[7] + + input[8] + input[9] + input[10] + input[11] + + input[12] + input[13] + input[14] + input[15]; +} + +void require_dense_input( + const BonsaiDenseWeightViews& views, + const std::vector& input +) { + if (input.size() != static_cast(views.input_values)) { + throw std::runtime_error( + "Bonsai dense matvec input size mismatch: " + views.weight.descriptor->key + ); + } +} + +void require_packed_input( + const BonsaiPackedWeightViews& views, + const std::vector& input +) { + if (!views.packed) { + throw std::runtime_error( + "Bonsai quantized matvec requires packed weight: " + views.weight.descriptor->key + ); + } + if (input.size() != static_cast(views.input_values)) { + throw std::runtime_error( + "Bonsai quantized matvec input size mismatch: " + views.weight.descriptor->key + ); + } +} + +void require_io_pointers(const float* input, const float* output) { + if (input == nullptr || output == nullptr) { + throw std::runtime_error("Bonsai matvec input/output pointer must not be null."); + } +} + +float dense_matvec_row_unchecked( + const BonsaiDenseWeightViews& views, + const float* input, + uint64_t row +) { + float sum = 0.0F; + const uint8_t* weight_row = views.weight.data + + row * views.input_values * views.weight.dtype_byte_count; + for (uint64_t column = 0; column < views.input_values; column++) { + const float weight = bonsai_read_scalar_as_f32( + weight_row + column * views.weight.dtype_byte_count, + views.weight.dtype + ); + sum += input[column] * weight; + } + return sum; +} + +float quantized_matvec_row_unchecked( + const BonsaiPackedWeightViews& views, + const float* input, + uint64_t row, + const float* group_input_sums +) { + const uint64_t values_per_word = 32ULL / static_cast(views.bits); + const uint64_t packed_columns = last_dimension(views.weight); + const uint64_t scale_groups = last_dimension(views.scales); + const uint64_t group_size = static_cast(views.group_size); + const uint32_t mask = (1U << static_cast(views.bits)) - 1U; + + const uint8_t* packed_row = views.weight.data + + row * packed_columns * views.weight.dtype_byte_count; + const uint8_t* scale_row = views.scales.data + + row * scale_groups * views.scales.dtype_byte_count; + const uint8_t* bias_row = views.biases.data + + row * scale_groups * views.biases.dtype_byte_count; + + float sum = 0.0F; + for (uint64_t group = 0; group < scale_groups; group++) { + const float scale = bonsai_read_scalar_as_f32( + scale_row + group * views.scales.dtype_byte_count, + views.scales.dtype + ); + const float bias = bonsai_read_scalar_as_f32( + bias_row + group * views.biases.dtype_byte_count, + views.biases.dtype + ); + const uint64_t group_start = group * group_size; + const uint64_t group_end = std::min(group_start + group_size, views.input_values); + float quantized_sum = 0.0F; + float input_sum = group_input_sums == nullptr ? 0.0F : group_input_sums[group]; + uint64_t column = group_start; + while (column < group_end) { + const uint64_t word_column = column / values_per_word; + const uint64_t first_offset = column % values_per_word; + const uint64_t word_base_column = word_column * values_per_word; + const uint64_t last_offset = std::min(values_per_word, group_end - word_base_column); + const uint32_t word = read_unaligned( + packed_row + word_column * views.weight.dtype_byte_count + ); + if (views.bits == 2 && + first_offset == 0 && + last_offset == values_per_word && + word_base_column + values_per_word <= views.input_values) { + const float* word_input = input + word_base_column; + quantized_sum += dot_u2_word(word, word_input); + if (group_input_sums == nullptr) { + input_sum += sum_16_values(word_input); + } + } else { + for (uint64_t offset = first_offset; offset < last_offset; offset++) { + const uint64_t input_index = word_base_column + offset; + const float input_value = input[input_index]; + const uint32_t quantized = + (word >> (offset * static_cast(views.bits))) & mask; + quantized_sum += input_value * static_cast(quantized); + if (group_input_sums == nullptr) { + input_sum += input_value; + } + } + } + column = word_base_column + last_offset; + } + sum += quantized_sum * scale + input_sum * bias; + } + return sum; +} + +std::vector packed_group_input_sums( + const BonsaiPackedWeightViews& views, + const float* input +) { + const uint64_t scale_groups = last_dimension(views.scales); + const uint64_t group_size = static_cast(views.group_size); + std::vector sums(static_cast(scale_groups), 0.0F); + for (uint64_t group = 0; group < scale_groups; group++) { + const uint64_t group_start = group * group_size; + const uint64_t group_end = std::min(group_start + group_size, views.input_values); + float sum = 0.0F; + for (uint64_t column = group_start; column < group_end; column++) { + sum += input[column]; + } + sums[static_cast(group)] = sum; + } + return sums; +} + +} // namespace + +BonsaiDenseWeightViews bonsai_require_dense_weight_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& key +) { + BonsaiTensorView weight = storage.require_view(index, key); + if (!bonsai_dtype_is_floating_point(weight.dtype)) { + throw std::runtime_error("Bonsai dense weight must be floating point: " + key); + } + return BonsaiDenseWeightViews { + weight, + leading_row_count(weight), + last_dimension(weight), + }; +} + +float bonsai_dense_matvec_row( + const BonsaiDenseWeightViews& views, + const std::vector& input, + uint64_t row +) { + require_dense_input(views, input); + if (row >= views.leading_rows) { + throw std::runtime_error( + "Bonsai dense matvec row is out of range: " + views.weight.descriptor->key + ); + } + + return dense_matvec_row_unchecked(views, input.data(), row); +} + +std::vector bonsai_dense_matvec( + const BonsaiDenseWeightViews& views, + const std::vector& input +) { + require_dense_input(views, input); + + std::vector output(static_cast(views.leading_rows)); + bonsai_dense_matvec_into(views, input.data(), output.data()); + return output; +} + +void bonsai_dense_matvec_into( + const BonsaiDenseWeightViews& views, + const float* input, + float* output +) { + require_io_pointers(input, output); + for (uint64_t row = 0; row < views.leading_rows; row++) { + output[row] = dense_matvec_row_unchecked(views, input, row); + } +} + +float bonsai_quantized_matvec_row( + const BonsaiPackedWeightViews& views, + const std::vector& input, + uint64_t row +) { + require_packed_input(views, input); + if (row >= views.leading_rows) { + throw std::runtime_error( + "Bonsai quantized matvec row is out of range: " + views.weight.descriptor->key + ); + } + + return quantized_matvec_row_unchecked(views, input.data(), row, nullptr); +} + +std::vector bonsai_quantized_matvec( + const BonsaiPackedWeightViews& views, + const std::vector& input +) { + require_packed_input(views, input); + + std::vector output(static_cast(views.leading_rows)); + bonsai_quantized_matvec_into(views, input.data(), output.data()); + return output; +} + +void bonsai_quantized_matvec_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output +) { + require_io_pointers(input, output); + if (!views.packed) { + throw std::runtime_error( + "Bonsai quantized matvec requires packed weight: " + views.weight.descriptor->key + ); + } + if (bonsai_vulkan_quantized_matvec_into(views, input, output)) { + return; + } + const std::vector group_input_sums = packed_group_input_sums(views, input); + for (uint64_t row = 0; row < views.leading_rows; row++) { + output[row] = quantized_matvec_row_unchecked(views, input, row, group_input_sums.data()); + } +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_matmul.h b/feature/bonsai/src/androidMain/cpp/bonsai_matmul.h new file mode 100644 index 000000000..5fa9e55d8 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_matmul.h @@ -0,0 +1,53 @@ +#pragma once + +#include "bonsai_packed_weight.h" +#include "bonsai_tensor_storage.h" + +#include +#include + +struct BonsaiDenseWeightViews { + BonsaiTensorView weight; + uint64_t leading_rows = 0; + uint64_t input_values = 0; +}; + +BonsaiDenseWeightViews bonsai_require_dense_weight_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& key +); + +float bonsai_dense_matvec_row( + const BonsaiDenseWeightViews& views, + const std::vector& input, + uint64_t row +); + +std::vector bonsai_dense_matvec( + const BonsaiDenseWeightViews& views, + const std::vector& input +); + +void bonsai_dense_matvec_into( + const BonsaiDenseWeightViews& views, + const float* input, + float* output +); + +float bonsai_quantized_matvec_row( + const BonsaiPackedWeightViews& views, + const std::vector& input, + uint64_t row +); + +std::vector bonsai_quantized_matvec( + const BonsaiPackedWeightViews& views, + const std::vector& input +); + +void bonsai_quantized_matvec_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_model_config.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_model_config.cpp new file mode 100644 index 000000000..9b1485785 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_model_config.cpp @@ -0,0 +1,239 @@ +#include "bonsai_model_config.h" + +#include +#include +#include +#include +#include + +namespace { + +bool stat_path(const std::string& path, struct stat* output) { + return lstat(path.c_str(), output) == 0; +} + +} // namespace + +std::string bonsai_join_path(const std::string& parent, const std::string& child) { + if (parent.empty() || parent.back() == '/') { + return parent + child; + } + return parent + "/" + child; +} + +bool bonsai_path_is_directory(const std::string& path) { + struct stat info {}; + return stat_path(path, &info) && S_ISDIR(info.st_mode); +} + +bool bonsai_path_is_regular_file(const std::string& path) { + struct stat info {}; + return stat_path(path, &info) && S_ISREG(info.st_mode); +} + +void bonsai_require_directory(const std::string& path, const std::string& label) { + if (!bonsai_path_is_directory(path)) { + throw std::runtime_error("missing Bonsai " + label + " directory: " + path); + } +} + +void bonsai_require_file(const std::string& path, const std::string& label) { + if (!bonsai_path_is_regular_file(path)) { + throw std::runtime_error("missing Bonsai " + label + " file: " + path); + } +} + +std::string bonsai_read_text_file(const std::string& path) { + std::ifstream input(path); + if (!input) { + throw std::runtime_error("could not read Bonsai file: " + path); + } + std::ostringstream output; + output << input.rdbuf(); + return output.str(); +} + +int bonsai_parse_json_int(const std::string& json, const std::string& key) { + const std::string quoted_key = "\"" + key + "\""; + const size_t key_index = json.find(quoted_key); + if (key_index == std::string::npos) { + throw std::runtime_error("missing " + key + " in Bonsai JSON config"); + } + + const size_t colon_index = json.find(':', key_index + quoted_key.size()); + if (colon_index == std::string::npos) { + throw std::runtime_error("invalid Bonsai JSON config"); + } + + size_t value_index = colon_index + 1; + while (value_index < json.size() && + std::isspace(static_cast(json[value_index])) != 0) { + value_index++; + } + + const size_t start = value_index; + while (value_index < json.size() && + std::isdigit(static_cast(json[value_index])) != 0) { + value_index++; + } + + if (start == value_index) { + throw std::runtime_error("invalid " + key + " in Bonsai JSON config"); + } + + return std::stoi(json.substr(start, value_index - start)); +} + +float bonsai_parse_json_float(const std::string& json, const std::string& key) { + const std::string quoted_key = "\"" + key + "\""; + const size_t key_index = json.find(quoted_key); + if (key_index == std::string::npos) { + throw std::runtime_error("missing " + key + " in Bonsai JSON config"); + } + + const size_t colon_index = json.find(':', key_index + quoted_key.size()); + if (colon_index == std::string::npos) { + throw std::runtime_error("invalid Bonsai JSON config"); + } + + size_t value_index = colon_index + 1; + while (value_index < json.size() && + std::isspace(static_cast(json[value_index])) != 0) { + value_index++; + } + + const size_t start = value_index; + while (value_index < json.size()) { + const char value = json[value_index]; + if (std::isdigit(static_cast(value)) == 0 && + value != '.' && + value != '-' && + value != '+' && + value != 'e' && + value != 'E') { + break; + } + value_index++; + } + if (start == value_index) { + throw std::runtime_error("invalid " + key + " in Bonsai JSON config"); + } + + return std::stof(json.substr(start, value_index - start)); +} + +std::vector bonsai_parse_json_uint_array_values( + const std::string& json, + const std::string& key +) { + const std::string quoted_key = "\"" + key + "\""; + const size_t key_index = json.find(quoted_key); + if (key_index == std::string::npos) { + throw std::runtime_error("missing " + key + " in Bonsai config JSON"); + } + + const size_t colon_index = json.find(':', key_index + quoted_key.size()); + if (colon_index == std::string::npos) { + throw std::runtime_error("invalid Bonsai config JSON"); + } + + size_t value_index = colon_index + 1; + while (value_index < json.size() && + std::isspace(static_cast(json[value_index])) != 0) { + value_index++; + } + if (value_index >= json.size() || json[value_index] != '[') { + throw std::runtime_error("invalid " + key + " in Bonsai config JSON"); + } + value_index++; + + std::vector values; + while (value_index < json.size()) { + while (value_index < json.size() && + std::isspace(static_cast(json[value_index])) != 0) { + value_index++; + } + if (value_index < json.size() && json[value_index] == ']') { + break; + } + + const size_t start = value_index; + while (value_index < json.size() && + std::isdigit(static_cast(json[value_index])) != 0) { + value_index++; + } + if (start == value_index) { + throw std::runtime_error("invalid " + key + " in Bonsai config JSON"); + } + values.push_back(std::stoull(json.substr(start, value_index - start))); + + while (value_index < json.size() && + std::isspace(static_cast(json[value_index])) != 0) { + value_index++; + } + if (value_index < json.size() && json[value_index] == ',') { + value_index++; + continue; + } + if (value_index < json.size() && json[value_index] == ']') { + break; + } + throw std::runtime_error("invalid " + key + " in Bonsai config JSON"); + } + + if (values.empty()) { + throw std::runtime_error("empty " + key + " in Bonsai config JSON"); + } + return values; +} + +BonsaiQuantizationConfig bonsai_read_quantization_config( + const std::string& packed_transformer_path +) { + const std::string path = bonsai_join_path( + packed_transformer_path, + "quantization_config.json" + ); + bonsai_require_file(path, "quantization config"); + const std::string json = bonsai_read_text_file(path); + BonsaiQuantizationConfig config { + bonsai_parse_json_int(json, "bits"), + bonsai_parse_json_int(json, "group_size"), + }; + + if ((config.bits != 1 && config.bits != 2) || config.group_size != 128) { + throw std::runtime_error( + "Unsupported Bonsai quantization: " + + std::to_string(config.bits) + + "-bit group " + + std::to_string(config.group_size) + + "." + ); + } + + return config; +} + +BonsaiFluxVaeConfig bonsai_read_vae_config(const std::string& vae_path) { + const std::string path = bonsai_join_path(vae_path, "config.json"); + bonsai_require_file(path, "vae config"); + const std::string json = bonsai_read_text_file(path); + const std::vector block_out_channels = bonsai_parse_json_uint_array_values( + json, + "block_out_channels" + ); + BonsaiFluxVaeConfig config { + static_cast(block_out_channels.size()), + static_cast(bonsai_parse_json_int(json, "layers_per_block")), + static_cast(bonsai_parse_json_int(json, "norm_num_groups")), + bonsai_parse_json_float(json, "batch_norm_eps"), + block_out_channels, + }; + if (config.layers_per_block == 0 || + config.block_out_channels_count != 4 || + config.norm_num_groups == 0 || + config.batch_norm_eps <= 0.0F) { + throw std::runtime_error("invalid Bonsai VAE decoder config."); + } + return config; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_model_config.h b/feature/bonsai/src/androidMain/cpp/bonsai_model_config.h new file mode 100644 index 000000000..1660e45a5 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_model_config.h @@ -0,0 +1,39 @@ +#pragma once + +#include "bonsai_flux_vae.h" + +#include +#include +#include + +struct BonsaiQuantizationConfig { + int bits = 0; + int group_size = 0; +}; + +std::string bonsai_join_path(const std::string& parent, const std::string& child); + +bool bonsai_path_is_directory(const std::string& path); + +bool bonsai_path_is_regular_file(const std::string& path); + +void bonsai_require_directory(const std::string& path, const std::string& label); + +void bonsai_require_file(const std::string& path, const std::string& label); + +std::string bonsai_read_text_file(const std::string& path); + +int bonsai_parse_json_int(const std::string& json, const std::string& key); + +float bonsai_parse_json_float(const std::string& json, const std::string& key); + +std::vector bonsai_parse_json_uint_array_values( + const std::string& json, + const std::string& key +); + +BonsaiQuantizationConfig bonsai_read_quantization_config( + const std::string& packed_transformer_path +); + +BonsaiFluxVaeConfig bonsai_read_vae_config(const std::string& vae_path); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.cpp new file mode 100644 index 000000000..5ac8de70c --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.cpp @@ -0,0 +1,1505 @@ +#include "bonsai_activation.h" +#include "bonsai_model_probe.h" + +#include "bonsai_attention.h" +#include "bonsai_dequant.h" +#include "bonsai_embedding.h" +#include "bonsai_flux_attention_layout.h" +#include "bonsai_flux_double_block.h" +#include "bonsai_flux_modulation.h" +#include "bonsai_flux_output.h" +#include "bonsai_flux_pos_embed.h" +#include "bonsai_flux_rope.h" +#include "bonsai_flux_single_block.h" +#include "bonsai_flux_time_embedding.h" +#include "bonsai_flux_transformer.h" +#include "bonsai_flux_vae.h" +#include "bonsai_latents.h" +#include "bonsai_layer_norm.h" +#include "bonsai_linear.h" +#include "bonsai_matmul.h" +#include "bonsai_model_config.h" +#include "bonsai_norm.h" +#include "bonsai_packed_weight.h" +#include "bonsai_prompt.h" +#include "bonsai_qwen.h" +#include "bonsai_rotary.h" +#include "bonsai_scheduler.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" +#include "bonsai_tokenizer.h" +#include "bonsai_vae_decoder.h" +#include "bonsai_vae_ops.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +void require_text_encoder_tensors(const BonsaiSafetensorsIndex& tensors) { + const std::string embedding_key = tensors.resolve_model_prefixed_key("embed_tokens.weight"); + tensors.require_packed_weight(embedding_key, 4, 64); + bonsai_require_qwen_text_encoder_tensors(tensors); +} + +BonsaiFluxTransformerInventorySummary require_transformer_tensors( + const BonsaiSafetensorsIndex& tensors, + const BonsaiQuantizationConfig& quantization +) { + return bonsai_require_flux_transformer_tensors( + tensors, + quantization.bits, + quantization.group_size + ); +} + +BonsaiFluxVaeInventorySummary require_vae_tensors( + const BonsaiSafetensorsIndex& tensors, + const BonsaiFluxVaeConfig& config +) { + return bonsai_require_flux_vae_tensors(tensors, config); +} + +uint64_t require_tensor_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const std::string& key +) { + return storage.require_view(tensors, key).byte_count; +} + +std::vector synthetic_input(uint64_t size) { + std::vector input; + input.reserve(static_cast(size)); + for (uint64_t index = 0; index < size; index++) { + input.push_back((static_cast(index % 7U) - 3.0F) * 0.125F); + } + return input; +} + +uint64_t require_packed_weight_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const BonsaiPackedWeightDescriptor& weight, + double* dequant_checksum, + double* matvec_checksum, + double* linear_checksum +) { + const BonsaiPackedWeightViews views = bonsai_require_packed_weight_views( + storage, + tensors, + weight + ); + if (views.packed) { + const std::vector row = bonsai_dequantize_packed_row(views, 0); + const size_t limit = std::min(row.size(), 32); + for (size_t index = 0; index < limit; index++) { + *dequant_checksum += static_cast(row[index]); + } + *matvec_checksum += static_cast( + bonsai_quantized_matvec_row(views, synthetic_input(views.input_values), 0) + ); + } + const BonsaiLinearViews linear = bonsai_require_packed_linear_views( + storage, + tensors, + weight, + weight.weight_key.substr(0, weight.weight_key.size() - std::string(".weight").size()) + + ".bias" + ); + *linear_checksum += static_cast( + bonsai_linear_row(linear, synthetic_input(linear.input_values), 0) + ); + return bonsai_linear_byte_count(linear); +} + +uint64_t require_dense_weight_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const std::string& key, + double* dense_checksum, + double* linear_checksum +) { + const BonsaiDenseWeightViews views = bonsai_require_dense_weight_view( + storage, + tensors, + key + ); + *dense_checksum += static_cast( + bonsai_dense_matvec_row(views, synthetic_input(views.input_values), 0) + ); + + const BonsaiLinearViews linear = bonsai_require_dense_linear_views( + storage, + tensors, + key, + key.substr(0, key.size() - std::string(".weight").size()) + ".bias" + ); + *linear_checksum += static_cast( + bonsai_linear_row(linear, synthetic_input(linear.input_values), 0) + ); + return bonsai_linear_byte_count(linear); +} + +uint64_t require_embedding_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const BonsaiPackedWeightDescriptor& descriptor, + double* embedding_checksum +) { + const BonsaiEmbeddingViews views = bonsai_require_embedding_views(storage, tensors, descriptor); + const std::vector lookup = bonsai_embedding_lookup(views, {0, 1, 2}); + const size_t limit = std::min(lookup.size(), 96); + for (size_t index = 0; index < limit; index++) { + *embedding_checksum += static_cast(lookup[index]); + } + return bonsai_embedding_byte_count(views); +} + +uint64_t require_rms_norm_view( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const std::string& key, + double* norm_checksum +) { + const BonsaiRmsNormWeightViews views = bonsai_require_rms_norm_weight(storage, tensors, key); + const std::vector output = bonsai_rms_norm(synthetic_input(views.dimensions), views, 1e-6F); + const size_t limit = std::min(output.size(), 64); + for (size_t index = 0; index < limit; index++) { + *norm_checksum += static_cast(output[index]); + } + return bonsai_rms_norm_byte_count(views); +} + +double activation_checksum() { + const std::vector gate = synthetic_input(128); + const std::vector up = synthetic_input(128); + const std::vector output = bonsai_silu_times(gate, up); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_activation_checksum() { + const std::vector output = bonsai_swiglu_last_dimension( + synthetic_input(3 * 16), + 16 + ); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_layer_norm_checksum() { + std::vector weight; + std::vector bias; + weight.reserve(16); + bias.reserve(16); + for (uint64_t index = 0; index < 16; index++) { + weight.push_back(1.0F + static_cast(index) * 0.01F); + bias.push_back((static_cast(index % 5U) - 2.0F) * 0.05F); + } + + const std::vector output = bonsai_layer_norm( + synthetic_input(3 * 16), + 16, + 1e-6F, + &weight, + &bias + ); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double rotary_checksum() { + const std::vector input = synthetic_input(256); + const std::vector output = bonsai_apply_rotary_to_heads( + input, + 128, + 3, + 1000000.0F + ); + + double checksum = 0.0; + const size_t limit = std::min(output.size(), 128); + for (size_t index = 0; index < limit; index++) { + checksum += static_cast(output[index]); + } + return checksum; +} + +double flux_time_embedding_checksum() { + const BonsaiFluxTimestepEmbedding embedding = bonsai_flux_timestep_embedding( + {1000.0F, 500.0F, 125.0F}, + 256 + ); + if (embedding.timestep_count != 3 || embedding.dimensions != 256) { + throw std::runtime_error("Bonsai Flux timestep embedding shape mismatch."); + } + + double checksum = 0.0; + for (float value : embedding.values) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_pos_embed_checksum() { + const BonsaiFluxRotaryEmbedding embedding = bonsai_flux_pos_embed({ + std::array {0.0F, 0.0F, 0.0F, 0.0F}, + std::array {0.0F, 0.0F, 1.0F, 2.0F}, + std::array {0.0F, 1.0F, 2.0F, 3.0F}, + }); + if (embedding.token_count != 3 || embedding.dimensions != 64) { + throw std::runtime_error("Bonsai Flux position embedding shape mismatch."); + } + + double checksum = 0.0; + for (float value : embedding.cos) { + checksum += static_cast(value); + } + for (float value : embedding.sin) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_rope_checksum() { + const uint64_t heads = 2; + const uint64_t sequence_length = 3; + const uint64_t head_dimension = 8; + std::vector norm_weight; + norm_weight.reserve(static_cast(head_dimension)); + for (uint64_t index = 0; index < head_dimension; index++) { + norm_weight.push_back(1.0F + 0.01F * static_cast(index)); + } + + const BonsaiFluxRotaryEmbedding embedding = bonsai_flux_pos_embed({ + std::array {0.0F, 0.0F, 0.0F, 0.0F}, + std::array {0.0F, 1.0F, 0.0F, 1.0F}, + std::array {0.0F, 2.0F, 1.0F, 2.0F}, + }); + std::vector cos_values; + std::vector sin_values; + cos_values.reserve(static_cast(sequence_length * head_dimension / 2)); + sin_values.reserve(static_cast(sequence_length * head_dimension / 2)); + for (uint64_t position = 0; position < sequence_length; position++) { + for (uint64_t index = 0; index < head_dimension / 2; index++) { + cos_values.push_back( + embedding.cos[static_cast(position * embedding.dimensions + index)] + ); + sin_values.push_back( + embedding.sin[static_cast(position * embedding.dimensions + index)] + ); + } + } + + const std::vector output = bonsai_flux_apply_rms_norm_and_rope( + synthetic_input(heads * sequence_length * head_dimension), + norm_weight, + cos_values, + sin_values, + heads, + sequence_length, + head_dimension, + 1e-5F + ); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_modulation_checksum() { + const uint64_t batch = 1; + const uint64_t sequence_length = 3; + const uint64_t dimensions = 16; + const BonsaiFluxSingleModulation single = bonsai_flux_split_single_modulation( + synthetic_input(batch * 3 * dimensions), + batch, + dimensions + ); + const BonsaiFluxDoubleModulation double_mod = bonsai_flux_split_double_modulation( + synthetic_input(batch * 6 * dimensions), + batch, + dimensions + ); + const BonsaiFluxNormOutModulation norm_out = bonsai_flux_split_norm_out_modulation( + synthetic_input(batch * 2 * dimensions), + batch, + dimensions + ); + const std::vector normed = bonsai_flux_apply_modulated_layer_norm( + synthetic_input(batch * sequence_length * dimensions), + single.shift, + single.scale, + batch, + sequence_length, + dimensions, + 1e-6F + ); + const std::vector gated = bonsai_flux_apply_gated_residual( + synthetic_input(batch * sequence_length * dimensions), + normed, + single.gate, + batch, + sequence_length, + dimensions + ); + const std::vector final_norm = bonsai_flux_apply_modulated_layer_norm( + gated, + norm_out.shift, + norm_out.scale, + batch, + sequence_length, + dimensions, + 1e-6F + ); + + double checksum = 0.0; + for (float value : double_mod.shift_msa) { + checksum += static_cast(value); + } + for (float value : double_mod.scale_mlp) { + checksum += static_cast(value); + } + for (float value : double_mod.gate_mlp) { + checksum += static_cast(value); + } + for (float value : final_norm) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_attention_layout_checksum() { + const uint64_t batch = 1; + const uint64_t sequence_length = 3; + const uint64_t heads = 2; + const uint64_t head_dimension = 8; + const uint64_t dimensions = heads * head_dimension; + const uint64_t mlp_hidden_dimensions = 24; + const BonsaiFluxSingleProjectionParts parts = bonsai_flux_split_single_projection( + synthetic_input(batch * sequence_length * (dimensions * 3 + mlp_hidden_dimensions * 2)), + batch, + sequence_length, + dimensions, + mlp_hidden_dimensions + ); + const std::vector query_heads = bonsai_flux_sequence_to_heads( + parts.query, + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector key_heads = bonsai_flux_sequence_to_heads( + parts.key, + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector value_heads = bonsai_flux_sequence_to_heads( + parts.value, + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector attended_heads = bonsai_scaled_dot_product_attention( + query_heads, + key_heads, + value_heads, + {}, + batch * heads, + sequence_length, + head_dimension, + 1.0F / std::sqrt(static_cast(head_dimension)) + ); + const std::vector attended_sequence = bonsai_flux_heads_to_sequence( + attended_heads, + batch, + sequence_length, + heads, + head_dimension + ); + const std::vector mlp_output = bonsai_swiglu_last_dimension( + parts.mlp_values, + mlp_hidden_dimensions * 2 + ); + const std::vector concatenated = bonsai_flux_concat_head_sequences( + query_heads, + value_heads, + batch, + heads, + sequence_length, + sequence_length, + head_dimension + ); + + double checksum = 0.0; + for (float value : attended_sequence) { + checksum += static_cast(value); + } + for (float value : mlp_output) { + checksum += static_cast(value); + } + for (float value : concatenated) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_single_block_checksum() { + const uint64_t batch = 1; + const uint64_t sequence_length = 3; + const uint64_t heads = 2; + const uint64_t head_dimension = 8; + const uint64_t dimensions = heads * head_dimension; + const uint64_t mlp_hidden_dimensions = 24; + std::vector norm_q_weight; + std::vector norm_k_weight; + norm_q_weight.reserve(static_cast(head_dimension)); + norm_k_weight.reserve(static_cast(head_dimension)); + for (uint64_t index = 0; index < head_dimension; index++) { + norm_q_weight.push_back(1.0F + 0.01F * static_cast(index)); + norm_k_weight.push_back(0.9F + 0.02F * static_cast(index)); + } + + const BonsaiFluxRotaryEmbedding embedding = bonsai_flux_pos_embed({ + std::array {0.0F, 0.0F, 0.0F, 0.0F}, + std::array {0.0F, 1.0F, 0.0F, 1.0F}, + std::array {0.0F, 2.0F, 1.0F, 2.0F}, + }); + std::vector cos_values; + std::vector sin_values; + cos_values.reserve(static_cast(sequence_length * head_dimension / 2)); + sin_values.reserve(static_cast(sequence_length * head_dimension / 2)); + for (uint64_t position = 0; position < sequence_length; position++) { + for (uint64_t index = 0; index < head_dimension / 2; index++) { + cos_values.push_back( + embedding.cos[static_cast(position * embedding.dimensions + index)] + ); + sin_values.push_back( + embedding.sin[static_cast(position * embedding.dimensions + index)] + ); + } + } + + const BonsaiFluxSingleBlockReferenceOutput output = bonsai_flux_single_block_reference( + synthetic_input(batch * sequence_length * dimensions), + synthetic_input(batch * 3 * dimensions), + synthetic_input( + batch * + sequence_length * + (dimensions * 3 + mlp_hidden_dimensions * 2) + ), + synthetic_input(batch * sequence_length * dimensions), + norm_q_weight, + norm_k_weight, + cos_values, + sin_values, + batch, + sequence_length, + heads, + head_dimension, + mlp_hidden_dimensions, + 1e-6F, + 1e-5F + ); + if (output.out_projection_input_dimensions != dimensions + mlp_hidden_dimensions) { + throw std::runtime_error("Bonsai Flux single block projection shape mismatch."); + } + + double checksum = 0.0; + for (float value : output.normalized_hidden) { + checksum += static_cast(value); + } + for (float value : output.out_projection_input) { + checksum += static_cast(value); + } + for (float value : output.residual_output) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_double_block_checksum() { + const uint64_t batch = 1; + const uint64_t text_sequence_length = 2; + const uint64_t image_sequence_length = 3; + const uint64_t heads = 2; + const uint64_t head_dimension = 8; + const uint64_t dimensions = heads * head_dimension; + std::vector text_norm_q; + std::vector text_norm_k; + std::vector image_norm_q; + std::vector image_norm_k; + text_norm_q.reserve(static_cast(head_dimension)); + text_norm_k.reserve(static_cast(head_dimension)); + image_norm_q.reserve(static_cast(head_dimension)); + image_norm_k.reserve(static_cast(head_dimension)); + for (uint64_t index = 0; index < head_dimension; index++) { + text_norm_q.push_back(1.0F + 0.01F * static_cast(index)); + text_norm_k.push_back(0.95F + 0.015F * static_cast(index)); + image_norm_q.push_back(0.9F + 0.02F * static_cast(index)); + image_norm_k.push_back(1.05F + 0.005F * static_cast(index)); + } + + const BonsaiFluxRotaryEmbedding embedding = bonsai_flux_pos_embed({ + std::array {0.0F, 0.0F, 0.0F, 0.0F}, + std::array {0.0F, 1.0F, 0.0F, 1.0F}, + std::array {0.0F, 2.0F, 1.0F, 2.0F}, + std::array {0.0F, 3.0F, 1.0F, 3.0F}, + std::array {0.0F, 4.0F, 2.0F, 4.0F}, + }); + std::vector text_cos; + std::vector text_sin; + std::vector image_cos; + std::vector image_sin; + text_cos.reserve(static_cast(text_sequence_length * head_dimension / 2)); + text_sin.reserve(static_cast(text_sequence_length * head_dimension / 2)); + image_cos.reserve(static_cast(image_sequence_length * head_dimension / 2)); + image_sin.reserve(static_cast(image_sequence_length * head_dimension / 2)); + for (uint64_t position = 0; position < text_sequence_length + image_sequence_length; position++) { + std::vector* target_cos = position < text_sequence_length ? &text_cos : &image_cos; + std::vector* target_sin = position < text_sequence_length ? &text_sin : &image_sin; + for (uint64_t index = 0; index < head_dimension / 2; index++) { + target_cos->push_back( + embedding.cos[static_cast(position * embedding.dimensions + index)] + ); + target_sin->push_back( + embedding.sin[static_cast(position * embedding.dimensions + index)] + ); + } + } + + const BonsaiFluxDoubleBlockReferenceOutput output = bonsai_flux_double_block_reference( + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * 6 * dimensions), + synthetic_input(batch * 6 * dimensions), + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * text_sequence_length * dimensions), + synthetic_input(batch * image_sequence_length * dimensions), + text_norm_q, + text_norm_k, + image_norm_q, + image_norm_k, + text_cos, + text_sin, + image_cos, + image_sin, + batch, + text_sequence_length, + image_sequence_length, + heads, + head_dimension, + 1e-6F, + 1e-5F + ); + if (output.dimensions != dimensions || + output.text_sequence_length != text_sequence_length || + output.image_sequence_length != image_sequence_length) { + throw std::runtime_error("Bonsai Flux double block output shape mismatch."); + } + + double checksum = 0.0; + for (float value : output.normalized_text_msa) { + checksum += static_cast(value); + } + for (float value : output.normalized_image_msa) { + checksum += static_cast(value); + } + for (float value : output.attention_text) { + checksum += static_cast(value); + } + for (float value : output.attention_image) { + checksum += static_cast(value); + } + for (float value : output.normalized_text_mlp) { + checksum += static_cast(value); + } + for (float value : output.normalized_image_mlp) { + checksum += static_cast(value); + } + for (float value : output.text_output) { + checksum += static_cast(value); + } + for (float value : output.image_output) { + checksum += static_cast(value); + } + return checksum; +} + +double flux_output_checksum() { + const uint64_t batch = 1; + const uint64_t image_sequence_length = 4; + const uint64_t dimensions = 16; + const std::vector output = bonsai_flux_final_projection_input( + synthetic_input(batch * image_sequence_length * dimensions), + synthetic_input(batch * 2 * dimensions), + batch, + image_sequence_length, + dimensions, + 1e-6F + ); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double attention_checksum() { + const uint64_t heads = 2; + const uint64_t key_value_heads = 1; + const uint64_t length = 4; + const uint64_t head_dimension = 8; + + const std::vector queries = synthetic_input(heads * length * head_dimension); + const std::vector keys = bonsai_repeat_kv_heads( + synthetic_input(key_value_heads * length * head_dimension), + key_value_heads, + heads / key_value_heads, + length, + head_dimension + ); + const std::vector values = bonsai_repeat_kv_heads( + synthetic_input(key_value_heads * length * head_dimension), + key_value_heads, + heads / key_value_heads, + length, + head_dimension + ); + + std::vector mask; + mask.reserve(static_cast(length * length)); + for (uint64_t row = 0; row < length; row++) { + for (uint64_t column = 0; column < length; column++) { + mask.push_back(column > row ? -std::numeric_limits::infinity() : 0.0F); + } + } + + const std::vector output = bonsai_scaled_dot_product_attention( + queries, + keys, + values, + mask, + heads, + length, + head_dimension, + 1.0F / std::sqrt(static_cast(head_dimension)) + ); + + double checksum = 0.0; + for (float value : output) { + checksum += static_cast(value); + } + return checksum; +} + +double scheduler_checksum() { + const BonsaiFlowMatchEulerSchedule schedule = bonsai_flow_match_euler_schedule(4096, 4); + const std::vector updated = bonsai_flow_match_euler_step( + synthetic_input(64), + 1, + synthetic_input(64), + schedule + ); + + double checksum = 0.0; + for (float value : schedule.timesteps) { + checksum += static_cast(value); + } + for (float value : schedule.sigmas) { + checksum += static_cast(value); + } + for (float value : updated) { + checksum += static_cast(value); + } + return checksum; +} + +double latent_checksum() { + const BonsaiLatentShape shape = bonsai_packed_latent_shape( + 512, + 512, + 1, + 128, + 8 + ); + if (shape.latent_height != 32 || shape.latent_width != 32 || shape.sequence_length != 1024) { + throw std::runtime_error("Bonsai latent shape mismatch."); + } + + const uint64_t small_batch = 1; + const uint64_t small_channels = 4; + const uint64_t small_height = 2; + const uint64_t small_width = 3; + const std::vector raw = synthetic_input( + small_batch * small_channels * small_height * small_width + ); + const std::vector packed = bonsai_pack_latents_nchw( + raw, + small_batch, + small_channels, + small_height, + small_width + ); + const std::vector unpacked = bonsai_unpack_packed_latents( + packed, + small_batch, + small_height * small_width, + small_channels, + 32, + 48, + 8 + ); + if (unpacked != raw) { + throw std::runtime_error("Bonsai latent pack/unpack mismatch."); + } + + const std::vector ids = bonsai_latent_grid_ids(1, 2, 3); + double checksum = static_cast(shape.sequence_length); + for (float value : packed) { + checksum += static_cast(value); + } + for (int32_t value : ids) { + checksum += static_cast(value); + } + return checksum; +} + +double prompt_checksum() { + const BonsaiQwenPromptSpec spec = bonsai_qwen_prompt_spec(); + const std::string formatted = bonsai_qwen_chat_formatted_prompt("Cat"); + if (formatted != + "<|im_start|>user\nCat<|im_end|>\n<|im_start|>assistant\n\n\n\n\n") { + throw std::runtime_error("Bonsai Qwen prompt template mismatch."); + } + const std::vector text_ids = bonsai_qwen_text_ids(4); + + double checksum = static_cast(spec.max_sequence_length) + + static_cast(spec.pad_token_id) + + static_cast(spec.eos_token_id); + for (char value : formatted) { + checksum += static_cast(static_cast(value)); + } + for (int32_t value : text_ids) { + checksum += static_cast(value); + } + return checksum; +} + +double vae_ops_checksum() { + BonsaiNchwTensor input { + 1, + 4, + 3, + 3, + synthetic_input(1 * 4 * 3 * 3), + }; + std::vector norm_weight; + std::vector norm_bias; + norm_weight.reserve(4); + norm_bias.reserve(4); + for (uint64_t index = 0; index < 4; index++) { + norm_weight.push_back(1.0F + 0.02F * static_cast(index)); + norm_bias.push_back((static_cast(index) - 1.5F) * 0.03F); + } + + const BonsaiNchwTensor normed = bonsai_vae_group_norm_nchw( + input, + 2, + norm_weight, + norm_bias, + 1e-6F + ); + BonsaiNchwTensor activated = normed; + activated.values = bonsai_silu(normed.values); + + std::vector weight; + weight.reserve(2 * 4 * 3 * 3); + for (uint64_t index = 0; index < 2 * 4 * 3 * 3; index++) { + weight.push_back((static_cast(index % 11U) - 5.0F) * 0.02F); + } + const std::vector bias {0.05F, -0.025F}; + const BonsaiNchwTensor conv = bonsai_vae_conv2d_nchw( + activated, + weight, + 2, + 3, + 3, + 1, + &bias + ); + const BonsaiNchwTensor upsampled = bonsai_vae_upsample_nearest2x_nchw(conv); + if (conv.height != 3 || conv.width != 3 || upsampled.height != 6 || upsampled.width != 6) { + throw std::runtime_error("Bonsai VAE op shape mismatch."); + } + + double checksum = 0.0; + for (float value : upsampled.values) { + checksum += static_cast(value); + } + return checksum; +} + +double vae_attention_checksum() { + BonsaiNchwTensor queries { + 1, + 4, + 2, + 3, + synthetic_input(1 * 4 * 2 * 3), + }; + BonsaiNchwTensor keys { + 1, + 4, + 2, + 3, + synthetic_input(1 * 4 * 2 * 3), + }; + BonsaiNchwTensor values { + 1, + 4, + 2, + 3, + synthetic_input(1 * 4 * 2 * 3), + }; + const BonsaiNchwTensor attended = bonsai_vae_spatial_attention_nchw( + queries, + keys, + values, + 1.0F / std::sqrt(4.0F) + ); + const BonsaiNchwTensor residual = bonsai_vae_add_nchw(queries, attended); + + double checksum = 0.0; + for (float value : residual.values) { + checksum += static_cast(value); + } + return checksum; +} + +double vae_decode_prelude_checksum() { + BonsaiNchwTensor packed { + 1, + 8, + 2, + 2, + synthetic_input(1 * 8 * 2 * 2), + }; + std::vector mean; + std::vector variance; + mean.reserve(8); + variance.reserve(8); + for (uint64_t index = 0; index < 8; index++) { + mean.push_back((static_cast(index) - 4.0F) * 0.01F); + variance.push_back(0.75F + 0.05F * static_cast(index)); + } + + const BonsaiNchwTensor denormalized = bonsai_vae_denormalize_channels_nchw( + packed, + mean, + variance, + 1e-6F + ); + const BonsaiNchwTensor unpatchified = bonsai_vae_unpatchify_nchw(denormalized); + if (unpatchified.channels != 2 || unpatchified.height != 4 || unpatchified.width != 4) { + throw std::runtime_error("Bonsai VAE decode prelude shape mismatch."); + } + + double checksum = 0.0; + for (float value : unpatchified.values) { + checksum += static_cast(value); + } + return checksum; +} + +uint64_t validate_text_encoder_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + double* dequant_checksum, + double* matvec_checksum, + double* linear_checksum, + double* embedding_checksum, + double* norm_checksum, + double* qwen_checksum, + uint64_t* qwen_layers, + uint64_t* qwen_logical_tensors, + uint64_t* qwen_view_layers, + uint64_t* qwen_view_bytes +) { + uint64_t bytes = 0; + const std::string embedding_key = tensors.resolve_model_prefixed_key("embed_tokens.weight"); + bytes += require_embedding_views( + storage, + tensors, + tensors.require_packed_weight(embedding_key, 4, 64), + embedding_checksum + ); + bytes += require_rms_norm_view( + storage, + tensors, + tensors.resolve_model_prefixed_key("norm.weight"), + norm_checksum + ); + const BonsaiQwenInventorySummary inventory = bonsai_require_qwen_text_encoder_tensors( + tensors + ); + *qwen_layers = inventory.layer_count; + *qwen_logical_tensors = inventory.logical_tensor_count; + const BonsaiQwenLayerProbeSummary qwen = bonsai_probe_qwen_text_encoder_layer0( + storage, + tensors + ); + bytes += qwen.bytes; + *qwen_checksum += qwen.checksum; + const BonsaiQwenTextEncoderViews qwen_views = bonsai_require_qwen_text_encoder_views( + storage, + tensors + ); + *qwen_view_layers = static_cast(qwen_views.layers.size()); + *qwen_view_bytes = bonsai_qwen_text_encoder_byte_count(qwen_views); + bytes += *qwen_view_bytes; + return bytes; +} + +uint64_t validate_transformer_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const BonsaiQuantizationConfig& quantization, + double* dequant_checksum, + double* matvec_checksum, + double* dense_checksum, + double* linear_checksum, + double* linear_sequence_checksum, + uint64_t* flux_double_view_blocks, + uint64_t* flux_single_view_blocks, + uint64_t* flux_transformer_view_bytes +) { + uint64_t bytes = 0; + const BonsaiLinearViews x_embedder = bonsai_require_dense_linear_views( + storage, + tensors, + "x_embedder.weight", + "x_embedder.bias" + ); + const std::vector x_input = synthetic_input(x_embedder.input_values); + *dense_checksum += static_cast(bonsai_dense_matvec_row( + x_embedder.dense, + x_input, + 0 + )); + *linear_checksum += static_cast(bonsai_linear_row( + x_embedder, + x_input, + 0 + )); + const std::vector sequence_output = bonsai_linear_sequence( + x_embedder, + synthetic_input(2 * x_embedder.input_values), + 1, + 2 + ); + const size_t sequence_limit = std::min(sequence_output.size(), 128); + for (size_t index = 0; index < sequence_limit; index++) { + *linear_sequence_checksum += static_cast(sequence_output[index]); + } + bytes += bonsai_linear_byte_count(x_embedder); + + bytes += require_dense_weight_view( + storage, + tensors, + "context_embedder.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "norm_out.linear.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "proj_out.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "time_guidance_embed.timestep_embedder.linear_1.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "time_guidance_embed.timestep_embedder.linear_2.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "double_stream_modulation_img.linear.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "double_stream_modulation_txt.linear.weight", + dense_checksum, + linear_checksum + ); + bytes += require_dense_weight_view( + storage, + tensors, + "single_stream_modulation.linear.weight", + dense_checksum, + linear_checksum + ); + bytes += require_packed_weight_views( + storage, + tensors, + tensors.require_packed_weight( + "transformer_blocks.0.attn.to_q.weight", + quantization.bits, + quantization.group_size + ), + dequant_checksum, + matvec_checksum, + linear_checksum + ); + bytes += require_packed_weight_views( + storage, + tensors, + tensors.require_packed_weight( + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.weight", + quantization.bits, + quantization.group_size + ), + dequant_checksum, + matvec_checksum, + linear_checksum + ); + const BonsaiFluxTransformerViews transformer_views = bonsai_require_flux_transformer_views( + storage, + tensors, + quantization.bits, + quantization.group_size + ); + *flux_double_view_blocks = static_cast(transformer_views.double_blocks.size()); + *flux_single_view_blocks = static_cast(transformer_views.single_blocks.size()); + *flux_transformer_view_bytes = bonsai_flux_transformer_byte_count(transformer_views); + bytes += *flux_transformer_view_bytes; + return bytes; +} + +uint64_t validate_vae_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& tensors, + const BonsaiFluxVaeConfig& config, + double* tensor_vector_checksum, + double* vae_conv_view_checksum, + double* vae_group_norm_view_checksum, + double* vae_attention_view_checksum, + double* vae_resnet_view_checksum, + double* vae_up_block_view_checksum, + double* vae_decode_prelude_view_checksum +) { + uint64_t bytes = 0; + const BonsaiTensorView mean = storage.require_view(tensors, "bn.running_mean"); + const BonsaiTensorView variance = storage.require_view(tensors, "bn.running_var"); + const std::vector mean_values = bonsai_tensor_view_to_f32_vector(mean); + const std::vector variance_values = bonsai_tensor_view_to_f32_vector(variance); + const size_t mean_limit = std::min(mean_values.size(), 64); + const size_t variance_limit = std::min(variance_values.size(), 64); + for (size_t index = 0; index < mean_limit; index++) { + *tensor_vector_checksum += static_cast(mean_values[index]); + } + for (size_t index = 0; index < variance_limit; index++) { + *tensor_vector_checksum += static_cast(variance_values[index]); + } + bytes += mean.byte_count; + bytes += variance.byte_count; + const BonsaiVaeConv2dViews post_quant_conv = bonsai_vae_require_conv2d_views( + storage, + tensors, + "post_quant_conv" + ); + const BonsaiNchwTensor conv_input { + 1, + post_quant_conv.input_channels, + post_quant_conv.kernel_height, + post_quant_conv.kernel_width, + synthetic_input( + post_quant_conv.input_channels * + post_quant_conv.kernel_height * + post_quant_conv.kernel_width + ), + }; + const BonsaiNchwTensor conv_output = bonsai_vae_conv2d_view_nchw( + conv_input, + post_quant_conv + ); + const size_t conv_limit = std::min(conv_output.values.size(), 64); + for (size_t index = 0; index < conv_limit; index++) { + *vae_conv_view_checksum += static_cast(conv_output.values[index]); + } + bytes += bonsai_vae_conv2d_byte_count(post_quant_conv); + const BonsaiVaeGroupNormViews norm_out = bonsai_vae_require_group_norm_views( + storage, + tensors, + "decoder.conv_norm_out", + config.block_out_channels[0], + config.norm_num_groups, + 1e-6F + ); + const BonsaiNchwTensor norm_input { + 1, + norm_out.channels, + 2, + 2, + synthetic_input(norm_out.channels * 2 * 2), + }; + const BonsaiNchwTensor norm_output = bonsai_vae_group_norm_view_nchw( + norm_input, + norm_out + ); + const size_t norm_limit = std::min(norm_output.values.size(), 64); + for (size_t index = 0; index < norm_limit; index++) { + *vae_group_norm_view_checksum += static_cast(norm_output.values[index]); + } + bytes += bonsai_vae_group_norm_byte_count(norm_out); + const BonsaiVaeAttentionViews mid_attention = bonsai_vae_require_attention_views( + storage, + tensors, + "decoder.mid_block.attentions.0", + config.block_out_channels.back(), + config.norm_num_groups, + 1e-6F + ); + const BonsaiNchwTensor attention_input { + 1, + mid_attention.channels, + 1, + 1, + synthetic_input(mid_attention.channels) + }; + const BonsaiNchwTensor attention_output = bonsai_vae_attention_view_nchw( + attention_input, + mid_attention + ); + const size_t attention_limit = std::min(attention_output.values.size(), 64); + for (size_t index = 0; index < attention_limit; index++) { + *vae_attention_view_checksum += static_cast(attention_output.values[index]); + } + bytes += bonsai_vae_attention_byte_count(mid_attention); + const BonsaiVaeResnetViews mid_resnet = bonsai_vae_require_resnet_views( + storage, + tensors, + "decoder.mid_block.resnets.0", + config.block_out_channels.back(), + config.block_out_channels.back(), + config.norm_num_groups, + 1e-6F + ); + const BonsaiNchwTensor resnet_input { + 1, + mid_resnet.input_channels, + 1, + 1, + synthetic_input(mid_resnet.input_channels) + }; + const BonsaiNchwTensor resnet_output = bonsai_vae_resnet_view_nchw( + resnet_input, + mid_resnet + ); + const size_t resnet_limit = std::min(resnet_output.values.size(), 64); + for (size_t index = 0; index < resnet_limit; index++) { + *vae_resnet_view_checksum += static_cast(resnet_output.values[index]); + } + bytes += bonsai_vae_resnet_byte_count(mid_resnet); + + const uint64_t last_up_block = config.block_out_channels_count - 1U; + const uint64_t last_up_output_channels = config.block_out_channels.front(); + const uint64_t last_up_input_channels = config.block_out_channels_count == 1 + ? last_up_output_channels + : config.block_out_channels[1]; + const BonsaiVaeUpBlockViews last_up = bonsai_vae_require_up_block_views( + storage, + tensors, + "decoder.up_blocks." + std::to_string(last_up_block), + last_up_input_channels, + last_up_output_channels, + config.layers_per_block + 1U, + config.norm_num_groups, + false, + 1e-6F + ); + const BonsaiNchwTensor up_input { + 1, + last_up.input_channels, + 1, + 1, + synthetic_input(last_up.input_channels) + }; + const BonsaiNchwTensor up_output = bonsai_vae_up_block_view_nchw(up_input, last_up); + const size_t up_limit = std::min(up_output.values.size(), 64); + for (size_t index = 0; index < up_limit; index++) { + *vae_up_block_view_checksum += static_cast(up_output.values[index]); + } + bytes += bonsai_vae_up_block_byte_count(last_up); + const BonsaiVaeDecodeViews decode_views = bonsai_vae_require_decode_views( + storage, + tensors, + config + ); + const BonsaiNchwTensor packed_input { + 1, + decode_views.packed_channels, + 1, + 1, + synthetic_input(decode_views.packed_channels) + }; + const BonsaiNchwTensor prelude_output = bonsai_vae_decode_prelude_view_nchw( + packed_input, + decode_views + ); + const size_t prelude_limit = std::min(prelude_output.values.size(), 64); + for (size_t index = 0; index < prelude_limit; index++) { + *vae_decode_prelude_view_checksum += static_cast(prelude_output.values[index]); + } + bytes += bonsai_vae_decode_byte_count(decode_views); + bytes += require_tensor_view(storage, tensors, "decoder.conv_in.weight"); + return bytes; +} + +} // namespace + +std::string probe_bonsai_model(const BonsaiModelPaths& paths) { + bonsai_require_directory(paths.root_path, "root"); + bonsai_require_directory(paths.tokenizer_path, "tokenizer"); + bonsai_require_directory(paths.scheduler_path, "scheduler"); + bonsai_require_file(bonsai_join_path(paths.tokenizer_path, "tokenizer.json"), "tokenizer"); + bonsai_require_file( + bonsai_join_path(paths.tokenizer_path, "tokenizer_config.json"), + "tokenizer config" + ); + const BonsaiQuantizationConfig quantization = bonsai_read_quantization_config( + paths.packed_transformer_path + ); + const BonsaiFluxVaeConfig vae_config = bonsai_read_vae_config(paths.vae_path); + const BonsaiTokenizerData tokenizer_data = bonsai_load_tokenizer_data( + paths.tokenizer_path + ); + const BonsaiTokenizerMetadata& tokenizer_metadata = tokenizer_data.metadata; + const BonsaiSafetensorsIndex transformer = BonsaiSafetensorsIndex::load_directory( + paths.packed_transformer_path, + "transformer" + ); + const BonsaiSafetensorsIndex text_encoder = BonsaiSafetensorsIndex::load_directory( + paths.text_encoder_path, + "text encoder" + ); + const BonsaiSafetensorsIndex vae = BonsaiSafetensorsIndex::load_directory( + paths.vae_path, + "vae" + ); + + const BonsaiFluxTransformerInventorySummary transformer_inventory = + require_transformer_tensors(transformer, quantization); + require_text_encoder_tensors(text_encoder); + const BonsaiFluxVaeInventorySummary vae_inventory = require_vae_tensors(vae, vae_config); + + const BonsaiTensorStorage transformer_storage(transformer); + const BonsaiTensorStorage text_encoder_storage(text_encoder); + const BonsaiTensorStorage vae_storage(vae); + double dequant_checksum = 0.0; + double matvec_checksum = 0.0; + double dense_checksum = 0.0; + double linear_checksum = 0.0; + double linear_sequence_checksum = 0.0; + double tensor_vector_checksum = 0.0; + double vae_conv_view_checksum = 0.0; + double vae_group_norm_view_checksum = 0.0; + double vae_attention_view_checksum = 0.0; + double vae_resnet_view_checksum = 0.0; + double vae_up_block_view_checksum = 0.0; + double vae_decode_prelude_view_checksum = 0.0; + double embedding_checksum = 0.0; + double norm_checksum = 0.0; + double qwen_checksum = 0.0; + uint64_t qwen_layers = 0; + uint64_t qwen_logical_tensors = 0; + uint64_t qwen_view_layers = 0; + uint64_t qwen_view_bytes = 0; + uint64_t flux_double_view_blocks = 0; + uint64_t flux_single_view_blocks = 0; + uint64_t flux_transformer_view_bytes = 0; + const double activation_probe_checksum = activation_checksum(); + const double flux_activation_probe_checksum = flux_activation_checksum(); + const double flux_layer_norm_probe_checksum = flux_layer_norm_checksum(); + const double flux_time_embedding_probe_checksum = flux_time_embedding_checksum(); + const double rotary_probe_checksum = rotary_checksum(); + const double flux_pos_embed_probe_checksum = flux_pos_embed_checksum(); + const double flux_rope_probe_checksum = flux_rope_checksum(); + const double flux_modulation_probe_checksum = flux_modulation_checksum(); + const double flux_attention_layout_probe_checksum = flux_attention_layout_checksum(); + const double flux_single_block_probe_checksum = flux_single_block_checksum(); + const double flux_double_block_probe_checksum = flux_double_block_checksum(); + const double flux_output_probe_checksum = flux_output_checksum(); + const double attention_probe_checksum = attention_checksum(); + const double scheduler_probe_checksum = scheduler_checksum(); + const double latent_probe_checksum = latent_checksum(); + const double prompt_probe_checksum = prompt_checksum(); + const double vae_ops_probe_checksum = vae_ops_checksum(); + const double vae_attention_probe_checksum = vae_attention_checksum(); + const double vae_decode_prelude_probe_checksum = vae_decode_prelude_checksum(); + const uint64_t transformer_bytes = validate_transformer_views( + transformer_storage, + transformer, + quantization, + &dequant_checksum, + &matvec_checksum, + &dense_checksum, + &linear_checksum, + &linear_sequence_checksum, + &flux_double_view_blocks, + &flux_single_view_blocks, + &flux_transformer_view_bytes + ); + const uint64_t text_encoder_bytes = validate_text_encoder_views( + text_encoder_storage, + text_encoder, + &dequant_checksum, + &matvec_checksum, + &linear_checksum, + &embedding_checksum, + &norm_checksum, + &qwen_checksum, + &qwen_layers, + &qwen_logical_tensors, + &qwen_view_layers, + &qwen_view_bytes + ); + const uint64_t vae_bytes = validate_vae_views( + vae_storage, + vae, + vae_config, + &tensor_vector_checksum, + &vae_conv_view_checksum, + &vae_group_norm_view_checksum, + &vae_attention_view_checksum, + &vae_resnet_view_checksum, + &vae_up_block_view_checksum, + &vae_decode_prelude_view_checksum + ); + + return "bits=" + std::to_string(quantization.bits) + + " group_size=" + std::to_string(quantization.group_size) + + " tokenizer_vocab=" + std::to_string(tokenizer_metadata.vocab_size) + + " tokenizer_merges=" + std::to_string(tokenizer_metadata.merge_count) + + " tokenizer_pad_id=" + std::to_string(tokenizer_metadata.pad_token_id) + + " tokenizer_eos_id=" + std::to_string(tokenizer_metadata.eos_token_id) + + " tokenizer_checksum=" + + std::to_string(bonsai_tokenizer_data_checksum(tokenizer_data)) + + " transformer_files=" + std::to_string(transformer.file_count()) + + " transformer_tensors=" + std::to_string(transformer.tensor_count()) + + " transformer_probe_bytes=" + std::to_string(transformer_bytes) + + " transformer_double_blocks=" + std::to_string(transformer_inventory.double_block_count) + + " transformer_single_blocks=" + std::to_string(transformer_inventory.single_block_count) + + " flux_double_view_blocks=" + std::to_string(flux_double_view_blocks) + + " flux_single_view_blocks=" + std::to_string(flux_single_view_blocks) + + " flux_transformer_view_bytes=" + std::to_string(flux_transformer_view_bytes) + + " transformer_logical_tensors=" + + std::to_string(transformer_inventory.logical_tensor_count) + + " text_encoder_files=" + std::to_string(text_encoder.file_count()) + + " text_encoder_tensors=" + std::to_string(text_encoder.tensor_count()) + + " text_encoder_probe_bytes=" + std::to_string(text_encoder_bytes) + + " qwen_layers=" + std::to_string(qwen_layers) + + " qwen_logical_tensors=" + std::to_string(qwen_logical_tensors) + + " qwen_view_layers=" + std::to_string(qwen_view_layers) + + " qwen_view_bytes=" + std::to_string(qwen_view_bytes) + + " vae_files=" + std::to_string(vae.file_count()) + + " vae_tensors=" + std::to_string(vae.tensor_count()) + + " vae_probe_bytes=" + std::to_string(vae_bytes) + + " vae_up_blocks=" + std::to_string(vae_inventory.up_block_count) + + " vae_resnet_blocks=" + std::to_string(vae_inventory.resnet_block_count) + + " vae_attention_blocks=" + std::to_string(vae_inventory.attention_block_count) + + " vae_norm_groups=" + std::to_string(vae_config.norm_num_groups) + + " vae_batch_norm_eps=" + std::to_string(vae_config.batch_norm_eps) + + " vae_logical_tensors=" + std::to_string(vae_inventory.logical_tensor_count) + + " dequant_checksum=" + std::to_string(dequant_checksum) + + " matvec_checksum=" + std::to_string(matvec_checksum) + + " dense_checksum=" + std::to_string(dense_checksum) + + " linear_checksum=" + std::to_string(linear_checksum) + + " linear_sequence_checksum=" + std::to_string(linear_sequence_checksum) + + " tensor_vector_checksum=" + std::to_string(tensor_vector_checksum) + + " vae_conv_view_checksum=" + std::to_string(vae_conv_view_checksum) + + " vae_group_norm_view_checksum=" + std::to_string(vae_group_norm_view_checksum) + + " vae_attention_view_checksum=" + std::to_string(vae_attention_view_checksum) + + " vae_resnet_view_checksum=" + std::to_string(vae_resnet_view_checksum) + + " vae_up_block_view_checksum=" + std::to_string(vae_up_block_view_checksum) + + " vae_decode_prelude_view_checksum=" + + std::to_string(vae_decode_prelude_view_checksum) + + " embedding_checksum=" + std::to_string(embedding_checksum) + + " norm_checksum=" + std::to_string(norm_checksum) + + " qwen_checksum=" + std::to_string(qwen_checksum) + + " activation_checksum=" + std::to_string(activation_probe_checksum) + + " flux_activation_checksum=" + std::to_string(flux_activation_probe_checksum) + + " flux_layer_norm_checksum=" + std::to_string(flux_layer_norm_probe_checksum) + + " flux_time_embedding_checksum=" + std::to_string(flux_time_embedding_probe_checksum) + + " rotary_checksum=" + std::to_string(rotary_probe_checksum) + + " flux_pos_embed_checksum=" + std::to_string(flux_pos_embed_probe_checksum) + + " flux_rope_checksum=" + std::to_string(flux_rope_probe_checksum) + + " flux_modulation_checksum=" + std::to_string(flux_modulation_probe_checksum) + + " flux_attention_layout_checksum=" + + std::to_string(flux_attention_layout_probe_checksum) + + " flux_single_block_checksum=" + std::to_string(flux_single_block_probe_checksum) + + " flux_double_block_checksum=" + std::to_string(flux_double_block_probe_checksum) + + " flux_output_checksum=" + std::to_string(flux_output_probe_checksum) + + " attention_checksum=" + std::to_string(attention_probe_checksum) + + " scheduler_checksum=" + std::to_string(scheduler_probe_checksum) + + " latent_checksum=" + std::to_string(latent_probe_checksum) + + " prompt_checksum=" + std::to_string(prompt_probe_checksum) + + " vae_ops_checksum=" + std::to_string(vae_ops_probe_checksum) + + " vae_attention_checksum=" + std::to_string(vae_attention_probe_checksum) + + " vae_decode_prelude_checksum=" + + std::to_string(vae_decode_prelude_probe_checksum); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.h b/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.h new file mode 100644 index 000000000..cc4ae0389 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_model_probe.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +struct BonsaiModelPaths { + std::string root_path; + std::string packed_transformer_path; + std::string text_encoder_path; + std::string tokenizer_path; + std::string vae_path; + std::string scheduler_path; +}; + +std::string probe_bonsai_model(const BonsaiModelPaths& paths); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_norm.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_norm.cpp new file mode 100644 index 000000000..28b19c1b9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_norm.cpp @@ -0,0 +1,62 @@ +#include "bonsai_norm.h" + +#include "bonsai_tensor.h" + +#include +#include + +BonsaiRmsNormWeightViews bonsai_require_rms_norm_weight( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& weight_key +) { + BonsaiTensorView weight = storage.require_view(index, weight_key); + if (!bonsai_dtype_is_floating_point(weight.dtype)) { + throw std::runtime_error("Bonsai RMSNorm weight must be floating point: " + weight_key); + } + if (weight.element_count == 0) { + throw std::runtime_error("Bonsai RMSNorm weight must not be empty: " + weight_key); + } + return BonsaiRmsNormWeightViews { + weight, + weight.element_count, + }; +} + +std::vector bonsai_rms_norm( + const std::vector& input, + const BonsaiRmsNormWeightViews& views, + float eps +) { + if (views.dimensions == 0 || input.size() % static_cast(views.dimensions) != 0) { + throw std::runtime_error( + "Bonsai RMSNorm input size mismatch: " + views.weight.descriptor->key + ); + } + + std::vector output; + output.reserve(input.size()); + for (size_t offset = 0; offset < input.size(); offset += static_cast(views.dimensions)) { + double squared_sum = 0.0; + for (uint64_t index = 0; index < views.dimensions; index++) { + const float value = input[offset + static_cast(index)]; + squared_sum += static_cast(value) * static_cast(value); + } + + const float scale = 1.0F / std::sqrt( + static_cast(squared_sum / static_cast(views.dimensions)) + eps + ); + for (uint64_t index = 0; index < views.dimensions; index++) { + const float weight = bonsai_read_scalar_as_f32( + views.weight.data + index * views.weight.dtype_byte_count, + views.weight.dtype + ); + output.push_back(input[offset + static_cast(index)] * scale * weight); + } + } + return output; +} + +uint64_t bonsai_rms_norm_byte_count(const BonsaiRmsNormWeightViews& views) { + return views.weight.byte_count; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_norm.h b/feature/bonsai/src/androidMain/cpp/bonsai_norm.h new file mode 100644 index 000000000..feaf26b60 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_norm.h @@ -0,0 +1,27 @@ +#pragma once + +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include +#include + +struct BonsaiRmsNormWeightViews { + BonsaiTensorView weight; + uint64_t dimensions = 0; +}; + +BonsaiRmsNormWeightViews bonsai_require_rms_norm_weight( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& weight_key +); + +std::vector bonsai_rms_norm( + const std::vector& input, + const BonsaiRmsNormWeightViews& views, + float eps +); + +uint64_t bonsai_rms_norm_byte_count(const BonsaiRmsNormWeightViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.cpp new file mode 100644 index 000000000..a14691e0b --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.cpp @@ -0,0 +1,177 @@ +#include "bonsai_packed_weight.h" + +#include +#include + +namespace { + +void require_supported_bits(int bits) { + if (bits != 1 && bits != 2 && bits != 4) { + throw std::runtime_error("Unsupported Bonsai quantization bits: " + std::to_string(bits)); + } +} + +uint64_t checked_multiply(uint64_t left, uint64_t right, const std::string& tensor_key) { + if (left != 0 && right > UINT64_MAX / left) { + throw std::runtime_error("Bonsai packed weight shape is too large: " + tensor_key); + } + return left * right; +} + +uint64_t leading_row_count(const BonsaiTensorView& view) { + if (view.descriptor->shape.empty()) { + throw std::runtime_error("Bonsai packed weight must have at least one dimension."); + } + + uint64_t rows = 1; + for (size_t index = 0; index + 1 < view.descriptor->shape.size(); index++) { + rows = checked_multiply(rows, view.descriptor->shape[index], view.descriptor->key); + } + return rows; +} + +uint64_t last_dimension(const BonsaiTensorView& view) { + if (view.descriptor->shape.empty()) { + throw std::runtime_error("Bonsai packed weight must have at least one dimension."); + } + return view.descriptor->shape.back(); +} + +void require_same_shape( + const BonsaiTensorView& left, + const BonsaiTensorView& right, + const std::string& label +) { + if (left.descriptor->shape != right.descriptor->shape) { + throw std::runtime_error( + "Bonsai packed weight " + + label + + " shape mismatch: " + + left.descriptor->key + + " vs " + + right.descriptor->key + ); + } +} + +void require_same_leading_shape( + const BonsaiTensorView& packed, + const BonsaiTensorView& scales +) { + const auto& packed_shape = packed.descriptor->shape; + const auto& scales_shape = scales.descriptor->shape; + if (packed_shape.size() != scales_shape.size()) { + throw std::runtime_error( + "Bonsai packed weight rank mismatch: " + packed.descriptor->key + ); + } + for (size_t index = 0; index + 1 < packed_shape.size(); index++) { + if (packed_shape[index] != scales_shape[index]) { + throw std::runtime_error( + "Bonsai packed weight leading shape mismatch: " + packed.descriptor->key + ); + } + } +} + +uint64_t expected_packed_last_dimension( + uint64_t scale_groups, + int bits, + int group_size, + const std::string& tensor_key +) { + const uint64_t values_per_word = 32ULL / static_cast(bits); + const uint64_t values = checked_multiply( + scale_groups, + static_cast(group_size), + tensor_key + ); + if (values % values_per_word != 0) { + throw std::runtime_error( + "Bonsai packed weight group shape is not aligned to uint32 packing: " + tensor_key + ); + } + return values / values_per_word; +} + +void validate_packed_views(const BonsaiPackedWeightViews& views) { + require_supported_bits(views.bits); + if (views.group_size <= 0) { + throw std::runtime_error("Bonsai packed weight group size must be positive."); + } + if (32 % views.bits != 0) { + throw std::runtime_error("Bonsai packed weight bits must divide 32."); + } + if (views.weight.dtype != BonsaiDType::U32) { + throw std::runtime_error( + "packed tensor " + views.weight.descriptor->key + " must be uint32" + ); + } + if (!bonsai_dtype_is_floating_point(views.scales.dtype) || + !bonsai_dtype_is_floating_point(views.biases.dtype) + ) { + throw std::runtime_error( + "Bonsai packed weight scales and biases must be floating point: " + + views.weight.descriptor->key + ); + } + + require_same_shape(views.scales, views.biases, "scale/bias"); + require_same_leading_shape(views.weight, views.scales); + const uint64_t expected_last = expected_packed_last_dimension( + last_dimension(views.scales), + views.bits, + views.group_size, + views.weight.descriptor->key + ); + if (last_dimension(views.weight) != expected_last) { + throw std::runtime_error( + "Bonsai packed weight data shape mismatch: " + views.weight.descriptor->key + ); + } +} + +} // namespace + +BonsaiPackedWeightViews bonsai_require_packed_weight_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor +) { + BonsaiPackedWeightViews views { + descriptor.packed, + storage.require_view(index, descriptor.weight_key), + {}, + {}, + descriptor.bits, + descriptor.group_size, + 0, + 0, + }; + + views.leading_rows = leading_row_count(views.weight); + if (!descriptor.packed) { + views.input_values = last_dimension(views.weight); + return views; + } + + views.scales = storage.require_view(index, descriptor.scales_key); + views.biases = storage.require_view(index, descriptor.biases_key); + validate_packed_views(views); + views.leading_rows = leading_row_count(views.weight); + views.input_values = checked_multiply( + last_dimension(views.scales), + static_cast(descriptor.group_size), + views.weight.descriptor->key + ); + return views; +} + +uint64_t bonsai_packed_weight_byte_count(const BonsaiPackedWeightViews& views) { + uint64_t bytes = views.weight.byte_count; + if (views.packed) { + bytes += views.scales.byte_count; + bytes += views.biases.byte_count; + } + return bytes; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.h b/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.h new file mode 100644 index 000000000..3987436ab --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_packed_weight.h @@ -0,0 +1,25 @@ +#pragma once + +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include + +struct BonsaiPackedWeightViews { + bool packed = false; + BonsaiTensorView weight; + BonsaiTensorView scales; + BonsaiTensorView biases; + int bits = 0; + int group_size = 0; + uint64_t leading_rows = 0; + uint64_t input_values = 0; +}; + +BonsaiPackedWeightViews bonsai_require_packed_weight_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiPackedWeightDescriptor& descriptor +); + +uint64_t bonsai_packed_weight_byte_count(const BonsaiPackedWeightViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_prompt.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_prompt.cpp new file mode 100644 index 000000000..1425e9a11 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_prompt.cpp @@ -0,0 +1,86 @@ +#include "bonsai_prompt.h" + +#include "bonsai_tokenizer.h" + +#include +#include +#include + +BonsaiQwenPromptSpec bonsai_qwen_prompt_spec() { + return BonsaiQwenPromptSpec {}; +} + +std::string bonsai_qwen_chat_formatted_prompt(const std::string& prompt) { + return "<|im_start|>user\n" + + prompt + + "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; +} + +std::vector bonsai_qwen_text_ids(uint64_t length) { + if (length > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai Qwen text id length is too large."); + } + std::vector values; + values.reserve(static_cast(length * 4U)); + for (uint64_t index = 0; index < length; index++) { + values.push_back(0); + values.push_back(0); + values.push_back(0); + values.push_back(static_cast(index)); + } + return values; +} + +BonsaiPromptEncodingPlan bonsai_prepare_qwen_prompt_encoding_plan( + const std::string& prompt, + const std::string& negative_prompt, + float cfg_scale, + const BonsaiTokenizerMetadata& tokenizer_metadata +) { + if (!std::isfinite(cfg_scale)) { + throw std::runtime_error("Bonsai CFG scale must be finite."); + } + + const BonsaiQwenPromptSpec spec = bonsai_qwen_prompt_spec(); + BonsaiPromptEncodingPlan plan; + plan.formatted_prompt = bonsai_qwen_chat_formatted_prompt(prompt); + plan.max_sequence_length = spec.max_sequence_length; + plan.pad_token_id = tokenizer_metadata.pad_token_id != 0 + ? tokenizer_metadata.pad_token_id + : spec.pad_token_id; + plan.eos_token_id = tokenizer_metadata.eos_token_id != 0 + ? tokenizer_metadata.eos_token_id + : spec.eos_token_id; + plan.uses_negative_prompt = cfg_scale > 1.0F; + plan.text_ids = bonsai_qwen_text_ids(plan.max_sequence_length); + + if (plan.uses_negative_prompt) { + const std::string effective_negative_prompt = negative_prompt.empty() + ? " " + : negative_prompt; + plan.formatted_negative_prompt = bonsai_qwen_chat_formatted_prompt( + effective_negative_prompt + ); + plan.negative_text_ids = bonsai_qwen_text_ids(plan.max_sequence_length); + } + + return plan; +} + +uint64_t bonsai_prompt_encoding_plan_checksum(const BonsaiPromptEncodingPlan& plan) { + uint64_t checksum = plan.max_sequence_length * 3U; + checksum += static_cast(plan.pad_token_id) * 5U; + checksum += static_cast(plan.eos_token_id) * 7U; + checksum += plan.uses_negative_prompt ? 11U : 0U; + checksum += static_cast(plan.formatted_prompt.size()) * 13U; + checksum += static_cast(plan.formatted_negative_prompt.size()) * 17U; + checksum += static_cast(plan.text_ids.size()) * 19U; + checksum += static_cast(plan.negative_text_ids.size()) * 23U; + for (char value : plan.formatted_prompt) { + checksum += static_cast(static_cast(value)); + } + for (char value : plan.formatted_negative_prompt) { + checksum += static_cast(static_cast(value)) * 2U; + } + return checksum; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_prompt.h b/feature/bonsai/src/androidMain/cpp/bonsai_prompt.h new file mode 100644 index 000000000..6c8f90eac --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_prompt.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include + +struct BonsaiTokenizerMetadata; + +struct BonsaiQwenPromptSpec { + uint64_t max_sequence_length = 512; + int32_t pad_token_id = 151643; + int32_t eos_token_id = 151645; +}; + +struct BonsaiPromptEncodingPlan { + std::string formatted_prompt; + std::string formatted_negative_prompt; + uint64_t max_sequence_length = 0; + int32_t pad_token_id = 0; + int32_t eos_token_id = 0; + bool uses_negative_prompt = false; + std::vector text_ids; + std::vector negative_text_ids; +}; + +BonsaiQwenPromptSpec bonsai_qwen_prompt_spec(); + +std::string bonsai_qwen_chat_formatted_prompt(const std::string& prompt); + +std::vector bonsai_qwen_text_ids(uint64_t length); + +BonsaiPromptEncodingPlan bonsai_prepare_qwen_prompt_encoding_plan( + const std::string& prompt, + const std::string& negative_prompt, + float cfg_scale, + const BonsaiTokenizerMetadata& tokenizer_metadata +); + +uint64_t bonsai_prompt_encoding_plan_checksum(const BonsaiPromptEncodingPlan& plan); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_qwen.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_qwen.cpp new file mode 100644 index 000000000..4a623a963 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_qwen.cpp @@ -0,0 +1,964 @@ +#include "bonsai_qwen.h" + +#include "bonsai_activation.h" +#include "bonsai_attention.h" +#include "bonsai_linear.h" +#include "bonsai_norm.h" +#include "bonsai_rotary.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; +constexpr uint64_t QWEN_HIDDEN_SIZE = 2560; +constexpr uint64_t QWEN_ATTENTION_HEADS = 32; +constexpr uint64_t QWEN_KEY_VALUE_HEADS = 8; +constexpr uint64_t QWEN_HEAD_DIMENSION = 128; +constexpr uint64_t QWEN_LAYER_COUNT = 36; +constexpr int QWEN_BITS = 4; +constexpr int QWEN_GROUP_SIZE = 64; + +std::vector synthetic_input(uint64_t size) { + std::vector input; + input.reserve(static_cast(size)); + for (uint64_t index = 0; index < size; index++) { + input.push_back((static_cast(index % 11U) - 5.0F) * 0.0625F); + } + return input; +} + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai Qwen shape is too large: ") + label); + } + return left * right; +} + +void add_bytes(uint64_t* bytes, uint64_t extra, const char* label) { + if (*bytes > std::numeric_limits::max() - extra) { + throw std::runtime_error(std::string("Bonsai Qwen byte count overflow: ") + label); + } + *bytes += extra; +} + +std::string qwen_key( + const BonsaiSafetensorsIndex& index, + const std::string& suffix +) { + return index.resolve_model_prefixed_key(suffix); +} + +void require_qwen_linear_tensor( + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + const std::string& name +) { + index.require_packed_weight( + qwen_key(index, prefix + "." + name + ".weight"), + QWEN_BITS, + QWEN_GROUP_SIZE + ); +} + +void require_qwen_norm_tensor( + const BonsaiSafetensorsIndex& index, + const std::string& suffix +) { + index.require(qwen_key(index, suffix)); +} + +BonsaiLinearViews require_qwen_linear( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + const std::string& name +) { + const std::string weight_key = qwen_key(index, prefix + "." + name + ".weight"); + return bonsai_require_packed_linear_views( + storage, + index, + index.require_packed_weight(weight_key, QWEN_BITS, QWEN_GROUP_SIZE), + qwen_key(index, prefix + "." + name + ".bias") + ); +} + +BonsaiRmsNormWeightViews require_qwen_norm( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& suffix +) { + return bonsai_require_rms_norm_weight(storage, index, qwen_key(index, suffix)); +} + +void require_linear_shape( + const BonsaiLinearViews& linear, + uint64_t expected_input, + uint64_t expected_output, + const std::string& label +) { + if (linear.input_values != expected_input || linear.output_rows != expected_output) { + throw std::runtime_error("Bonsai Qwen linear shape mismatch: " + label); + } +} + +void require_norm_shape( + const BonsaiRmsNormWeightViews& norm, + uint64_t expected_dimensions, + const std::string& label +) { + if (norm.dimensions != expected_dimensions) { + throw std::runtime_error("Bonsai Qwen norm shape mismatch: " + label); + } +} + +std::vector add_vectors( + const std::vector& left, + const std::vector& right, + const char* label +) { + if (left.size() != right.size()) { + throw std::runtime_error(std::string("Bonsai Qwen residual size mismatch: ") + label); + } + std::vector output; + output.reserve(left.size()); + for (size_t index = 0; index < left.size(); index++) { + output.push_back(left[index] + right[index]); + } + return output; +} + +std::vector sequence_to_heads( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + const uint64_t hidden_size = checked_multiply(heads, head_dimension, "sequence heads"); + const uint64_t expected_size = checked_multiply( + checked_multiply(batch, sequence_length, "sequence heads"), + hidden_size, + "sequence heads" + ); + if (input.size() != static_cast(expected_size)) { + throw std::runtime_error("Bonsai Qwen sequence-to-heads input size mismatch."); + } + + std::vector output; + output.reserve(static_cast(expected_size)); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t token = 0; token < sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const uint64_t source = ((batch_index * sequence_length + token) * hidden_size) + + head * head_dimension + + column; + output.push_back(input[static_cast(source)]); + } + } + } + } + return output; +} + +std::vector heads_to_sequence( + const std::vector& input, + uint64_t batch, + uint64_t sequence_length, + uint64_t heads, + uint64_t head_dimension +) { + const uint64_t hidden_size = checked_multiply(heads, head_dimension, "heads sequence"); + const uint64_t expected_size = checked_multiply( + checked_multiply(batch, sequence_length, "heads sequence"), + hidden_size, + "heads sequence" + ); + if (input.size() != static_cast(expected_size)) { + throw std::runtime_error("Bonsai Qwen heads-to-sequence input size mismatch."); + } + + std::vector output(static_cast(expected_size), 0.0F); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t token = 0; token < sequence_length; token++) { + for (uint64_t column = 0; column < head_dimension; column++) { + const uint64_t source = (((batch_index * heads + head) * sequence_length + token) * + head_dimension) + column; + const uint64_t target = ((batch_index * sequence_length + token) * hidden_size) + + head * head_dimension + + column; + output[static_cast(target)] = input[static_cast(source)]; + } + } + } + } + return output; +} + +std::vector apply_rotary_to_batched_heads( + const std::vector& input, + uint64_t batch, + uint64_t heads, + uint64_t sequence_length, + uint64_t head_dimension +) { + const uint64_t expected_size = checked_multiply( + checked_multiply(checked_multiply(batch, heads, "qwen rotary"), sequence_length, "qwen rotary"), + head_dimension, + "qwen rotary" + ); + if (input.size() != static_cast(expected_size)) { + throw std::runtime_error("Bonsai Qwen rotary input size mismatch."); + } + + std::vector output; + output.reserve(input.size()); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + for (uint64_t head = 0; head < heads; head++) { + for (uint64_t token = 0; token < sequence_length; token++) { + std::vector head_values; + head_values.reserve(static_cast(head_dimension)); + const uint64_t offset = (((batch_index * heads + head) * sequence_length + token) * + head_dimension); + for (uint64_t column = 0; column < head_dimension; column++) { + head_values.push_back(input[static_cast(offset + column)]); + } + const std::vector rotated = bonsai_apply_rotary_to_heads( + head_values, + head_dimension, + token, + 1000000.0F + ); + output.insert(output.end(), rotated.begin(), rotated.end()); + } + } + } + return output; +} + +std::vector repeat_kv_heads_batched( + const std::vector& input, + uint64_t batch, + uint64_t key_value_heads, + uint64_t repeats, + uint64_t sequence_length, + uint64_t head_dimension +) { + const uint64_t batch_stride = checked_multiply( + checked_multiply(key_value_heads, sequence_length, "qwen repeat kv"), + head_dimension, + "qwen repeat kv" + ); + if (input.size() != static_cast(batch_stride * batch)) { + throw std::runtime_error("Bonsai Qwen repeat-KV input size mismatch."); + } + + std::vector output; + output.reserve(static_cast(batch_stride * repeats * batch)); + for (uint64_t batch_index = 0; batch_index < batch; batch_index++) { + const auto start = input.begin() + static_cast(batch_index * batch_stride); + const auto end = start + static_cast(batch_stride); + const std::vector repeated = bonsai_repeat_kv_heads( + std::vector(start, end), + key_value_heads, + repeats, + sequence_length, + head_dimension + ); + output.insert(output.end(), repeated.begin(), repeated.end()); + } + return output; +} + +std::vector qwen_attention_sequence( + const BonsaiQwenAttentionViews& views, + const std::vector& hidden_states, + uint64_t batch, + uint64_t sequence_length, + const std::vector& additive_attention_mask +) { + const std::vector query_sequence = bonsai_linear_sequence( + views.q_proj, + hidden_states, + batch, + sequence_length + ); + const std::vector key_sequence = bonsai_linear_sequence( + views.k_proj, + hidden_states, + batch, + sequence_length + ); + const std::vector value_sequence = bonsai_linear_sequence( + views.v_proj, + hidden_states, + batch, + sequence_length + ); + + std::vector queries = sequence_to_heads( + bonsai_rms_norm(query_sequence, views.q_norm, 1e-6F), + batch, + sequence_length, + views.attention_heads, + views.head_dimension + ); + std::vector keys = sequence_to_heads( + bonsai_rms_norm(key_sequence, views.k_norm, 1e-6F), + batch, + sequence_length, + views.key_value_heads, + views.head_dimension + ); + std::vector values = sequence_to_heads( + value_sequence, + batch, + sequence_length, + views.key_value_heads, + views.head_dimension + ); + + queries = apply_rotary_to_batched_heads( + queries, + batch, + views.attention_heads, + sequence_length, + views.head_dimension + ); + keys = apply_rotary_to_batched_heads( + keys, + batch, + views.key_value_heads, + sequence_length, + views.head_dimension + ); + keys = repeat_kv_heads_batched( + keys, + batch, + views.key_value_heads, + views.attention_heads / views.key_value_heads, + sequence_length, + views.head_dimension + ); + values = repeat_kv_heads_batched( + values, + batch, + views.key_value_heads, + views.attention_heads / views.key_value_heads, + sequence_length, + views.head_dimension + ); + + const std::vector attended = bonsai_scaled_dot_product_attention( + queries, + keys, + values, + additive_attention_mask, + batch * views.attention_heads, + sequence_length, + views.head_dimension, + views.scale + ); + return bonsai_linear_sequence( + views.o_proj, + heads_to_sequence( + attended, + batch, + sequence_length, + views.attention_heads, + views.head_dimension + ), + batch, + sequence_length + ); +} + +std::vector qwen_mlp_sequence( + const BonsaiQwenMlpViews& views, + const std::vector& hidden_states, + uint64_t batch, + uint64_t sequence_length +) { + return bonsai_linear_sequence( + views.down_proj, + bonsai_silu_times( + bonsai_linear_sequence(views.gate_proj, hidden_states, batch, sequence_length), + bonsai_linear_sequence(views.up_proj, hidden_states, batch, sequence_length) + ), + batch, + sequence_length + ); +} + +double checksum_linear_rows( + const BonsaiLinearViews& linear, + uint64_t row_count +) { + const std::vector input = synthetic_input(linear.input_values); + const uint64_t limit = std::min(row_count, linear.output_rows); + double checksum = 0.0; + for (uint64_t row = 0; row < limit; row++) { + checksum += static_cast(bonsai_linear_row(linear, input, row)); + } + return checksum; +} + +std::vector sample_linear_rows( + const BonsaiLinearViews& linear, + uint64_t row_count +) { + const std::vector input = synthetic_input(linear.input_values); + const uint64_t limit = std::min(row_count, linear.output_rows); + std::vector output; + output.reserve(static_cast(limit)); + for (uint64_t row = 0; row < limit; row++) { + output.push_back(bonsai_linear_row(linear, input, row)); + } + return output; +} + +double checksum_norm( + const BonsaiRmsNormWeightViews& norm +) { + const std::vector output = bonsai_rms_norm( + synthetic_input(norm.dimensions), + norm, + 1e-6F + ); + const size_t limit = std::min(output.size(), 32); + double checksum = 0.0; + for (size_t index = 0; index < limit; index++) { + checksum += static_cast(output[index]); + } + return checksum; +} + +std::vector checked_token_ids( + const BonsaiQwenTextEncoderViews& views, + const std::vector& input_ids +) { + std::vector token_ids; + token_ids.reserve(input_ids.size()); + for (int32_t input_id : input_ids) { + if (input_id < 0) { + throw std::runtime_error("Bonsai Qwen token id must be non-negative."); + } + const uint64_t token_id = static_cast(input_id); + if (token_id >= views.embedding.rows) { + throw std::runtime_error("Bonsai Qwen token id is outside embedding rows."); + } + token_ids.push_back(token_id); + } + return token_ids; +} + +std::vector qwen_additive_attention_mask( + const std::vector& attention_mask, + uint64_t sequence_length +) { + if (attention_mask.size() != static_cast(sequence_length)) { + throw std::runtime_error("Bonsai Qwen attention mask length mismatch."); + } + + std::vector output; + output.reserve(static_cast(checked_multiply( + sequence_length, + sequence_length, + "qwen attention mask" + ))); + for (uint64_t row = 0; row < sequence_length; row++) { + for (uint64_t column = 0; column < sequence_length; column++) { + const bool masked_future_token = column > row; + const bool masked_padding_token = + attention_mask[static_cast(column)] != 1; + output.push_back( + masked_future_token || masked_padding_token + ? -std::numeric_limits::infinity() + : 0.0F + ); + } + } + return output; +} + +std::vector flatten_selected_states( + const std::vector>& selected_states, + uint64_t sequence_length, + uint64_t hidden_size +) { + if (selected_states.empty()) { + throw std::runtime_error("Bonsai Qwen selected hidden states are empty."); + } + const uint64_t selected_count = static_cast(selected_states.size()); + const uint64_t state_size = checked_multiply( + sequence_length, + hidden_size, + "qwen selected state" + ); + const uint64_t output_size = checked_multiply( + state_size, + selected_count, + "qwen selected states" + ); + for (const std::vector& state : selected_states) { + if (state.size() != static_cast(state_size)) { + throw std::runtime_error("Bonsai Qwen selected hidden state size mismatch."); + } + } + + std::vector output; + output.reserve(static_cast(output_size)); + for (uint64_t token = 0; token < sequence_length; token++) { + for (const std::vector& state : selected_states) { + const uint64_t offset = token * hidden_size; + output.insert( + output.end(), + state.begin() + static_cast(offset), + state.begin() + static_cast(offset + hidden_size) + ); + } + } + return output; +} + +void log_qwen_layer_phase(const char* phase, uint64_t layer, uint64_t sequence_length) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=%s layer=%llu sequence=%llu", + phase, + static_cast(layer), + static_cast(sequence_length) + ); +} + +} // namespace + +BonsaiQwenInventorySummary bonsai_require_qwen_text_encoder_tensors( + const BonsaiSafetensorsIndex& index +) { + BonsaiQwenInventorySummary summary { + QWEN_LAYER_COUNT, + 0, + }; + + require_qwen_norm_tensor(index, "norm.weight"); + summary.logical_tensor_count++; + + for (uint64_t layer = 0; layer < QWEN_LAYER_COUNT; layer++) { + const std::string prefix = "layers." + std::to_string(layer); + require_qwen_norm_tensor(index, prefix + ".input_layernorm.weight"); + require_qwen_norm_tensor(index, prefix + ".post_attention_layernorm.weight"); + require_qwen_norm_tensor(index, prefix + ".self_attn.q_norm.weight"); + require_qwen_norm_tensor(index, prefix + ".self_attn.k_norm.weight"); + summary.logical_tensor_count += 4; + + require_qwen_linear_tensor(index, prefix + ".self_attn", "q_proj"); + require_qwen_linear_tensor(index, prefix + ".self_attn", "k_proj"); + require_qwen_linear_tensor(index, prefix + ".self_attn", "v_proj"); + require_qwen_linear_tensor(index, prefix + ".self_attn", "o_proj"); + require_qwen_linear_tensor(index, prefix + ".mlp", "gate_proj"); + require_qwen_linear_tensor(index, prefix + ".mlp", "up_proj"); + require_qwen_linear_tensor(index, prefix + ".mlp", "down_proj"); + summary.logical_tensor_count += 7; + } + + return summary; +} + +BonsaiQwenLayerViews bonsai_require_qwen_layer_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + uint64_t layer +) { + if (layer >= QWEN_LAYER_COUNT) { + throw std::runtime_error("Bonsai Qwen layer index is out of range."); + } + const std::string prefix = "layers." + std::to_string(layer); + BonsaiQwenLayerViews views { + require_qwen_norm(storage, index, prefix + ".input_layernorm.weight"), + require_qwen_norm(storage, index, prefix + ".post_attention_layernorm.weight"), + {}, + {}, + QWEN_HIDDEN_SIZE, + }; + views.attention = BonsaiQwenAttentionViews { + require_qwen_linear(storage, index, prefix + ".self_attn", "q_proj"), + require_qwen_linear(storage, index, prefix + ".self_attn", "k_proj"), + require_qwen_linear(storage, index, prefix + ".self_attn", "v_proj"), + require_qwen_linear(storage, index, prefix + ".self_attn", "o_proj"), + require_qwen_norm(storage, index, prefix + ".self_attn.q_norm.weight"), + require_qwen_norm(storage, index, prefix + ".self_attn.k_norm.weight"), + QWEN_HIDDEN_SIZE, + QWEN_ATTENTION_HEADS, + QWEN_KEY_VALUE_HEADS, + QWEN_HEAD_DIMENSION, + 1.0F / std::sqrt(static_cast(QWEN_HEAD_DIMENSION)), + }; + views.mlp = BonsaiQwenMlpViews { + require_qwen_linear(storage, index, prefix + ".mlp", "gate_proj"), + require_qwen_linear(storage, index, prefix + ".mlp", "up_proj"), + require_qwen_linear(storage, index, prefix + ".mlp", "down_proj"), + QWEN_HIDDEN_SIZE, + 0, + }; + views.mlp.intermediate_size = views.mlp.gate_proj.output_rows; + + require_norm_shape(views.input_norm, QWEN_HIDDEN_SIZE, "input_layernorm"); + require_norm_shape(views.post_attention_norm, QWEN_HIDDEN_SIZE, "post_attention_layernorm"); + require_norm_shape(views.attention.q_norm, QWEN_HEAD_DIMENSION, "q_norm"); + require_norm_shape(views.attention.k_norm, QWEN_HEAD_DIMENSION, "k_norm"); + require_linear_shape( + views.attention.q_proj, + QWEN_HIDDEN_SIZE, + QWEN_ATTENTION_HEADS * QWEN_HEAD_DIMENSION, + "q_proj" + ); + require_linear_shape( + views.attention.k_proj, + QWEN_HIDDEN_SIZE, + QWEN_KEY_VALUE_HEADS * QWEN_HEAD_DIMENSION, + "k_proj" + ); + require_linear_shape( + views.attention.v_proj, + QWEN_HIDDEN_SIZE, + QWEN_KEY_VALUE_HEADS * QWEN_HEAD_DIMENSION, + "v_proj" + ); + require_linear_shape( + views.attention.o_proj, + QWEN_ATTENTION_HEADS * QWEN_HEAD_DIMENSION, + QWEN_HIDDEN_SIZE, + "o_proj" + ); + require_linear_shape( + views.mlp.gate_proj, + QWEN_HIDDEN_SIZE, + views.mlp.intermediate_size, + "gate_proj" + ); + require_linear_shape( + views.mlp.up_proj, + QWEN_HIDDEN_SIZE, + views.mlp.intermediate_size, + "up_proj" + ); + require_linear_shape( + views.mlp.down_proj, + views.mlp.intermediate_size, + QWEN_HIDDEN_SIZE, + "down_proj" + ); + return views; +} + +BonsaiQwenTextEncoderViews bonsai_require_qwen_text_encoder_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index +) { + const std::string embedding_key = qwen_key(index, "embed_tokens.weight"); + BonsaiQwenTextEncoderViews views { + bonsai_require_embedding_views( + storage, + index, + index.require_packed_weight(embedding_key, QWEN_BITS, QWEN_GROUP_SIZE) + ), + require_qwen_norm(storage, index, "norm.weight"), + {}, + {9, 18, 27}, + QWEN_HIDDEN_SIZE, + }; + if (views.embedding.dimensions != QWEN_HIDDEN_SIZE || + views.final_norm.dimensions != QWEN_HIDDEN_SIZE) { + throw std::runtime_error("Bonsai Qwen text encoder top-level shape mismatch."); + } + views.layers.reserve(static_cast(QWEN_LAYER_COUNT)); + for (uint64_t layer = 0; layer < QWEN_LAYER_COUNT; layer++) { + views.layers.push_back(bonsai_require_qwen_layer_views(storage, index, layer)); + } + return views; +} + +std::vector bonsai_qwen_layer_sequence( + const BonsaiQwenLayerViews& views, + const std::vector& hidden_states, + uint64_t batch, + uint64_t sequence_length, + const std::vector& additive_attention_mask +) { + if (batch == 0 || sequence_length == 0 || views.hidden_size == 0) { + throw std::runtime_error("Bonsai Qwen layer sequence shape must be positive."); + } + const uint64_t expected_input = checked_multiply( + checked_multiply(batch, sequence_length, "qwen layer sequence"), + views.hidden_size, + "qwen layer sequence" + ); + if (hidden_states.size() != static_cast(expected_input)) { + throw std::runtime_error("Bonsai Qwen layer sequence input size mismatch."); + } + + const std::vector attended = qwen_attention_sequence( + views.attention, + bonsai_rms_norm(hidden_states, views.input_norm, 1e-6F), + batch, + sequence_length, + additive_attention_mask + ); + const std::vector after_attention = add_vectors( + hidden_states, + attended, + "attention" + ); + return add_vectors( + after_attention, + qwen_mlp_sequence( + views.mlp, + bonsai_rms_norm(after_attention, views.post_attention_norm, 1e-6F), + batch, + sequence_length + ), + "mlp" + ); +} + +BonsaiQwenPromptEmbeddings bonsai_qwen_text_encoder_forward( + const BonsaiQwenTextEncoderViews& views, + const std::vector& input_ids, + const std::vector& attention_mask +) { + if (input_ids.empty()) { + throw std::runtime_error("Bonsai Qwen input ids must not be empty."); + } + if (input_ids.size() != attention_mask.size()) { + throw std::runtime_error("Bonsai Qwen input ids and mask length mismatch."); + } + if (views.hidden_size == 0 || + views.layers.empty() || + views.embedding.dimensions != views.hidden_size) { + throw std::runtime_error("Bonsai Qwen text encoder views are incomplete."); + } + + const uint64_t batch = 1; + const uint64_t sequence_length = static_cast(input_ids.size()); + std::vector hidden_states = bonsai_embedding_lookup( + views.embedding, + checked_token_ids(views, input_ids) + ); + const std::vector additive_attention_mask = qwen_additive_attention_mask( + attention_mask, + sequence_length + ); + + std::vector> selected_states; + selected_states.reserve(views.hidden_state_layers.size()); + for (uint64_t layer = 0; layer < static_cast(views.layers.size()); layer++) { + log_qwen_layer_phase("qwen_layer_start", layer + 1U, sequence_length); + hidden_states = bonsai_qwen_layer_sequence( + views.layers[static_cast(layer)], + hidden_states, + batch, + sequence_length, + additive_attention_mask + ); + log_qwen_layer_phase("qwen_layer_done", layer + 1U, sequence_length); + const uint64_t state_index = layer + 1; + if (std::find( + views.hidden_state_layers.begin(), + views.hidden_state_layers.end(), + state_index + ) != views.hidden_state_layers.end()) { + selected_states.push_back(hidden_states); + } + } + + if (selected_states.size() != views.hidden_state_layers.size()) { + throw std::runtime_error("Bonsai Qwen did not produce all selected hidden states."); + } + + BonsaiQwenPromptEmbeddings embeddings; + embeddings.batch = batch; + embeddings.sequence_length = sequence_length; + embeddings.hidden_size = checked_multiply( + views.hidden_size, + static_cast(selected_states.size()), + "qwen prompt embedding width" + ); + embeddings.selected_layer_count = static_cast(selected_states.size()); + embeddings.values = flatten_selected_states( + selected_states, + sequence_length, + views.hidden_size + ); + return embeddings; +} + +uint64_t bonsai_qwen_layer_byte_count(const BonsaiQwenLayerViews& views) { + uint64_t bytes = bonsai_rms_norm_byte_count(views.input_norm); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.post_attention_norm), "post norm"); + add_bytes(&bytes, bonsai_linear_byte_count(views.attention.q_proj), "q proj"); + add_bytes(&bytes, bonsai_linear_byte_count(views.attention.k_proj), "k proj"); + add_bytes(&bytes, bonsai_linear_byte_count(views.attention.v_proj), "v proj"); + add_bytes(&bytes, bonsai_linear_byte_count(views.attention.o_proj), "o proj"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.attention.q_norm), "q norm"); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.attention.k_norm), "k norm"); + add_bytes(&bytes, bonsai_linear_byte_count(views.mlp.gate_proj), "gate proj"); + add_bytes(&bytes, bonsai_linear_byte_count(views.mlp.up_proj), "up proj"); + add_bytes(&bytes, bonsai_linear_byte_count(views.mlp.down_proj), "down proj"); + return bytes; +} + +uint64_t bonsai_qwen_text_encoder_byte_count(const BonsaiQwenTextEncoderViews& views) { + uint64_t bytes = bonsai_embedding_byte_count(views.embedding); + add_bytes(&bytes, bonsai_rms_norm_byte_count(views.final_norm), "final norm"); + for (const BonsaiQwenLayerViews& layer : views.layers) { + add_bytes(&bytes, bonsai_qwen_layer_byte_count(layer), "layer"); + } + return bytes; +} + +BonsaiQwenLayerProbeSummary bonsai_probe_qwen_text_encoder_layer0( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index +) { + BonsaiQwenLayerProbeSummary summary; + + const BonsaiRmsNormWeightViews input_norm = require_qwen_norm( + storage, + index, + "layers.0.input_layernorm.weight" + ); + const BonsaiRmsNormWeightViews post_attention_norm = require_qwen_norm( + storage, + index, + "layers.0.post_attention_layernorm.weight" + ); + const BonsaiRmsNormWeightViews q_norm = require_qwen_norm( + storage, + index, + "layers.0.self_attn.q_norm.weight" + ); + const BonsaiRmsNormWeightViews k_norm = require_qwen_norm( + storage, + index, + "layers.0.self_attn.k_norm.weight" + ); + + require_norm_shape(input_norm, QWEN_HIDDEN_SIZE, "input_layernorm"); + require_norm_shape(post_attention_norm, QWEN_HIDDEN_SIZE, "post_attention_layernorm"); + require_norm_shape(q_norm, QWEN_HEAD_DIMENSION, "q_norm"); + require_norm_shape(k_norm, QWEN_HEAD_DIMENSION, "k_norm"); + + const BonsaiLinearViews q_proj = require_qwen_linear( + storage, + index, + "layers.0.self_attn", + "q_proj" + ); + const BonsaiLinearViews k_proj = require_qwen_linear( + storage, + index, + "layers.0.self_attn", + "k_proj" + ); + const BonsaiLinearViews v_proj = require_qwen_linear( + storage, + index, + "layers.0.self_attn", + "v_proj" + ); + const BonsaiLinearViews o_proj = require_qwen_linear( + storage, + index, + "layers.0.self_attn", + "o_proj" + ); + const BonsaiLinearViews gate_proj = require_qwen_linear( + storage, + index, + "layers.0.mlp", + "gate_proj" + ); + const BonsaiLinearViews up_proj = require_qwen_linear( + storage, + index, + "layers.0.mlp", + "up_proj" + ); + const BonsaiLinearViews down_proj = require_qwen_linear( + storage, + index, + "layers.0.mlp", + "down_proj" + ); + + require_linear_shape( + q_proj, + QWEN_HIDDEN_SIZE, + QWEN_ATTENTION_HEADS * QWEN_HEAD_DIMENSION, + "q_proj" + ); + require_linear_shape( + k_proj, + QWEN_HIDDEN_SIZE, + QWEN_KEY_VALUE_HEADS * QWEN_HEAD_DIMENSION, + "k_proj" + ); + require_linear_shape( + v_proj, + QWEN_HIDDEN_SIZE, + QWEN_KEY_VALUE_HEADS * QWEN_HEAD_DIMENSION, + "v_proj" + ); + require_linear_shape( + o_proj, + QWEN_ATTENTION_HEADS * QWEN_HEAD_DIMENSION, + QWEN_HIDDEN_SIZE, + "o_proj" + ); + require_linear_shape(gate_proj, QWEN_HIDDEN_SIZE, up_proj.output_rows, "gate_proj"); + require_linear_shape(up_proj, QWEN_HIDDEN_SIZE, gate_proj.output_rows, "up_proj"); + require_linear_shape(down_proj, gate_proj.output_rows, QWEN_HIDDEN_SIZE, "down_proj"); + + summary.bytes += bonsai_rms_norm_byte_count(input_norm); + summary.bytes += bonsai_rms_norm_byte_count(post_attention_norm); + summary.bytes += bonsai_rms_norm_byte_count(q_norm); + summary.bytes += bonsai_rms_norm_byte_count(k_norm); + summary.bytes += bonsai_linear_byte_count(q_proj); + summary.bytes += bonsai_linear_byte_count(k_proj); + summary.bytes += bonsai_linear_byte_count(v_proj); + summary.bytes += bonsai_linear_byte_count(o_proj); + summary.bytes += bonsai_linear_byte_count(gate_proj); + summary.bytes += bonsai_linear_byte_count(up_proj); + summary.bytes += bonsai_linear_byte_count(down_proj); + + summary.checksum += checksum_norm(input_norm); + summary.checksum += checksum_norm(post_attention_norm); + summary.checksum += checksum_norm(q_norm); + summary.checksum += checksum_norm(k_norm); + summary.checksum += checksum_linear_rows(q_proj, 4); + summary.checksum += checksum_linear_rows(k_proj, 4); + summary.checksum += checksum_linear_rows(v_proj, 4); + summary.checksum += checksum_linear_rows(o_proj, 4); + summary.checksum += checksum_linear_rows(down_proj, 4); + + const std::vector gate_values = sample_linear_rows(gate_proj, 8); + const std::vector up_values = sample_linear_rows(up_proj, 8); + const std::vector gated = bonsai_silu_times(gate_values, up_values); + for (float value : gated) { + summary.checksum += static_cast(value); + } + + return summary; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_qwen.h b/feature/bonsai/src/androidMain/cpp/bonsai_qwen.h new file mode 100644 index 000000000..0d6b07a79 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_qwen.h @@ -0,0 +1,104 @@ +#pragma once + +#include "bonsai_embedding.h" +#include "bonsai_linear.h" +#include "bonsai_norm.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include + +struct BonsaiQwenLayerProbeSummary { + uint64_t bytes = 0; + double checksum = 0.0; +}; + +struct BonsaiQwenInventorySummary { + uint64_t layer_count = 0; + uint64_t logical_tensor_count = 0; +}; + +struct BonsaiQwenAttentionViews { + BonsaiLinearViews q_proj; + BonsaiLinearViews k_proj; + BonsaiLinearViews v_proj; + BonsaiLinearViews o_proj; + BonsaiRmsNormWeightViews q_norm; + BonsaiRmsNormWeightViews k_norm; + uint64_t hidden_size = 0; + uint64_t attention_heads = 0; + uint64_t key_value_heads = 0; + uint64_t head_dimension = 0; + float scale = 0.0F; +}; + +struct BonsaiQwenMlpViews { + BonsaiLinearViews gate_proj; + BonsaiLinearViews up_proj; + BonsaiLinearViews down_proj; + uint64_t hidden_size = 0; + uint64_t intermediate_size = 0; +}; + +struct BonsaiQwenLayerViews { + BonsaiRmsNormWeightViews input_norm; + BonsaiRmsNormWeightViews post_attention_norm; + BonsaiQwenAttentionViews attention; + BonsaiQwenMlpViews mlp; + uint64_t hidden_size = 0; +}; + +struct BonsaiQwenTextEncoderViews { + BonsaiEmbeddingViews embedding; + BonsaiRmsNormWeightViews final_norm; + std::vector layers; + std::vector hidden_state_layers; + uint64_t hidden_size = 0; +}; + +struct BonsaiQwenPromptEmbeddings { + std::vector values; + uint64_t batch = 1; + uint64_t sequence_length = 0; + uint64_t hidden_size = 0; + uint64_t selected_layer_count = 0; +}; + +BonsaiQwenInventorySummary bonsai_require_qwen_text_encoder_tensors( + const BonsaiSafetensorsIndex& index +); + +BonsaiQwenLayerViews bonsai_require_qwen_layer_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + uint64_t layer +); + +BonsaiQwenTextEncoderViews bonsai_require_qwen_text_encoder_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index +); + +std::vector bonsai_qwen_layer_sequence( + const BonsaiQwenLayerViews& views, + const std::vector& hidden_states, + uint64_t batch, + uint64_t sequence_length, + const std::vector& additive_attention_mask +); + +BonsaiQwenPromptEmbeddings bonsai_qwen_text_encoder_forward( + const BonsaiQwenTextEncoderViews& views, + const std::vector& input_ids, + const std::vector& attention_mask +); + +uint64_t bonsai_qwen_layer_byte_count(const BonsaiQwenLayerViews& views); + +uint64_t bonsai_qwen_text_encoder_byte_count(const BonsaiQwenTextEncoderViews& views); + +BonsaiQwenLayerProbeSummary bonsai_probe_qwen_text_encoder_layer0( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.cpp new file mode 100644 index 000000000..1289d0c01 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.cpp @@ -0,0 +1,141 @@ +#include "bonsai_qwen_inputs.h" + +#include + +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; + +void log_input_phase(const char* phase, uint64_t value = 0) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=%s value=%llu", + phase, + static_cast(value) + ); +} + +void require_text_ids_shape( + const std::vector& text_ids, + uint64_t sequence_length, + const std::string& label +) { + if (text_ids.size() != static_cast(sequence_length * 4U)) { + throw std::runtime_error("Bonsai Qwen " + label + " text id shape mismatch."); + } +} + +std::vector limited_ids( + const std::vector& ids, + uint64_t max_sequence_length, + int32_t fallback_token_id +) { + std::vector output; + output.reserve(static_cast(max_sequence_length)); + const size_t token_limit = std::min(ids.size(), static_cast(max_sequence_length)); + output.insert(output.end(), ids.begin(), ids.begin() + static_cast(token_limit)); + if (output.empty()) { + output.push_back(fallback_token_id); + } + return output; +} + +std::vector attention_mask(size_t sequence_length) { + return std::vector(sequence_length, 1); +} + +} // namespace + +BonsaiQwenInputShell bonsai_prepare_qwen_input_shell( + const BonsaiPromptEncodingPlan& prompt_plan, + const BonsaiQwenTextEncoderViews& text_encoder_views, + const BonsaiTokenizerData& tokenizer_data +) { + if (prompt_plan.max_sequence_length == 0) { + throw std::runtime_error("Bonsai Qwen prompt sequence length must be positive."); + } + if (text_encoder_views.hidden_size == 0 || + text_encoder_views.layers.empty() || + text_encoder_views.embedding.dimensions != text_encoder_views.hidden_size || + text_encoder_views.final_norm.dimensions != text_encoder_views.hidden_size) { + throw std::runtime_error("Bonsai Qwen text encoder views are incomplete."); + } + + require_text_ids_shape( + prompt_plan.text_ids, + prompt_plan.max_sequence_length, + "prompt" + ); + if (prompt_plan.uses_negative_prompt) { + require_text_ids_shape( + prompt_plan.negative_text_ids, + prompt_plan.max_sequence_length, + "negative prompt" + ); + } + + BonsaiQwenInputShell input; + input.hidden_size = text_encoder_views.hidden_size; + input.tokenization_pending = false; + input.has_negative_prompt = prompt_plan.uses_negative_prompt; + log_input_phase("qwen_prompt_tokenize_start", prompt_plan.formatted_prompt.size()); + const std::vector prompt_ids = bonsai_tokenizer_encode( + tokenizer_data, + prompt_plan.formatted_prompt + ); + log_input_phase("qwen_prompt_tokenize_done", prompt_ids.size()); + input.prompt_input_ids = limited_ids( + prompt_ids, + prompt_plan.max_sequence_length, + prompt_plan.eos_token_id + ); + input.prompt_token_count = static_cast(input.prompt_input_ids.size()); + input.sequence_length = input.prompt_token_count; + input.prompt_attention_mask = attention_mask(input.prompt_input_ids.size()); + log_input_phase("qwen_prompt_sequence_ready", input.prompt_token_count); + if (input.has_negative_prompt) { + log_input_phase( + "qwen_negative_tokenize_start", + prompt_plan.formatted_negative_prompt.size() + ); + const std::vector negative_ids = bonsai_tokenizer_encode( + tokenizer_data, + prompt_plan.formatted_negative_prompt + ); + log_input_phase("qwen_negative_tokenize_done", negative_ids.size()); + input.negative_input_ids = limited_ids( + negative_ids, + prompt_plan.max_sequence_length, + prompt_plan.eos_token_id + ); + input.negative_token_count = static_cast(input.negative_input_ids.size()); + input.sequence_length = std::max(input.sequence_length, input.negative_token_count); + input.negative_attention_mask = attention_mask(input.negative_input_ids.size()); + log_input_phase("qwen_negative_sequence_ready", input.negative_token_count); + } + return input; +} + +uint64_t bonsai_qwen_input_shell_checksum(const BonsaiQwenInputShell& input) { + uint64_t checksum = input.batch * 3U; + checksum += input.sequence_length * 5U; + checksum += input.hidden_size * 7U; + checksum += input.tokenization_pending ? 11U : 0U; + checksum += input.has_negative_prompt ? 13U : 0U; + checksum += input.prompt_token_count * 17U; + checksum += input.negative_token_count * 19U; + checksum += static_cast(input.prompt_input_ids.size()) * 23U; + checksum += static_cast(input.negative_input_ids.size()) * 29U; + checksum += static_cast(input.prompt_attention_mask.size()) * 31U; + checksum += static_cast(input.negative_attention_mask.size()) * 37U; + const size_t prompt_limit = std::min(input.prompt_input_ids.size(), 16); + for (size_t index = 0; index < prompt_limit; index++) { + checksum += static_cast(input.prompt_input_ids[index]) * (index + 1U); + } + return checksum; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.h b/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.h new file mode 100644 index 000000000..90e3096b8 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_qwen_inputs.h @@ -0,0 +1,30 @@ +#pragma once + +#include "bonsai_prompt.h" +#include "bonsai_qwen.h" +#include "bonsai_tokenizer.h" + +#include +#include + +struct BonsaiQwenInputShell { + uint64_t batch = 1; + uint64_t sequence_length = 0; + uint64_t hidden_size = 0; + bool tokenization_pending = true; + bool has_negative_prompt = false; + uint64_t prompt_token_count = 0; + uint64_t negative_token_count = 0; + std::vector prompt_input_ids; + std::vector negative_input_ids; + std::vector prompt_attention_mask; + std::vector negative_attention_mask; +}; + +BonsaiQwenInputShell bonsai_prepare_qwen_input_shell( + const BonsaiPromptEncodingPlan& prompt_plan, + const BonsaiQwenTextEncoderViews& text_encoder_views, + const BonsaiTokenizerData& tokenizer_data +); + +uint64_t bonsai_qwen_input_shell_checksum(const BonsaiQwenInputShell& input); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_rotary.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_rotary.cpp new file mode 100644 index 000000000..3cc023c4a --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_rotary.cpp @@ -0,0 +1,127 @@ +#include "bonsai_rotary.h" + +#include +#include + +namespace { + +void require_even_head_dimension(uint64_t head_dimension) { + if (head_dimension == 0 || head_dimension % 2 != 0) { + throw std::runtime_error("Bonsai rotary head dimension must be positive and even."); + } +} + +} // namespace + +std::vector bonsai_rotary_inv_frequencies( + uint64_t head_dimension, + float base +) { + require_even_head_dimension(head_dimension); + + std::vector output; + output.reserve(static_cast(head_dimension / 2)); + for (uint64_t index = 0; index < head_dimension; index += 2) { + output.push_back(1.0F / std::pow(base, static_cast(index) / head_dimension)); + } + return output; +} + +std::vector bonsai_rotary_cos_values( + uint64_t position, + uint64_t head_dimension, + float base +) { + const std::vector inv_frequencies = bonsai_rotary_inv_frequencies( + head_dimension, + base + ); + + std::vector output; + output.reserve(static_cast(head_dimension)); + for (float frequency : inv_frequencies) { + output.push_back(std::cos(static_cast(position) * frequency)); + } + for (float frequency : inv_frequencies) { + output.push_back(std::cos(static_cast(position) * frequency)); + } + return output; +} + +std::vector bonsai_rotary_sin_values( + uint64_t position, + uint64_t head_dimension, + float base +) { + const std::vector inv_frequencies = bonsai_rotary_inv_frequencies( + head_dimension, + base + ); + + std::vector output; + output.reserve(static_cast(head_dimension)); + for (float frequency : inv_frequencies) { + output.push_back(std::sin(static_cast(position) * frequency)); + } + for (float frequency : inv_frequencies) { + output.push_back(std::sin(static_cast(position) * frequency)); + } + return output; +} + +std::vector bonsai_rotate_half( + const std::vector& input, + uint64_t head_dimension +) { + require_even_head_dimension(head_dimension); + if (input.size() % static_cast(head_dimension) != 0) { + throw std::runtime_error("Bonsai rotate-half input size mismatch."); + } + + std::vector output; + output.reserve(input.size()); + const size_t dimension = static_cast(head_dimension); + const size_t half = dimension / 2; + for (size_t offset = 0; offset < input.size(); offset += dimension) { + for (size_t index = 0; index < half; index++) { + output.push_back(-input[offset + half + index]); + } + for (size_t index = 0; index < half; index++) { + output.push_back(input[offset + index]); + } + } + return output; +} + +std::vector bonsai_apply_rotary_to_heads( + const std::vector& input, + uint64_t head_dimension, + uint64_t position, + float base +) { + require_even_head_dimension(head_dimension); + if (input.size() % static_cast(head_dimension) != 0) { + throw std::runtime_error("Bonsai rotary input size mismatch."); + } + + const std::vector rotated = bonsai_rotate_half(input, head_dimension); + const std::vector cos_values = bonsai_rotary_cos_values( + position, + head_dimension, + base + ); + const std::vector sin_values = bonsai_rotary_sin_values( + position, + head_dimension, + base + ); + + std::vector output; + output.reserve(input.size()); + const size_t dimension = static_cast(head_dimension); + for (size_t index = 0; index < input.size(); index++) { + const size_t column = index % dimension; + output.push_back(input[index] * cos_values[column] + rotated[index] * sin_values[column]); + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_rotary.h b/feature/bonsai/src/androidMain/cpp/bonsai_rotary.h new file mode 100644 index 000000000..49dc8fa5f --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_rotary.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +std::vector bonsai_rotary_inv_frequencies( + uint64_t head_dimension, + float base +); + +std::vector bonsai_rotary_cos_values( + uint64_t position, + uint64_t head_dimension, + float base +); + +std::vector bonsai_rotary_sin_values( + uint64_t position, + uint64_t head_dimension, + float base +); + +std::vector bonsai_rotate_half( + const std::vector& input, + uint64_t head_dimension +); + +std::vector bonsai_apply_rotary_to_heads( + const std::vector& input, + uint64_t head_dimension, + uint64_t position, + float base +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_runtime.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_runtime.cpp new file mode 100644 index 000000000..74d65884f --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_runtime.cpp @@ -0,0 +1,582 @@ +#include "bonsai_runtime.h" + +#include "bonsai_image_encoder.h" +#include "bonsai_latents.h" +#include "bonsai_prompt.h" +#include "bonsai_qwen_inputs.h" +#include "bonsai_runtime_context.h" +#include "bonsai_scheduler.h" +#include "bonsai_vulkan.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; +constexpr uint64_t GIB_KB = 1024ULL * 1024ULL; +constexpr uint64_t PERFORMANCE_RAM_KB = 7ULL * GIB_KB; +constexpr uint64_t BALANCED_RAM_KB = 5ULL * GIB_KB; + +enum class BonsaiMemoryPolicy { + Performance, + Balanced, + Survival, +}; + +struct BonsaiTextPhaseOutput { + BonsaiQwenInputShell input_shell; + BonsaiQwenPromptEmbeddings prompt_embeddings; + BonsaiQwenPromptEmbeddings negative_prompt_embeddings; +}; + +std::string trim_ascii(const std::string& value); + +void log_phase(const char* phase) { + __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "phase=%s", phase); +} + +const char* memory_policy_name(BonsaiMemoryPolicy policy) { + switch (policy) { + case BonsaiMemoryPolicy::Performance: + return "performance"; + case BonsaiMemoryPolicy::Balanced: + return "balanced"; + case BonsaiMemoryPolicy::Survival: + return "survival"; + } + return "unknown"; +} + +uint64_t android_total_ram_kb() { + std::ifstream meminfo("/proc/meminfo"); + std::string line; + while (std::getline(meminfo, line)) { + if (line.rfind("MemTotal:", 0) != 0) { + continue; + } + std::istringstream stream(line); + std::string label; + uint64_t value = 0; + std::string unit; + stream >> label >> value >> unit; + return value; + } + return 0; +} + +uint64_t process_status_value_kb(const char* key) { + std::ifstream status("/proc/self/status"); + std::string line; + const std::string prefix = std::string(key) + ":"; + while (std::getline(status, line)) { + if (line.rfind(prefix, 0) != 0) { + continue; + } + std::istringstream stream(line); + std::string label; + uint64_t value = 0; + std::string unit; + stream >> label >> value >> unit; + return value; + } + return 0; +} + +BonsaiMemoryPolicy auto_memory_policy(uint64_t total_ram_kb) { + if (total_ram_kb >= PERFORMANCE_RAM_KB) { + return BonsaiMemoryPolicy::Performance; + } + if (total_ram_kb >= BALANCED_RAM_KB) { + return BonsaiMemoryPolicy::Balanced; + } + return BonsaiMemoryPolicy::Survival; +} + +void log_memory_snapshot(const char* point, BonsaiMemoryPolicy policy) { + const uint64_t total_ram_kb = android_total_ram_kb(); + const uint64_t vm_rss_kb = process_status_value_kb("VmRSS"); + const uint64_t vm_hwm_kb = process_status_value_kb("VmHWM"); + const uint64_t vm_size_kb = process_status_value_kb("VmSize"); + const struct mallinfo2 info = mallinfo2(); + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=memory_snapshot point=%s memory_policy=%s total_ram_mb=%llu vm_rss_mb=%llu vm_hwm_mb=%llu vm_size_mb=%llu heap_alloc_mb=%llu heap_free_mb=%llu heap_releasable_mb=%llu", + point, + memory_policy_name(policy), + static_cast(total_ram_kb / 1024ULL), + static_cast(vm_rss_kb / 1024ULL), + static_cast(vm_hwm_kb / 1024ULL), + static_cast(vm_size_kb / 1024ULL), + static_cast(info.uordblks / (1024ULL * 1024ULL)), + static_cast(info.fordblks / (1024ULL * 1024ULL)), + static_cast(info.keepcost / (1024ULL * 1024ULL)) + ); +} + +void purge_android_allocator_pages() { +#ifdef M_PURGE + using MalloptFunction = int (*)(int, int); + void* symbol = dlsym(RTLD_DEFAULT, "mallopt"); + if (symbol == nullptr) { + return; + } + reinterpret_cast(symbol)(M_PURGE, 0); +#endif +} + +void purge_allocator(const char* point, BonsaiMemoryPolicy policy) { + log_memory_snapshot(point, policy); + purge_android_allocator_pages(); + std::string after_point(point); + after_point += "_after_purge"; + log_memory_snapshot(after_point.c_str(), policy); +} + +void require_not_cancelled(const std::atomic_bool& cancel_requested) { + if (cancel_requested.load()) { + throw BonsaiGenerationCancelled(); + } +} + +void validate_request(const BonsaiGenerationRequest& request) { + if (request.model_paths.root_path.empty()) { + throw std::runtime_error("Bonsai model path is required."); + } + if (request.prompt.empty()) { + throw std::runtime_error("Prompt is required."); + } + if (request.sampling_steps <= 0) { + throw std::runtime_error("Bonsai sampling steps must be positive."); + } + if (!std::isfinite(request.cfg_scale)) { + throw std::runtime_error("Bonsai CFG scale must be finite."); + } + if (request.width <= 0 || request.height <= 0 || + request.width % 32 != 0 || request.height % 32 != 0) { + throw std::runtime_error("Bonsai image size must be positive and divisible by 32."); + } + if (request.batch_count != 1) { + throw std::runtime_error("Android Bonsai runtime supports batch_count=1 only."); + } +} + +BonsaiVulkanBackendMode parse_backend_mode(const std::string& value) { + if (value == "cpu" || value == "CPU") { + return BonsaiVulkanBackendMode::Cpu; + } + if (value == "vulkan" || value == "VULKAN") { + return BonsaiVulkanBackendMode::Vulkan; + } + return BonsaiVulkanBackendMode::Auto; +} + +std::string trim_ascii(const std::string& value) { + size_t start = 0; + while (start < value.size() && + std::isspace(static_cast(value[start])) != 0) { + start++; + } + + size_t end = value.size(); + while (end > start && + std::isspace(static_cast(value[end - 1])) != 0) { + end--; + } + return value.substr(start, end - start); +} + +int64_t parse_seed(const std::string& seed) { + const std::string trimmed = trim_ascii(seed); + if (trimmed.empty()) { + std::random_device random_device; + std::mt19937 generator(random_device()); + std::uniform_int_distribution distribution( + 0, + std::numeric_limits::max() + ); + return distribution(generator); + } + + try { + size_t consumed = 0; + const int64_t value = std::stoll(trimmed, &consumed, 10); + if (consumed != trimmed.size()) { + throw std::invalid_argument("trailing seed characters"); + } + return value; + } catch (const std::exception&) { + throw std::runtime_error("Bonsai seed is not a valid integer: " + seed); + } +} + +std::vector> float4_ids( + const std::vector& ids, + const char* label +) { + if (ids.empty() || ids.size() % 4U != 0) { + throw std::runtime_error(std::string("Bonsai ") + label + " ids shape mismatch."); + } + std::vector> output; + output.reserve(ids.size() / 4U); + for (size_t index = 0; index < ids.size(); index += 4U) { + output.push_back({ + static_cast(ids[index]), + static_cast(ids[index + 1U]), + static_cast(ids[index + 2U]), + static_cast(ids[index + 3U]), + }); + } + return output; +} + +std::vector classifier_free_guidance( + const std::vector& conditional, + const std::vector& unconditional, + float guidance +) { + if (conditional.size() != unconditional.size()) { + throw std::runtime_error("Bonsai CFG noise shape mismatch."); + } + std::vector output; + output.reserve(conditional.size()); + for (size_t index = 0; index < conditional.size(); index++) { + output.push_back( + unconditional[index] + guidance * (conditional[index] - unconditional[index]) + ); + } + return output; +} + +BonsaiNchwTensor packed_latents_tensor( + const std::vector& packed_latents, + const BonsaiLatentShape& latent_shape, + uint64_t image_height, + uint64_t image_width +) { + return BonsaiNchwTensor { + latent_shape.batch_size, + latent_shape.channels, + latent_shape.latent_height, + latent_shape.latent_width, + bonsai_unpack_packed_latents( + packed_latents, + latent_shape.batch_size, + latent_shape.sequence_length, + latent_shape.channels, + image_height, + image_width, + 8 + ), + }; +} + +BonsaiTextPhaseOutput run_text_encoder_phase( + const BonsaiGenerationRequest& effective_request, + const BonsaiTokenizerData& tokenizer_data, + const BonsaiQwenTextEncoderViews& text_encoder_views, + const std::atomic_bool& cancel_requested +) { + require_not_cancelled(cancel_requested); + log_phase("prompt_plan_start"); + const BonsaiPromptEncodingPlan prompt_plan = bonsai_prepare_qwen_prompt_encoding_plan( + effective_request.prompt, + effective_request.negative_prompt, + effective_request.cfg_scale, + tokenizer_data.metadata + ); + log_phase("qwen_input_start"); + BonsaiTextPhaseOutput output; + output.input_shell = bonsai_prepare_qwen_input_shell( + prompt_plan, + text_encoder_views, + tokenizer_data + ); + require_not_cancelled(cancel_requested); + log_phase("qwen_prompt_forward_start"); + output.prompt_embeddings = bonsai_qwen_text_encoder_forward( + text_encoder_views, + output.input_shell.prompt_input_ids, + output.input_shell.prompt_attention_mask + ); + log_phase("qwen_prompt_forward_done"); + if (output.input_shell.has_negative_prompt) { + log_phase("qwen_negative_forward_start"); + output.negative_prompt_embeddings = bonsai_qwen_text_encoder_forward( + text_encoder_views, + output.input_shell.negative_input_ids, + output.input_shell.negative_attention_mask + ); + log_phase("qwen_negative_forward_done"); + } + require_not_cancelled(cancel_requested); + return output; +} + +std::vector run_denoise_phase( + const BonsaiGenerationRequest& effective_request, + int64_t seed, + const BonsaiTextPhaseOutput& text_phase, + const BonsaiFluxTransformerViews& transformer_views, + const BonsaiProgressCallback& progress_callback, + const std::atomic_bool& cancel_requested +) { + log_phase("latents_start"); + const BonsaiLatentShape latent_shape = bonsai_packed_latent_shape( + static_cast(effective_request.height), + static_cast(effective_request.width), + static_cast(effective_request.batch_count), + 128, + 8 + ); + const BonsaiFlowMatchEulerSchedule schedule = bonsai_flow_match_euler_schedule( + latent_shape.sequence_length, + static_cast(effective_request.sampling_steps) + ); + std::vector latents = bonsai_pack_latents_nchw( + bonsai_random_latents_nchw( + latent_shape.batch_size, + latent_shape.channels, + latent_shape.latent_height, + latent_shape.latent_width, + seed + ), + latent_shape.batch_size, + latent_shape.channels, + latent_shape.latent_height, + latent_shape.latent_width + ); + const std::vector> image_ids = float4_ids( + bonsai_latent_grid_ids( + latent_shape.batch_size, + latent_shape.latent_height, + latent_shape.latent_width + ), + "image" + ); + const std::vector> text_ids = float4_ids( + bonsai_qwen_text_ids(text_phase.prompt_embeddings.sequence_length), + "text" + ); + const std::vector> negative_text_ids = + text_phase.input_shell.has_negative_prompt + ? float4_ids( + bonsai_qwen_text_ids(text_phase.negative_prompt_embeddings.sequence_length), + "negative text" + ) + : std::vector> {}; + + if (progress_callback) { + progress_callback(0, effective_request.sampling_steps); + } + require_not_cancelled(cancel_requested); + + log_phase("denoise_start"); + for (size_t index = 0; index < schedule.timesteps.size(); index++) { + require_not_cancelled(cancel_requested); + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=denoise_step_start step=%zu/%zu", + index + 1U, + schedule.timesteps.size() + ); + const BonsaiFluxTransformerOutput conditional = bonsai_flux_transformer_forward( + transformer_views, + latents, + text_phase.prompt_embeddings.values, + image_ids, + text_ids, + schedule.timesteps[index] + ); + std::vector noise = conditional.values; + if (text_phase.input_shell.has_negative_prompt) { + const BonsaiFluxTransformerOutput unconditional = bonsai_flux_transformer_forward( + transformer_views, + latents, + text_phase.negative_prompt_embeddings.values, + image_ids, + negative_text_ids, + schedule.timesteps[index] + ); + noise = classifier_free_guidance( + conditional.values, + unconditional.values, + effective_request.cfg_scale + ); + } + latents = bonsai_flow_match_euler_step( + noise, + static_cast(index), + latents, + schedule + ); + if (progress_callback) { + progress_callback(static_cast(index + 1U), effective_request.sampling_steps); + } + } + require_not_cancelled(cancel_requested); + return latents; +} + +std::string run_vae_phase( + const BonsaiGenerationRequest& effective_request, + const std::vector& latents, + const BonsaiVaeDecodeViews& vae_views +) { + const BonsaiLatentShape latent_shape = bonsai_packed_latent_shape( + static_cast(effective_request.height), + static_cast(effective_request.width), + static_cast(effective_request.batch_count), + 128, + 8 + ); + log_phase("vae_decode_start"); + return bonsai_encode_nchw_tensor_as_png_base64( + bonsai_vae_decode_packed_view_nchw( + packed_latents_tensor( + latents, + latent_shape, + static_cast(effective_request.height), + static_cast(effective_request.width) + ), + vae_views + ) + ); +} + +std::string bonsai_generate_image_performance( + const BonsaiGenerationRequest& effective_request, + int64_t seed, + const BonsaiProgressCallback& progress_callback, + const std::atomic_bool& cancel_requested, + BonsaiMemoryPolicy policy +) { + log_phase("load_context_start"); + std::unique_ptr model_context = + bonsai_load_runtime_model_context(effective_request.model_paths); + log_phase("load_context_done"); + log_memory_snapshot("full_context_loaded", policy); + BonsaiTextPhaseOutput text_phase = run_text_encoder_phase( + effective_request, + model_context->tokenizer_data, + model_context->text_encoder_views, + cancel_requested + ); + std::vector latents = run_denoise_phase( + effective_request, + seed, + text_phase, + model_context->transformer_views, + progress_callback, + cancel_requested + ); + return run_vae_phase(effective_request, latents, model_context->vae_views); +} + +std::string bonsai_generate_image_staged( + const BonsaiGenerationRequest& effective_request, + int64_t seed, + const BonsaiProgressCallback& progress_callback, + const std::atomic_bool& cancel_requested, + BonsaiMemoryPolicy policy +) { + log_phase("load_text_context_start"); + std::unique_ptr text_context = + bonsai_load_text_encoder_runtime_context(effective_request.model_paths); + log_phase("load_text_context_done"); + log_memory_snapshot("text_context_loaded", policy); + BonsaiTextPhaseOutput text_phase = run_text_encoder_phase( + effective_request, + text_context->tokenizer_data, + text_context->text_encoder_views, + cancel_requested + ); + text_context.reset(); + purge_allocator("text_context_released", policy); + + log_phase("load_flux_context_start"); + std::unique_ptr flux_context = + bonsai_load_flux_transformer_runtime_context(effective_request.model_paths); + log_phase("load_flux_context_done"); + log_memory_snapshot("flux_context_loaded", policy); + std::vector latents = run_denoise_phase( + effective_request, + seed, + text_phase, + flux_context->transformer_views, + progress_callback, + cancel_requested + ); + flux_context.reset(); + purge_allocator("flux_context_released", policy); + + log_phase("load_vae_context_start"); + std::unique_ptr vae_context = + bonsai_load_vae_runtime_context(effective_request.model_paths); + log_phase("load_vae_context_done"); + log_memory_snapshot("vae_context_loaded", policy); + std::string output = run_vae_phase(effective_request, latents, vae_context->vae_views); + vae_context.reset(); + purge_allocator("vae_context_released", policy); + return output; +} + +} // namespace + +BonsaiGenerationCancelled::BonsaiGenerationCancelled() : + std::runtime_error("Bonsai generation cancelled.") {} + +std::string bonsai_generate_image( + const BonsaiGenerationRequest& request, + const BonsaiProgressCallback& progress_callback, + const std::atomic_bool& cancel_requested +) { + require_not_cancelled(cancel_requested); + validate_request(request); + const BonsaiVulkanBackendMode backend_mode = parse_backend_mode(request.backend); + bonsai_vulkan_set_backend_mode(backend_mode); + const uint64_t total_ram_kb = android_total_ram_kb(); + const BonsaiMemoryPolicy policy = auto_memory_policy(total_ram_kb); + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=memory_policy memory_policy=%s total_ram_mb=%llu", + memory_policy_name(policy), + static_cast(total_ram_kb / 1024ULL) + ); + log_memory_snapshot("request_start", policy); + if (backend_mode != BonsaiVulkanBackendMode::Cpu) { + bonsai_vulkan_runtime_available(); + } + log_phase("parse_seed"); + const int64_t seed = parse_seed(request.seed); + if (policy == BonsaiMemoryPolicy::Performance) { + return bonsai_generate_image_performance( + request, + seed, + progress_callback, + cancel_requested, + policy + ); + } + return bonsai_generate_image_staged( + request, + seed, + progress_callback, + cancel_requested, + policy + ); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_runtime.h b/feature/bonsai/src/androidMain/cpp/bonsai_runtime.h new file mode 100644 index 000000000..176178e25 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_runtime.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include + +#include "bonsai_model_probe.h" + +struct BonsaiGenerationRequest { + BonsaiModelPaths model_paths; + std::string prompt; + std::string negative_prompt; + int sampling_steps = 0; + float cfg_scale = 0.0F; + int width = 0; + int height = 0; + std::string seed; + int batch_count = 0; + bool allow_nsfw = false; + std::string backend = "auto"; +}; + +class BonsaiGenerationCancelled : public std::runtime_error { +public: + BonsaiGenerationCancelled(); +}; + +using BonsaiProgressCallback = std::function; + +std::string bonsai_generate_image( + const BonsaiGenerationRequest& request, + const BonsaiProgressCallback& progress_callback, + const std::atomic_bool& cancel_requested +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.cpp new file mode 100644 index 000000000..887933925 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.cpp @@ -0,0 +1,237 @@ +#include "bonsai_runtime_context.h" + +#include + +namespace { + +void require_runtime_layout(const BonsaiModelPaths& paths) { + bonsai_require_directory(paths.root_path, "root"); + bonsai_require_directory(paths.tokenizer_path, "tokenizer"); + bonsai_require_directory(paths.scheduler_path, "scheduler"); + bonsai_require_file(bonsai_join_path(paths.tokenizer_path, "tokenizer.json"), "tokenizer"); + bonsai_require_file( + bonsai_join_path(paths.tokenizer_path, "tokenizer_config.json"), + "tokenizer config" + ); +} + +void require_text_encoder_inventory(const BonsaiSafetensorsIndex& index) { + const std::string embedding_key = index.resolve_model_prefixed_key("embed_tokens.weight"); + index.require_packed_weight(embedding_key, 4, 64); +} + +std::string build_context_summary(const BonsaiRuntimeModelContext& context) { + std::ostringstream output; + output + << "quantization_bits=" << context.quantization.bits + << " group_size=" << context.quantization.group_size + << " tokenizer_class=" << context.tokenizer_data.metadata.runtime_tokenizer_class + << " tokenizer_vocab=" << context.tokenizer_data.metadata.vocab_size + << " tokenizer_merges=" << context.tokenizer_data.metadata.merge_count + << " tokenizer_pad_id=" << context.tokenizer_data.metadata.pad_token_id + << " tokenizer_eos_id=" << context.tokenizer_data.metadata.eos_token_id + << " tokenizer_checksum=" + << bonsai_tokenizer_data_checksum(context.tokenizer_data) + << " transformer_tensors=" << context.transformer_index.tensor_count() + << " transformer_files=" << context.transformer_index.file_count() + << " transformer_logical_tensors=" + << context.transformer_inventory.logical_tensor_count + << " flux_double_blocks=" << context.transformer_inventory.double_block_count + << " flux_single_blocks=" << context.transformer_inventory.single_block_count + << " flux_double_view_blocks=" << context.transformer_views.double_blocks.size() + << " flux_single_view_blocks=" << context.transformer_views.single_blocks.size() + << " flux_transformer_view_bytes=" + << bonsai_flux_transformer_byte_count(context.transformer_views) + << " text_encoder_tensors=" << context.text_encoder_index.tensor_count() + << " text_encoder_files=" << context.text_encoder_index.file_count() + << " qwen_layers=" << context.text_encoder_inventory.layer_count + << " qwen_logical_tensors=" << context.text_encoder_inventory.logical_tensor_count + << " qwen_view_layers=" << context.text_encoder_views.layers.size() + << " qwen_view_bytes=" << bonsai_qwen_text_encoder_byte_count( + context.text_encoder_views + ) + << " vae_tensors=" << context.vae_index.tensor_count() + << " vae_files=" << context.vae_index.file_count() + << " vae_up_blocks=" << context.vae_inventory.up_block_count + << " vae_resnet_blocks=" << context.vae_inventory.resnet_block_count + << " vae_attention_blocks=" << context.vae_inventory.attention_block_count + << " vae_decode_bytes=" << bonsai_vae_decode_byte_count(context.vae_views); + return output.str(); +} + +std::string build_text_encoder_context_summary( + const BonsaiTextEncoderRuntimeContext& context +) { + std::ostringstream output; + output + << "tokenizer_class=" << context.tokenizer_data.metadata.runtime_tokenizer_class + << " tokenizer_vocab=" << context.tokenizer_data.metadata.vocab_size + << " tokenizer_merges=" << context.tokenizer_data.metadata.merge_count + << " tokenizer_pad_id=" << context.tokenizer_data.metadata.pad_token_id + << " tokenizer_eos_id=" << context.tokenizer_data.metadata.eos_token_id + << " tokenizer_checksum=" + << bonsai_tokenizer_data_checksum(context.tokenizer_data) + << " text_encoder_tensors=" << context.text_encoder_index.tensor_count() + << " text_encoder_files=" << context.text_encoder_index.file_count() + << " qwen_layers=" << context.text_encoder_inventory.layer_count + << " qwen_logical_tensors=" << context.text_encoder_inventory.logical_tensor_count + << " qwen_view_layers=" << context.text_encoder_views.layers.size() + << " qwen_view_bytes=" << bonsai_qwen_text_encoder_byte_count( + context.text_encoder_views + ); + return output.str(); +} + +std::string build_flux_transformer_context_summary( + const BonsaiFluxTransformerRuntimeContext& context +) { + std::ostringstream output; + output + << "quantization_bits=" << context.quantization.bits + << " group_size=" << context.quantization.group_size + << " transformer_tensors=" << context.transformer_index.tensor_count() + << " transformer_files=" << context.transformer_index.file_count() + << " transformer_logical_tensors=" + << context.transformer_inventory.logical_tensor_count + << " flux_double_blocks=" << context.transformer_inventory.double_block_count + << " flux_single_blocks=" << context.transformer_inventory.single_block_count + << " flux_double_view_blocks=" << context.transformer_views.double_blocks.size() + << " flux_single_view_blocks=" << context.transformer_views.single_blocks.size() + << " flux_transformer_view_bytes=" + << bonsai_flux_transformer_byte_count(context.transformer_views); + return output.str(); +} + +std::string build_vae_context_summary(const BonsaiVaeRuntimeContext& context) { + std::ostringstream output; + output + << "vae_tensors=" << context.vae_index.tensor_count() + << " vae_files=" << context.vae_index.file_count() + << " vae_up_blocks=" << context.vae_inventory.up_block_count + << " vae_resnet_blocks=" << context.vae_inventory.resnet_block_count + << " vae_attention_blocks=" << context.vae_inventory.attention_block_count + << " vae_decode_bytes=" << bonsai_vae_decode_byte_count(context.vae_views); + return output.str(); +} + +} // namespace + +BonsaiTextEncoderRuntimeContext::BonsaiTextEncoderRuntimeContext( + const BonsaiModelPaths& model_paths +) : + paths(model_paths), + tokenizer_data(bonsai_load_tokenizer_data(model_paths.tokenizer_path)), + text_encoder_index(BonsaiSafetensorsIndex::load_directory( + model_paths.text_encoder_path, + "text encoder" + )), + text_encoder_storage(text_encoder_index), + text_encoder_views(bonsai_require_qwen_text_encoder_views( + text_encoder_storage, + text_encoder_index + )) { + require_text_encoder_inventory(text_encoder_index); + text_encoder_inventory = bonsai_require_qwen_text_encoder_tensors(text_encoder_index); + summary = build_text_encoder_context_summary(*this); +} + +BonsaiFluxTransformerRuntimeContext::BonsaiFluxTransformerRuntimeContext( + const BonsaiModelPaths& model_paths +) : + paths(model_paths), + quantization(bonsai_read_quantization_config(model_paths.packed_transformer_path)), + transformer_index(BonsaiSafetensorsIndex::load_directory( + model_paths.packed_transformer_path, + "transformer" + )), + transformer_storage(transformer_index), + transformer_views(bonsai_require_flux_transformer_views( + transformer_storage, + transformer_index, + quantization.bits, + quantization.group_size + )) { + transformer_inventory = bonsai_require_flux_transformer_tensors( + transformer_index, + quantization.bits, + quantization.group_size + ); + summary = build_flux_transformer_context_summary(*this); +} + +BonsaiVaeRuntimeContext::BonsaiVaeRuntimeContext(const BonsaiModelPaths& model_paths) : + paths(model_paths), + vae_config(bonsai_read_vae_config(model_paths.vae_path)), + vae_index(BonsaiSafetensorsIndex::load_directory(model_paths.vae_path, "vae")), + vae_storage(vae_index), + vae_views(bonsai_vae_require_decode_views(vae_storage, vae_index, vae_config)) { + vae_inventory = bonsai_require_flux_vae_tensors(vae_index, vae_config); + summary = build_vae_context_summary(*this); +} + +BonsaiRuntimeModelContext::BonsaiRuntimeModelContext(const BonsaiModelPaths& model_paths) : + paths(model_paths), + quantization(bonsai_read_quantization_config(model_paths.packed_transformer_path)), + vae_config(bonsai_read_vae_config(model_paths.vae_path)), + tokenizer_data(bonsai_load_tokenizer_data(model_paths.tokenizer_path)), + transformer_index(BonsaiSafetensorsIndex::load_directory( + model_paths.packed_transformer_path, + "transformer" + )), + text_encoder_index(BonsaiSafetensorsIndex::load_directory( + model_paths.text_encoder_path, + "text encoder" + )), + vae_index(BonsaiSafetensorsIndex::load_directory(model_paths.vae_path, "vae")), + transformer_storage(transformer_index), + text_encoder_storage(text_encoder_index), + vae_storage(vae_index), + transformer_views(bonsai_require_flux_transformer_views( + transformer_storage, + transformer_index, + quantization.bits, + quantization.group_size + )), + text_encoder_views(bonsai_require_qwen_text_encoder_views( + text_encoder_storage, + text_encoder_index + )), + vae_views(bonsai_vae_require_decode_views(vae_storage, vae_index, vae_config)) { + transformer_inventory = bonsai_require_flux_transformer_tensors( + transformer_index, + quantization.bits, + quantization.group_size + ); + require_text_encoder_inventory(text_encoder_index); + text_encoder_inventory = bonsai_require_qwen_text_encoder_tensors(text_encoder_index); + vae_inventory = bonsai_require_flux_vae_tensors(vae_index, vae_config); + summary = build_context_summary(*this); +} + +std::unique_ptr bonsai_load_runtime_model_context( + const BonsaiModelPaths& paths +) { + require_runtime_layout(paths); + return std::make_unique(paths); +} + +std::unique_ptr bonsai_load_text_encoder_runtime_context( + const BonsaiModelPaths& paths +) { + require_runtime_layout(paths); + return std::make_unique(paths); +} + +std::unique_ptr bonsai_load_flux_transformer_runtime_context( + const BonsaiModelPaths& paths +) { + require_runtime_layout(paths); + return std::make_unique(paths); +} + +std::unique_ptr bonsai_load_vae_runtime_context( + const BonsaiModelPaths& paths +) { + require_runtime_layout(paths); + return std::make_unique(paths); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.h b/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.h new file mode 100644 index 000000000..37ce904a6 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_runtime_context.h @@ -0,0 +1,93 @@ +#pragma once + +#include "bonsai_flux_transformer.h" +#include "bonsai_flux_vae.h" +#include "bonsai_model_config.h" +#include "bonsai_model_probe.h" +#include "bonsai_qwen.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" +#include "bonsai_tokenizer.h" +#include "bonsai_vae_decoder.h" + +#include +#include + +struct BonsaiTextEncoderRuntimeContext { + explicit BonsaiTextEncoderRuntimeContext(const BonsaiModelPaths& model_paths); + + BonsaiModelPaths paths; + BonsaiTokenizerData tokenizer_data; + BonsaiSafetensorsIndex text_encoder_index; + BonsaiTensorStorage text_encoder_storage; + BonsaiQwenTextEncoderViews text_encoder_views; + BonsaiQwenInventorySummary text_encoder_inventory; + std::string summary; +}; + +struct BonsaiFluxTransformerRuntimeContext { + explicit BonsaiFluxTransformerRuntimeContext(const BonsaiModelPaths& model_paths); + + BonsaiModelPaths paths; + BonsaiQuantizationConfig quantization; + BonsaiSafetensorsIndex transformer_index; + BonsaiTensorStorage transformer_storage; + BonsaiFluxTransformerViews transformer_views; + BonsaiFluxTransformerInventorySummary transformer_inventory; + std::string summary; +}; + +struct BonsaiVaeRuntimeContext { + explicit BonsaiVaeRuntimeContext(const BonsaiModelPaths& model_paths); + + BonsaiModelPaths paths; + BonsaiFluxVaeConfig vae_config; + BonsaiSafetensorsIndex vae_index; + BonsaiTensorStorage vae_storage; + BonsaiVaeDecodeViews vae_views; + BonsaiFluxVaeInventorySummary vae_inventory; + std::string summary; +}; + +struct BonsaiRuntimeModelContext { + explicit BonsaiRuntimeModelContext(const BonsaiModelPaths& model_paths); + + BonsaiModelPaths paths; + BonsaiQuantizationConfig quantization; + BonsaiFluxVaeConfig vae_config; + BonsaiTokenizerData tokenizer_data; + + BonsaiSafetensorsIndex transformer_index; + BonsaiSafetensorsIndex text_encoder_index; + BonsaiSafetensorsIndex vae_index; + + BonsaiTensorStorage transformer_storage; + BonsaiTensorStorage text_encoder_storage; + BonsaiTensorStorage vae_storage; + + BonsaiFluxTransformerViews transformer_views; + BonsaiQwenTextEncoderViews text_encoder_views; + BonsaiVaeDecodeViews vae_views; + + BonsaiFluxTransformerInventorySummary transformer_inventory; + BonsaiQwenInventorySummary text_encoder_inventory; + BonsaiFluxVaeInventorySummary vae_inventory; + + std::string summary; +}; + +std::unique_ptr bonsai_load_runtime_model_context( + const BonsaiModelPaths& paths +); + +std::unique_ptr bonsai_load_text_encoder_runtime_context( + const BonsaiModelPaths& paths +); + +std::unique_ptr bonsai_load_flux_transformer_runtime_context( + const BonsaiModelPaths& paths +); + +std::unique_ptr bonsai_load_vae_runtime_context( + const BonsaiModelPaths& paths +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.cpp new file mode 100644 index 000000000..e2e59e3f9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.cpp @@ -0,0 +1,559 @@ +#include "bonsai_safetensors.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr uint64_t MAX_SAFETENSORS_HEADER_BYTES = 128ULL * 1024ULL * 1024ULL; + +struct DirectoryHandle { + explicit DirectoryHandle(const std::string& path) : value(opendir(path.c_str())) {} + + ~DirectoryHandle() { + if (value != nullptr) { + closedir(value); + } + } + + DIR* value = nullptr; +}; + +bool ends_with(const std::string& value, const std::string& suffix) { + return value.size() >= suffix.size() && + value.compare(value.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +bool stat_path(const std::string& path, struct stat* output) { + return lstat(path.c_str(), output) == 0; +} + +bool is_directory(const std::string& path) { + struct stat info {}; + return stat_path(path, &info) && S_ISDIR(info.st_mode); +} + +uint64_t file_size(const std::string& path) { + struct stat info {}; + if (!stat_path(path, &info) || !S_ISREG(info.st_mode)) { + throw std::runtime_error("missing Bonsai safetensors file: " + path); + } + return static_cast(info.st_size); +} + +std::string join_path(const std::string& parent, const std::string& child) { + if (parent.empty() || parent.back() == '/') { + return parent + child; + } + return parent + "/" + child; +} + +void collect_safetensors( + const std::string& directory, + std::vector* output +) { + DirectoryHandle handle(directory); + if (handle.value == nullptr) { + throw std::runtime_error("could not read Bonsai directory: " + directory); + } + + while (dirent* entry = readdir(handle.value)) { + const std::string name(entry->d_name); + if (name == "." || name == "..") { + continue; + } + + const std::string path = join_path(directory, name); + struct stat info {}; + if (!stat_path(path, &info)) { + continue; + } + + if (S_ISDIR(info.st_mode)) { + collect_safetensors(path, output); + } else if (S_ISREG(info.st_mode) && ends_with(name, ".safetensors")) { + output->push_back(path); + } + } +} + +uint64_t read_little_endian_u64(const unsigned char* bytes) { + uint64_t value = 0; + for (int index = 7; index >= 0; index--) { + value = (value << 8U) | bytes[index]; + } + return value; +} + +class JsonCursor { +public: + explicit JsonCursor(const std::string& json) : json_(json) {} + + bool at_end() { + skip_whitespace(); + return index_ >= json_.size(); + } + + bool consume(char expected) { + skip_whitespace(); + if (index_ < json_.size() && json_[index_] == expected) { + index_++; + return true; + } + return false; + } + + void expect(char expected) { + if (!consume(expected)) { + throw std::runtime_error("invalid Bonsai safetensors JSON header"); + } + } + + std::string parse_string() { + skip_whitespace(); + expect('"'); + std::string value; + while (index_ < json_.size()) { + const char current = json_[index_++]; + if (current == '"') { + return value; + } + if (current != '\\') { + value.push_back(current); + continue; + } + if (index_ >= json_.size()) { + throw std::runtime_error("invalid Bonsai safetensors string escape"); + } + const char escaped = json_[index_++]; + switch (escaped) { + case '"': + case '\\': + case '/': + value.push_back(escaped); + break; + case 'b': + value.push_back('\b'); + break; + case 'f': + value.push_back('\f'); + break; + case 'n': + value.push_back('\n'); + break; + case 'r': + value.push_back('\r'); + break; + case 't': + value.push_back('\t'); + break; + case 'u': + if (index_ + 4 > json_.size()) { + throw std::runtime_error("invalid Bonsai safetensors unicode escape"); + } + index_ += 4; + value.push_back('?'); + break; + default: + throw std::runtime_error("invalid Bonsai safetensors string escape"); + } + } + throw std::runtime_error("unterminated Bonsai safetensors string"); + } + + uint64_t parse_uint64() { + skip_whitespace(); + const size_t start = index_; + while (index_ < json_.size() && + std::isdigit(static_cast(json_[index_])) != 0) { + index_++; + } + if (start == index_) { + throw std::runtime_error("invalid Bonsai safetensors integer"); + } + return static_cast(std::stoull(json_.substr(start, index_ - start))); + } + + std::vector parse_uint64_array() { + std::vector values; + expect('['); + if (consume(']')) { + return values; + } + while (true) { + values.push_back(parse_uint64()); + if (consume(']')) { + return values; + } + expect(','); + } + } + + void skip_value() { + skip_whitespace(); + if (index_ >= json_.size()) { + throw std::runtime_error("invalid Bonsai safetensors JSON value"); + } + + const char current = json_[index_]; + if (current == '"') { + parse_string(); + return; + } + if (current == '{') { + skip_object(); + return; + } + if (current == '[') { + skip_array(); + return; + } + if (std::isdigit(static_cast(current)) != 0 || current == '-') { + skip_number(); + return; + } + if (try_skip_literal("true") || + try_skip_literal("false") || + try_skip_literal("null") + ) { + return; + } + throw std::runtime_error("invalid Bonsai safetensors JSON literal"); + } + +private: + void skip_whitespace() { + while (index_ < json_.size() && + std::isspace(static_cast(json_[index_])) != 0) { + index_++; + } + } + + void skip_object() { + expect('{'); + if (consume('}')) { + return; + } + while (true) { + parse_string(); + expect(':'); + skip_value(); + if (consume('}')) { + return; + } + expect(','); + } + } + + void skip_array() { + expect('['); + if (consume(']')) { + return; + } + while (true) { + skip_value(); + if (consume(']')) { + return; + } + expect(','); + } + } + + void skip_number() { + if (index_ < json_.size() && json_[index_] == '-') { + index_++; + } + while (index_ < json_.size() && + std::isdigit(static_cast(json_[index_])) != 0) { + index_++; + } + if (index_ < json_.size() && json_[index_] == '.') { + index_++; + while (index_ < json_.size() && + std::isdigit(static_cast(json_[index_])) != 0) { + index_++; + } + } + if (index_ < json_.size() && (json_[index_] == 'e' || json_[index_] == 'E')) { + index_++; + if (index_ < json_.size() && (json_[index_] == '+' || json_[index_] == '-')) { + index_++; + } + while (index_ < json_.size() && + std::isdigit(static_cast(json_[index_])) != 0) { + index_++; + } + } + } + + bool try_skip_literal(const std::string& literal) { + if (json_.compare(index_, literal.size(), literal) == 0) { + index_ += literal.size(); + return true; + } + return false; + } + + const std::string& json_; + size_t index_ = 0; +}; + +BonsaiTensorDescriptor parse_tensor_descriptor( + JsonCursor* cursor, + const std::string& key +) { + BonsaiTensorDescriptor descriptor; + descriptor.key = key; + bool has_dtype = false; + bool has_shape = false; + bool has_offsets = false; + + cursor->expect('{'); + if (!cursor->consume('}')) { + while (true) { + const std::string property = cursor->parse_string(); + cursor->expect(':'); + if (property == "dtype") { + descriptor.dtype = cursor->parse_string(); + has_dtype = true; + } else if (property == "shape") { + descriptor.shape = cursor->parse_uint64_array(); + has_shape = true; + } else if (property == "data_offsets") { + const std::vector offsets = cursor->parse_uint64_array(); + if (offsets.size() != 2 || offsets[0] > offsets[1]) { + throw std::runtime_error( + "invalid Bonsai safetensors data_offsets for tensor: " + key + ); + } + descriptor.data_start = offsets[0]; + descriptor.data_end = offsets[1]; + has_offsets = true; + } else { + cursor->skip_value(); + } + if (cursor->consume('}')) { + break; + } + cursor->expect(','); + } + } + + if (!has_dtype || !has_shape || !has_offsets) { + throw std::runtime_error("invalid Bonsai safetensors metadata for tensor: " + key); + } + try { + descriptor.dtype_type = bonsai_dtype_from_safetensors(descriptor.dtype); + } catch (const std::runtime_error&) { + throw std::runtime_error( + "unsupported Bonsai safetensors dtype " + descriptor.dtype + " for tensor: " + key + ); + } + return descriptor; +} + +std::vector parse_safetensors_header( + const std::string& header, + uint64_t data_base_offset, + uint64_t file_byte_count, + const std::string& file_path +) { + JsonCursor cursor(header); + std::vector descriptors; + cursor.expect('{'); + if (!cursor.consume('}')) { + while (true) { + const std::string key = cursor.parse_string(); + cursor.expect(':'); + if (key == "__metadata__") { + cursor.skip_value(); + } else { + BonsaiTensorDescriptor descriptor = parse_tensor_descriptor(&cursor, key); + if (data_base_offset + descriptor.data_end > file_byte_count) { + throw std::runtime_error( + "Bonsai safetensors tensor data exceeds file size: " + key + ); + } + descriptor.data_start += data_base_offset; + descriptor.data_end += data_base_offset; + descriptor.file_path = file_path; + descriptors.push_back(descriptor); + } + + if (cursor.consume('}')) { + break; + } + cursor.expect(','); + } + } + if (!cursor.at_end()) { + throw std::runtime_error("invalid trailing data in Bonsai safetensors header"); + } + return descriptors; +} + +std::vector read_safetensors_file(const std::string& path) { + std::ifstream input(path, std::ios::binary); + if (!input) { + throw std::runtime_error("could not read Bonsai safetensors file: " + path); + } + + unsigned char header_length_bytes[8] {}; + input.read(reinterpret_cast(header_length_bytes), sizeof(header_length_bytes)); + if (input.gcount() != static_cast(sizeof(header_length_bytes))) { + throw std::runtime_error("invalid Bonsai safetensors header: " + path); + } + + const uint64_t header_length = read_little_endian_u64(header_length_bytes); + if (header_length == 0 || header_length > MAX_SAFETENSORS_HEADER_BYTES) { + throw std::runtime_error("unsupported Bonsai safetensors header length: " + path); + } + + std::string header(static_cast(header_length), '\0'); + input.read(header.data(), static_cast(header.size())); + if (input.gcount() != static_cast(header.size())) { + throw std::runtime_error("truncated Bonsai safetensors header: " + path); + } + + const uint64_t data_base_offset = 8ULL + header_length; + return parse_safetensors_header(header, data_base_offset, file_size(path), path); +} + +bool has_suffix(const std::string& value, const std::string& suffix) { + return value.size() >= suffix.size() && + value.compare(value.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +std::string key_prefix_for_weight(const std::string& weight_key) { + if (!has_suffix(weight_key, ".weight")) { + throw std::runtime_error("expected tensor key ending in .weight: " + weight_key); + } + return weight_key.substr(0, weight_key.size() - std::string(".weight").size()); +} + +} // namespace + +BonsaiSafetensorsIndex BonsaiSafetensorsIndex::load_directory( + const std::string& directory, + const std::string& label +) { + if (!is_directory(directory)) { + throw std::runtime_error("missing Bonsai " + label + " directory: " + directory); + } + + std::vector files; + collect_safetensors(directory, &files); + std::sort(files.begin(), files.end()); + if (files.empty()) { + throw std::runtime_error("no .safetensors files in Bonsai " + label + " directory"); + } + + BonsaiSafetensorsIndex index; + index.files_ = files; + for (const std::string& file : files) { + index.append(file, read_safetensors_file(file)); + } + if (index.descriptors_.empty()) { + throw std::runtime_error("Bonsai " + label + " checkpoint has no tensor metadata"); + } + return index; +} + +bool BonsaiSafetensorsIndex::contains(const std::string& key) const { + return index_.find(key) != index_.end(); +} + +const BonsaiTensorDescriptor& BonsaiSafetensorsIndex::require(const std::string& key) const { + const BonsaiTensorDescriptor* descriptor = optional(key); + if (descriptor == nullptr) { + throw std::runtime_error("Bonsai checkpoint is missing tensor: " + key); + } + return *descriptor; +} + +const BonsaiTensorDescriptor* BonsaiSafetensorsIndex::optional(const std::string& key) const { + const auto found = index_.find(key); + if (found == index_.end()) { + return nullptr; + } + return &descriptors_[found->second]; +} + +std::string BonsaiSafetensorsIndex::resolve_model_prefixed_key( + const std::string& suffix +) const { + if (contains(suffix)) { + return suffix; + } + return "model." + suffix; +} + +BonsaiPackedWeightDescriptor BonsaiSafetensorsIndex::require_packed_weight( + const std::string& weight_key, + int bits, + int group_size +) const { + const std::string prefix = key_prefix_for_weight(weight_key); + const std::string scales_key = prefix + ".scales"; + const BonsaiTensorDescriptor* scales = optional(scales_key); + if (scales == nullptr) { + require(weight_key); + return BonsaiPackedWeightDescriptor { + false, + weight_key, + "", + "", + bits, + group_size, + }; + } + + const std::string biases_key = prefix + ".biases"; + require(biases_key); + const BonsaiTensorDescriptor& packed = require(weight_key); + if (packed.dtype_type != BonsaiDType::U32) { + throw std::runtime_error("packed tensor " + weight_key + " must be uint32"); + } + + return BonsaiPackedWeightDescriptor { + true, + weight_key, + scales_key, + biases_key, + bits, + group_size, + }; +} + +size_t BonsaiSafetensorsIndex::tensor_count() const { + return descriptors_.size(); +} + +size_t BonsaiSafetensorsIndex::file_count() const { + return files_.size(); +} + +const std::vector& BonsaiSafetensorsIndex::files() const { + return files_; +} + +const std::vector& BonsaiSafetensorsIndex::descriptors() const { + return descriptors_; +} + +void BonsaiSafetensorsIndex::append( + const std::string& file_path, + std::vector descriptors +) { + for (BonsaiTensorDescriptor& descriptor : descriptors) { + descriptor.file_path = file_path; + descriptors_.push_back(descriptor); + index_[descriptor.key] = descriptors_.size() - 1; + } +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.h b/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.h new file mode 100644 index 000000000..eedffeaf2 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_safetensors.h @@ -0,0 +1,60 @@ +#pragma once + +#include "bonsai_tensor.h" + +#include +#include +#include +#include + +struct BonsaiTensorDescriptor { + std::string key; + std::string file_path; + std::string dtype; + BonsaiDType dtype_type = BonsaiDType::F32; + std::vector shape; + uint64_t data_start = 0; + uint64_t data_end = 0; +}; + +struct BonsaiPackedWeightDescriptor { + bool packed = false; + std::string weight_key; + std::string scales_key; + std::string biases_key; + int bits = 0; + int group_size = 0; +}; + +class BonsaiSafetensorsIndex { +public: + static BonsaiSafetensorsIndex load_directory( + const std::string& directory, + const std::string& label + ); + + bool contains(const std::string& key) const; + const BonsaiTensorDescriptor& require(const std::string& key) const; + const BonsaiTensorDescriptor* optional(const std::string& key) const; + std::string resolve_model_prefixed_key(const std::string& suffix) const; + BonsaiPackedWeightDescriptor require_packed_weight( + const std::string& weight_key, + int bits, + int group_size + ) const; + + size_t tensor_count() const; + size_t file_count() const; + const std::vector& files() const; + const std::vector& descriptors() const; + +private: + void append( + const std::string& file_path, + std::vector descriptors + ); + + std::vector files_; + std::vector descriptors_; + std::unordered_map index_; +}; diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.cpp new file mode 100644 index 000000000..d08326dfb --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.cpp @@ -0,0 +1,84 @@ +#include "bonsai_scheduler.h" + +#include +#include +#include + +namespace { + +float empirical_mu(uint64_t image_sequence_length, uint64_t steps) { + const float a1 = 8.73809524e-05F; + const float b1 = 1.89833333F; + const float a2 = 0.00016927F; + const float b2 = 0.45666666F; + const float sequence_length = static_cast(image_sequence_length); + if (image_sequence_length > 4300) { + return a2 * sequence_length + b2; + } + + const float m200 = a2 * sequence_length + b2; + const float m10 = a1 * sequence_length + b1; + const float a = (m200 - m10) / 190.0F; + const float b = m200 - 200.0F * a; + return a * static_cast(steps) + b; +} + +float time_shift(float mu, float sigma_power, float timestep) { + const float numerator = std::exp(mu); + return numerator / (numerator + std::pow(1.0F / timestep - 1.0F, sigma_power)); +} + +} // namespace + +BonsaiFlowMatchEulerSchedule bonsai_flow_match_euler_schedule( + uint64_t image_sequence_length, + uint64_t steps +) { + if (image_sequence_length == 0) { + throw std::runtime_error("Bonsai scheduler image sequence length must be positive."); + } + + const uint64_t step_count = std::max(1, steps); + const float mu = empirical_mu(image_sequence_length, step_count); + + BonsaiFlowMatchEulerSchedule output; + output.timesteps.reserve(static_cast(step_count)); + output.sigmas.reserve(static_cast(step_count + 1)); + for (uint64_t index = 0; index < step_count; index++) { + const float linear = 1.0F - + static_cast(index) * + (1.0F - 1.0F / static_cast(step_count)) / + static_cast(std::max(1, step_count - 1)); + const float sigma = time_shift(mu, 1.0F, linear); + output.sigmas.push_back(sigma); + output.timesteps.push_back(sigma * 1000.0F); + } + output.sigmas.push_back(0.0F); + return output; +} + +std::vector bonsai_flow_match_euler_step( + const std::vector& noise, + uint64_t timestep_index, + const std::vector& latents, + const BonsaiFlowMatchEulerSchedule& schedule +) { + if (noise.size() != latents.size()) { + throw std::runtime_error("Bonsai scheduler noise/latents size mismatch."); + } + if (schedule.sigmas.size() != schedule.timesteps.size() + 1) { + throw std::runtime_error("Bonsai scheduler shape mismatch."); + } + if (timestep_index + 1 >= schedule.sigmas.size()) { + throw std::runtime_error("Bonsai scheduler timestep is out of range."); + } + + const float delta = schedule.sigmas[static_cast(timestep_index + 1)] - + schedule.sigmas[static_cast(timestep_index)]; + std::vector output; + output.reserve(latents.size()); + for (size_t index = 0; index < latents.size(); index++) { + output.push_back(latents[index] + delta * noise[index]); + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.h b/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.h new file mode 100644 index 000000000..4d9745e24 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_scheduler.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +struct BonsaiFlowMatchEulerSchedule { + std::vector timesteps; + std::vector sigmas; +}; + +BonsaiFlowMatchEulerSchedule bonsai_flow_match_euler_schedule( + uint64_t image_sequence_length, + uint64_t steps +); + +std::vector bonsai_flow_match_euler_step( + const std::vector& noise, + uint64_t timestep_index, + const std::vector& latents, + const BonsaiFlowMatchEulerSchedule& schedule +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tensor.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_tensor.cpp new file mode 100644 index 000000000..645e0f743 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tensor.cpp @@ -0,0 +1,165 @@ +#include "bonsai_tensor.h" + +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const std::string& tensor_key) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error("Bonsai tensor shape is too large: " + tensor_key); + } + return left * right; +} + +float f16_to_f32(uint16_t value) { + const uint32_t sign = (static_cast(value & 0x8000U)) << 16U; + uint32_t exponent = (value >> 10U) & 0x1FU; + uint32_t mantissa = value & 0x03FFU; + + uint32_t bits = 0; + if (exponent == 0) { + if (mantissa == 0) { + bits = sign; + } else { + exponent = 1; + while ((mantissa & 0x0400U) == 0) { + mantissa <<= 1U; + exponent--; + } + mantissa &= 0x03FFU; + bits = sign | ((exponent + 112U) << 23U) | (mantissa << 13U); + } + } else if (exponent == 0x1FU) { + bits = sign | 0x7F800000U | (mantissa << 13U); + } else { + bits = sign | ((exponent + 112U) << 23U) | (mantissa << 13U); + } + + float output = 0.0F; + std::memcpy(&output, &bits, sizeof(output)); + return output; +} + +float bf16_to_f32(uint16_t value) { + const uint32_t bits = static_cast(value) << 16U; + float output = 0.0F; + std::memcpy(&output, &bits, sizeof(output)); + return output; +} + +template +T read_unaligned(const uint8_t* data) { + T value {}; + std::memcpy(&value, data, sizeof(T)); + return value; +} + +} // namespace + +BonsaiDType bonsai_dtype_from_safetensors(const std::string& dtype) { + if (dtype == "BOOL") return BonsaiDType::Bool; + if (dtype == "U8") return BonsaiDType::U8; + if (dtype == "I8") return BonsaiDType::I8; + if (dtype == "U16") return BonsaiDType::U16; + if (dtype == "I16") return BonsaiDType::I16; + if (dtype == "U32") return BonsaiDType::U32; + if (dtype == "I32") return BonsaiDType::I32; + if (dtype == "U64") return BonsaiDType::U64; + if (dtype == "I64") return BonsaiDType::I64; + if (dtype == "F16") return BonsaiDType::F16; + if (dtype == "BF16") return BonsaiDType::BF16; + if (dtype == "F32") return BonsaiDType::F32; + if (dtype == "F64") return BonsaiDType::F64; + throw std::runtime_error("unsupported Bonsai tensor dtype: " + dtype); +} + +std::string bonsai_dtype_name(BonsaiDType dtype) { + switch (dtype) { + case BonsaiDType::Bool: return "BOOL"; + case BonsaiDType::U8: return "U8"; + case BonsaiDType::I8: return "I8"; + case BonsaiDType::U16: return "U16"; + case BonsaiDType::I16: return "I16"; + case BonsaiDType::U32: return "U32"; + case BonsaiDType::I32: return "I32"; + case BonsaiDType::U64: return "U64"; + case BonsaiDType::I64: return "I64"; + case BonsaiDType::F16: return "F16"; + case BonsaiDType::BF16: return "BF16"; + case BonsaiDType::F32: return "F32"; + case BonsaiDType::F64: return "F64"; + } +} + +bool bonsai_dtype_is_floating_point(BonsaiDType dtype) { + return dtype == BonsaiDType::F16 || + dtype == BonsaiDType::BF16 || + dtype == BonsaiDType::F32 || + dtype == BonsaiDType::F64; +} + +uint64_t bonsai_dtype_byte_count(BonsaiDType dtype) { + switch (dtype) { + case BonsaiDType::Bool: + case BonsaiDType::U8: + case BonsaiDType::I8: + return 1; + case BonsaiDType::U16: + case BonsaiDType::I16: + case BonsaiDType::F16: + case BonsaiDType::BF16: + return 2; + case BonsaiDType::U32: + case BonsaiDType::I32: + case BonsaiDType::F32: + return 4; + case BonsaiDType::U64: + case BonsaiDType::I64: + case BonsaiDType::F64: + return 8; + } +} + +uint64_t bonsai_shape_element_count( + const std::vector& shape, + const std::string& tensor_key +) { + uint64_t count = 1; + for (uint64_t dimension : shape) { + count = checked_multiply(count, dimension, tensor_key); + } + return count; +} + +float bonsai_read_scalar_as_f32(const uint8_t* data, BonsaiDType dtype) { + switch (dtype) { + case BonsaiDType::Bool: + return *data == 0 ? 0.0F : 1.0F; + case BonsaiDType::U8: + return static_cast(read_unaligned(data)); + case BonsaiDType::I8: + return static_cast(read_unaligned(data)); + case BonsaiDType::U16: + return static_cast(read_unaligned(data)); + case BonsaiDType::I16: + return static_cast(read_unaligned(data)); + case BonsaiDType::U32: + return static_cast(read_unaligned(data)); + case BonsaiDType::I32: + return static_cast(read_unaligned(data)); + case BonsaiDType::U64: + return static_cast(read_unaligned(data)); + case BonsaiDType::I64: + return static_cast(read_unaligned(data)); + case BonsaiDType::F16: + return f16_to_f32(read_unaligned(data)); + case BonsaiDType::BF16: + return bf16_to_f32(read_unaligned(data)); + case BonsaiDType::F32: + return read_unaligned(data); + case BonsaiDType::F64: + return static_cast(read_unaligned(data)); + } +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tensor.h b/feature/bonsai/src/androidMain/cpp/bonsai_tensor.h new file mode 100644 index 000000000..31ee8c4d9 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tensor.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +enum class BonsaiDType { + Bool, + U8, + I8, + U16, + I16, + U32, + I32, + U64, + I64, + F16, + BF16, + F32, + F64, +}; + +BonsaiDType bonsai_dtype_from_safetensors(const std::string& dtype); +std::string bonsai_dtype_name(BonsaiDType dtype); +bool bonsai_dtype_is_floating_point(BonsaiDType dtype); +uint64_t bonsai_dtype_byte_count(BonsaiDType dtype); +uint64_t bonsai_shape_element_count( + const std::vector& shape, + const std::string& tensor_key +); +float bonsai_read_scalar_as_f32(const uint8_t* data, BonsaiDType dtype); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.cpp new file mode 100644 index 000000000..22846f04e --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.cpp @@ -0,0 +1,176 @@ +#include "bonsai_tensor_storage.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const std::string& tensor_key) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error("Bonsai tensor shape is too large: " + tensor_key); + } + return left * right; +} + +size_t checked_size_t(uint64_t value, const std::string& label) { + if (value > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai " + label + " is too large for this runtime."); + } + return static_cast(value); +} + +} // namespace + +class BonsaiMappedFile { +public: + explicit BonsaiMappedFile(const std::string& path) : path_(path) { + fd_ = open(path.c_str(), O_RDONLY | O_CLOEXEC); + if (fd_ < 0) { + throw std::runtime_error("could not open Bonsai safetensors file: " + path); + } + + struct stat info {}; + if (fstat(fd_, &info) != 0 || !S_ISREG(info.st_mode)) { + close(fd_); + fd_ = -1; + throw std::runtime_error("invalid Bonsai safetensors file: " + path); + } + + size_ = static_cast(info.st_size); + if (size_ == 0) { + close(fd_); + fd_ = -1; + throw std::runtime_error("empty Bonsai safetensors file: " + path); + } + + size_t map_size = 0; + try { + map_size = checked_size_t(size_, "safetensors file"); + } catch (...) { + close(fd_); + fd_ = -1; + throw; + } + + mapped_ = mmap( + nullptr, + map_size, + PROT_READ, + MAP_PRIVATE, + fd_, + 0 + ); + if (mapped_ == MAP_FAILED) { + mapped_ = nullptr; + close(fd_); + fd_ = -1; + throw std::runtime_error("could not map Bonsai safetensors file: " + path); + } + } + + ~BonsaiMappedFile() { + if (mapped_ != nullptr) { + munmap(mapped_, static_cast(size_)); + } + if (fd_ >= 0) { + close(fd_); + } + } + + BonsaiMappedFile(const BonsaiMappedFile&) = delete; + BonsaiMappedFile& operator=(const BonsaiMappedFile&) = delete; + + const uint8_t* data() const { + return static_cast(mapped_); + } + + uint64_t size() const { + return size_; + } + +private: + std::string path_; + int fd_ = -1; + void* mapped_ = nullptr; + uint64_t size_ = 0; +}; + +BonsaiTensorStorage::BonsaiTensorStorage(const BonsaiSafetensorsIndex& index) { + for (const std::string& file : index.files()) { + mapped_files_.emplace(file, std::make_unique(file)); + } +} + +BonsaiTensorStorage::~BonsaiTensorStorage() = default; + +BonsaiTensorStorage::BonsaiTensorStorage(BonsaiTensorStorage&&) noexcept = default; + +BonsaiTensorStorage& BonsaiTensorStorage::operator=(BonsaiTensorStorage&&) noexcept = default; + +std::vector bonsai_tensor_view_to_f32_vector(const BonsaiTensorView& view) { + if (view.descriptor == nullptr || view.data == nullptr) { + throw std::runtime_error("Bonsai tensor view is empty."); + } + if (view.element_count > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("Bonsai tensor view is too large: " + view.descriptor->key); + } + + std::vector output; + output.reserve(static_cast(view.element_count)); + for (uint64_t index = 0; index < view.element_count; index++) { + output.push_back(bonsai_read_scalar_as_f32( + view.data + index * view.dtype_byte_count, + view.dtype + )); + } + return output; +} + +BonsaiTensorView BonsaiTensorStorage::view(const BonsaiTensorDescriptor& descriptor) const { + const auto found = mapped_files_.find(descriptor.file_path); + if (found == mapped_files_.end()) { + throw std::runtime_error("Bonsai tensor file is not mapped: " + descriptor.file_path); + } + + const BonsaiMappedFile& file = *found->second; + if (descriptor.data_start > descriptor.data_end || descriptor.data_end > file.size()) { + throw std::runtime_error("Bonsai tensor data range is invalid: " + descriptor.key); + } + + const uint64_t bytes = descriptor.data_end - descriptor.data_start; + const uint64_t elements = bonsai_shape_element_count(descriptor.shape, descriptor.key); + const uint64_t dtype_bytes = bonsai_dtype_byte_count(descriptor.dtype_type); + const uint64_t expected_bytes = checked_multiply(elements, dtype_bytes, descriptor.key); + if (expected_bytes != bytes) { + throw std::runtime_error( + "Bonsai tensor byte size mismatch: " + + descriptor.key + + " expected " + + std::to_string(expected_bytes) + + " got " + + std::to_string(bytes) + ); + } + + return BonsaiTensorView { + &descriptor, + file.data() + checked_size_t(descriptor.data_start, "tensor offset"), + bytes, + elements, + dtype_bytes, + descriptor.dtype_type, + }; +} + +BonsaiTensorView BonsaiTensorStorage::require_view( + const BonsaiSafetensorsIndex& index, + const std::string& key +) const { + return view(index.require(key)); +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.h b/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.h new file mode 100644 index 000000000..df5a05974 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tensor_storage.h @@ -0,0 +1,42 @@ +#pragma once + +#include "bonsai_safetensors.h" + +#include +#include +#include +#include +#include + +class BonsaiMappedFile; + +struct BonsaiTensorView { + const BonsaiTensorDescriptor* descriptor = nullptr; + const uint8_t* data = nullptr; + uint64_t byte_count = 0; + uint64_t element_count = 0; + uint64_t dtype_byte_count = 0; + BonsaiDType dtype = BonsaiDType::F32; +}; + +std::vector bonsai_tensor_view_to_f32_vector(const BonsaiTensorView& view); + +class BonsaiTensorStorage { +public: + explicit BonsaiTensorStorage(const BonsaiSafetensorsIndex& index); + ~BonsaiTensorStorage(); + + BonsaiTensorStorage(const BonsaiTensorStorage&) = delete; + BonsaiTensorStorage& operator=(const BonsaiTensorStorage&) = delete; + BonsaiTensorStorage(BonsaiTensorStorage&&) noexcept; + BonsaiTensorStorage& operator=(BonsaiTensorStorage&&) noexcept; + + BonsaiTensorView view(const BonsaiTensorDescriptor& descriptor) const; + BonsaiTensorView require_view( + const BonsaiSafetensorsIndex& index, + const std::string& key + ) const; + +private: + std::unordered_map> mapped_files_; +}; diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.cpp new file mode 100644 index 000000000..d533b688e --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.cpp @@ -0,0 +1,919 @@ +#include "bonsai_tokenizer.h" + +#include "bonsai_model_config.h" +#include "bonsai_prompt.h" + +#include +#include +#include +#include +#include +#include + +namespace { + +void skip_whitespace(const std::string& json, size_t* index) { + while (*index < json.size() && + std::isspace(static_cast(json[*index])) != 0) { + (*index)++; + } +} + +void require_index(const std::string& json, size_t index, const std::string& label) { + if (index >= json.size()) { + throw std::runtime_error("invalid Bonsai tokenizer JSON while reading " + label); + } +} + +uint32_t hex_value(char value) { + if (value >= '0' && value <= '9') { + return static_cast(value - '0'); + } + if (value >= 'a' && value <= 'f') { + return static_cast(value - 'a' + 10); + } + if (value >= 'A' && value <= 'F') { + return static_cast(value - 'A' + 10); + } + throw std::runtime_error("invalid Bonsai tokenizer JSON unicode escape"); +} + +uint32_t parse_hex4(const std::string& json, size_t index) { + if (index + 4U > json.size()) { + throw std::runtime_error("invalid Bonsai tokenizer JSON unicode escape"); + } + return (hex_value(json[index]) << 12U) | + (hex_value(json[index + 1U]) << 8U) | + (hex_value(json[index + 2U]) << 4U) | + hex_value(json[index + 3U]); +} + +std::string utf8_codepoint(uint32_t codepoint) { + std::string output; + if (codepoint <= 0x7FU) { + output.push_back(static_cast(codepoint)); + } else if (codepoint <= 0x7FFU) { + output.push_back(static_cast(0xC0U | (codepoint >> 6U))); + output.push_back(static_cast(0x80U | (codepoint & 0x3FU))); + } else if (codepoint <= 0xFFFFU) { + output.push_back(static_cast(0xE0U | (codepoint >> 12U))); + output.push_back(static_cast(0x80U | ((codepoint >> 6U) & 0x3FU))); + output.push_back(static_cast(0x80U | (codepoint & 0x3FU))); + } else if (codepoint <= 0x10FFFFU) { + output.push_back(static_cast(0xF0U | (codepoint >> 18U))); + output.push_back(static_cast(0x80U | ((codepoint >> 12U) & 0x3FU))); + output.push_back(static_cast(0x80U | ((codepoint >> 6U) & 0x3FU))); + output.push_back(static_cast(0x80U | (codepoint & 0x3FU))); + } else { + throw std::runtime_error("invalid Bonsai tokenizer unicode codepoint"); + } + return output; +} + +std::string parse_string_at(const std::string& json, size_t* index) { + require_index(json, *index, "string"); + if (json[*index] != '"') { + throw std::runtime_error("invalid Bonsai tokenizer JSON string"); + } + (*index)++; + + std::string output; + while (*index < json.size()) { + const char value = json[*index]; + (*index)++; + if (value == '"') { + return output; + } + if (value != '\\') { + output.push_back(value); + continue; + } + + require_index(json, *index, "escape"); + const char escaped = json[*index]; + (*index)++; + switch (escaped) { + case '"': + case '\\': + case '/': + output.push_back(escaped); + break; + case 'b': + output.push_back('\b'); + break; + case 'f': + output.push_back('\f'); + break; + case 'n': + output.push_back('\n'); + break; + case 'r': + output.push_back('\r'); + break; + case 't': + output.push_back('\t'); + break; + case 'u': + if (*index + 4U > json.size()) { + throw std::runtime_error("invalid Bonsai tokenizer JSON unicode escape"); + } + output += utf8_codepoint(parse_hex4(json, *index)); + *index += 4U; + break; + default: + throw std::runtime_error("invalid Bonsai tokenizer JSON escape"); + } + } + + throw std::runtime_error("unterminated Bonsai tokenizer JSON string"); +} + +size_t find_value_start(const std::string& json, const std::string& key) { + const std::string quoted_key = "\"" + key + "\""; + const size_t key_index = json.find(quoted_key); + if (key_index == std::string::npos) { + return std::string::npos; + } + + const size_t colon_index = json.find(':', key_index + quoted_key.size()); + if (colon_index == std::string::npos) { + throw std::runtime_error("invalid Bonsai tokenizer JSON key: " + key); + } + + size_t value_index = colon_index + 1U; + skip_whitespace(json, &value_index); + return value_index; +} + +bool optional_string_value( + const std::string& json, + const std::string& key, + std::string* output +) { + size_t value_index = find_value_start(json, key); + if (value_index == std::string::npos) { + return false; + } + require_index(json, value_index, key); + if (json[value_index] != '"') { + return false; + } + *output = parse_string_at(json, &value_index); + return true; +} + +bool optional_bool_value( + const std::string& json, + const std::string& key, + bool* output +) { + size_t value_index = find_value_start(json, key); + if (value_index == std::string::npos) { + return false; + } + if (json.compare(value_index, 4U, "true") == 0) { + *output = true; + return true; + } + if (json.compare(value_index, 5U, "false") == 0) { + *output = false; + return true; + } + throw std::runtime_error("invalid Bonsai tokenizer boolean JSON key: " + key); +} + +void skip_string(const std::string& json, size_t* index) { + (void) parse_string_at(json, index); +} + +void skip_container(const std::string& json, size_t* index, char open, char close) { + require_index(json, *index, "container"); + if (json[*index] != open) { + throw std::runtime_error("invalid Bonsai tokenizer JSON container"); + } + + uint64_t depth = 0; + while (*index < json.size()) { + const char value = json[*index]; + if (value == '"') { + skip_string(json, index); + continue; + } + if (value == open) { + depth++; + } else if (value == close) { + if (depth == 0) { + throw std::runtime_error("invalid Bonsai tokenizer JSON container depth"); + } + depth--; + (*index)++; + if (depth == 0) { + return; + } + continue; + } + (*index)++; + } + + throw std::runtime_error("unterminated Bonsai tokenizer JSON container"); +} + +size_t container_end(const std::string& json, size_t start, char open, char close) { + size_t index = start; + skip_container(json, &index, open, close); + return index; +} + +std::string extract_container_for_key( + const std::string& json, + const std::string& key, + char open, + char close +) { + const size_t start = find_value_start(json, key); + if (start == std::string::npos) { + throw std::runtime_error("missing " + key + " in Bonsai tokenizer JSON"); + } + require_index(json, start, key); + if (json[start] != open) { + throw std::runtime_error("invalid " + key + " in Bonsai tokenizer JSON"); + } + const size_t end = container_end(json, start, open, close); + return json.substr(start, end - start); +} + +int64_t parse_integer_value(const std::string& json, size_t* index) { + skip_whitespace(json, index); + require_index(json, *index, "integer"); + const size_t start = *index; + if (json[*index] == '-') { + (*index)++; + } + while (*index < json.size() && + std::isdigit(static_cast(json[*index])) != 0) { + (*index)++; + } + if (start == *index || (json[start] == '-' && start + 1U == *index)) { + throw std::runtime_error("invalid Bonsai tokenizer integer"); + } + return std::stoll(json.substr(start, *index - start)); +} + +bool optional_token_string( + const std::string& json, + const std::string& key, + std::string* output +) { + size_t value_index = find_value_start(json, key); + if (value_index == std::string::npos) { + return false; + } + require_index(json, value_index, key); + if (json[value_index] == '"') { + *output = parse_string_at(json, &value_index); + return true; + } + if (json[value_index] == '{') { + const size_t end = container_end(json, value_index, '{', '}'); + const std::string object = json.substr(value_index, end - value_index); + return optional_string_value(object, "content", output); + } + return false; +} + +int32_t checked_token_id(int64_t value, const std::string& label) { + if (value < 0 || value > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("invalid Bonsai tokenizer " + label + " id"); + } + return static_cast(value); +} + +bool optional_int_value( + const std::string& json, + const std::string& key, + int32_t* output +) { + size_t value_index = find_value_start(json, key); + if (value_index == std::string::npos) { + return false; + } + *output = checked_token_id(parse_integer_value(json, &value_index), key); + return true; +} + +std::unordered_map parse_vocab_entries(const std::string& vocab_json) { + std::unordered_map vocab; + size_t index = 0; + skip_whitespace(vocab_json, &index); + require_index(vocab_json, index, "vocab"); + if (vocab_json[index] != '{') { + throw std::runtime_error("invalid Bonsai tokenizer vocab"); + } + index++; + + while (index < vocab_json.size()) { + skip_whitespace(vocab_json, &index); + if (index < vocab_json.size() && vocab_json[index] == '}') { + return vocab; + } + + const std::string key = parse_string_at(vocab_json, &index); + skip_whitespace(vocab_json, &index); + require_index(vocab_json, index, "vocab colon"); + if (vocab_json[index] != ':') { + throw std::runtime_error("invalid Bonsai tokenizer vocab member"); + } + index++; + vocab.emplace(key, checked_token_id(parse_integer_value(vocab_json, &index), "vocab")); + + skip_whitespace(vocab_json, &index); + if (index < vocab_json.size() && vocab_json[index] == ',') { + index++; + continue; + } + if (index < vocab_json.size() && vocab_json[index] == '}') { + return vocab; + } + throw std::runtime_error("invalid Bonsai tokenizer vocab separator"); + } + + throw std::runtime_error("unterminated Bonsai tokenizer vocab"); +} + +BonsaiTokenizerMerge parse_merge_string(const std::string& value) { + const size_t separator = value.find(' '); + if (separator == std::string::npos || separator + 1U >= value.size()) { + throw std::runtime_error("invalid Bonsai tokenizer merge pair"); + } + return BonsaiTokenizerMerge { + value.substr(0, separator), + value.substr(separator + 1U), + }; +} + +BonsaiTokenizerMerge parse_merge_array(const std::string& merges_json, size_t* index) { + require_index(merges_json, *index, "merge array"); + if (merges_json[*index] != '[') { + throw std::runtime_error("invalid Bonsai tokenizer merge array"); + } + (*index)++; + skip_whitespace(merges_json, index); + const std::string first = parse_string_at(merges_json, index); + skip_whitespace(merges_json, index); + require_index(merges_json, *index, "merge separator"); + if (merges_json[*index] != ',') { + throw std::runtime_error("invalid Bonsai tokenizer merge array separator"); + } + (*index)++; + skip_whitespace(merges_json, index); + const std::string second = parse_string_at(merges_json, index); + skip_whitespace(merges_json, index); + require_index(merges_json, *index, "merge close"); + if (merges_json[*index] != ']') { + throw std::runtime_error("invalid Bonsai tokenizer merge array close"); + } + (*index)++; + return BonsaiTokenizerMerge {first, second}; +} + +std::vector parse_merge_entries( + const std::string& merges_json, + bool* first_entry_is_array +) { + std::vector merges; + size_t index = 0; + skip_whitespace(merges_json, &index); + require_index(merges_json, index, "merges"); + if (merges_json[index] != '[') { + throw std::runtime_error("invalid Bonsai tokenizer merges"); + } + index++; + + bool first_entry = true; + while (index < merges_json.size()) { + skip_whitespace(merges_json, &index); + if (index < merges_json.size() && merges_json[index] == ']') { + return merges; + } + + if (first_entry) { + *first_entry_is_array = merges_json[index] == '['; + first_entry = false; + } + if (merges_json[index] == '[') { + merges.push_back(parse_merge_array(merges_json, &index)); + } else if (merges_json[index] == '"') { + merges.push_back(parse_merge_string(parse_string_at(merges_json, &index))); + } else { + throw std::runtime_error("invalid Bonsai tokenizer merge entry"); + } + + skip_whitespace(merges_json, &index); + if (index < merges_json.size() && merges_json[index] == ',') { + index++; + continue; + } + if (index < merges_json.size() && merges_json[index] == ']') { + return merges; + } + throw std::runtime_error("invalid Bonsai tokenizer merge separator"); + } + + throw std::runtime_error("unterminated Bonsai tokenizer merges"); +} + +std::vector parse_added_tokens(const std::string& tokenizer_json) { + const size_t start = find_value_start(tokenizer_json, "added_tokens"); + if (start == std::string::npos) { + return {}; + } + require_index(tokenizer_json, start, "added_tokens"); + if (tokenizer_json[start] != '[') { + throw std::runtime_error("invalid Bonsai tokenizer added_tokens"); + } + + const size_t end = container_end(tokenizer_json, start, '[', ']'); + const std::string added_tokens_json = tokenizer_json.substr(start, end - start); + size_t index = 0; + skip_whitespace(added_tokens_json, &index); + index++; + + std::vector added_tokens; + while (index < added_tokens_json.size()) { + skip_whitespace(added_tokens_json, &index); + if (index < added_tokens_json.size() && added_tokens_json[index] == ']') { + return added_tokens; + } + if (added_tokens_json[index] != '{') { + throw std::runtime_error("invalid Bonsai tokenizer added token object"); + } + const size_t object_end = container_end(added_tokens_json, index, '{', '}'); + const std::string object = added_tokens_json.substr(index, object_end - index); + + BonsaiTokenizerAddedToken token; + if (!optional_string_value(object, "content", &token.content) || + !optional_int_value(object, "id", &token.id)) { + throw std::runtime_error("invalid Bonsai tokenizer added token"); + } + (void) optional_bool_value(object, "special", &token.special); + added_tokens.push_back(token); + index = object_end; + + skip_whitespace(added_tokens_json, &index); + if (index < added_tokens_json.size() && added_tokens_json[index] == ',') { + index++; + continue; + } + if (index < added_tokens_json.size() && added_tokens_json[index] == ']') { + return added_tokens; + } + throw std::runtime_error("invalid Bonsai tokenizer added_tokens separator"); + } + + throw std::runtime_error("unterminated Bonsai tokenizer added_tokens"); +} + +bool starts_with(const std::string& value, const std::string& prefix) { + return value.size() >= prefix.size() && value.substr(0, prefix.size()) == prefix; +} + +std::unordered_map merge_ranks(const BonsaiTokenizerData& data) { + std::unordered_map ranks; + ranks.reserve(data.merges.size()); + for (size_t index = 0; index < data.merges.size(); index++) { + ranks.emplace( + data.merges[index].first + '\n' + data.merges[index].second, + static_cast(index) + ); + } + return ranks; +} + +std::vector utf8_scalars(const std::string& value) { + std::vector scalars; + size_t index = 0; + while (index < value.size()) { + const unsigned char byte = static_cast(value[index]); + size_t length = 1; + if ((byte & 0x80U) == 0U) { + length = 1; + } else if ((byte & 0xE0U) == 0xC0U) { + length = 2; + } else if ((byte & 0xF0U) == 0xE0U) { + length = 3; + } else if ((byte & 0xF8U) == 0xF0U) { + length = 4; + } else { + throw std::runtime_error("invalid Bonsai tokenizer UTF-8 token"); + } + if (index + length > value.size()) { + throw std::runtime_error("truncated Bonsai tokenizer UTF-8 token"); + } + scalars.push_back(value.substr(index, length)); + index += length; + } + return scalars; +} + +std::vector byte_encoder_table() { + std::vector bytes; + for (uint32_t value = 33; value <= 126; value++) { + bytes.push_back(value); + } + for (uint32_t value = 161; value <= 172; value++) { + bytes.push_back(value); + } + for (uint32_t value = 174; value <= 255; value++) { + bytes.push_back(value); + } + + std::unordered_set present(bytes.begin(), bytes.end()); + std::vector codepoints = bytes; + uint32_t extra = 0; + for (uint32_t value = 0; value <= 255; value++) { + if (present.find(value) == present.end()) { + bytes.push_back(value); + codepoints.push_back(256U + extra); + extra++; + } + } + + std::vector table(256); + for (size_t index = 0; index < bytes.size(); index++) { + table[bytes[index]] = utf8_codepoint(codepoints[index]); + } + return table; +} + +const std::vector& byte_encoder() { + static const std::vector table = byte_encoder_table(); + return table; +} + +std::string byte_encode_token(const std::string& token) { + const std::vector& table = byte_encoder(); + std::string output; + for (unsigned char byte : token) { + output += table[byte]; + } + return output; +} + +bool is_ascii_letter(unsigned char byte) { + return (byte >= 'A' && byte <= 'Z') || (byte >= 'a' && byte <= 'z'); +} + +bool is_ascii_digit(unsigned char byte) { + return byte >= '0' && byte <= '9'; +} + +bool is_ascii_space(unsigned char byte) { + return byte == ' ' || byte == '\n' || byte == '\t' || byte == '\r' || byte == '\f'; +} + +bool starts_with_at(const std::string& value, size_t index, const std::string& prefix) { + return index + prefix.size() <= value.size() && + value.compare(index, prefix.size(), prefix) == 0; +} + +std::vector byte_level_pretokens(const std::string& text) { + std::vector tokens; + size_t index = 0; + const std::vector contractions = {"'s", "'t", "'re", "'ve", "'m", "'ll", "'d"}; + while (index < text.size()) { + bool matched_contraction = false; + for (const std::string& contraction : contractions) { + if (starts_with_at(text, index, contraction)) { + tokens.push_back(byte_encode_token(contraction)); + index += contraction.size(); + matched_contraction = true; + break; + } + } + if (matched_contraction) { + continue; + } + + const size_t start = index; + const unsigned char current = static_cast(text[index]); + if (current == ' ' && index + 1U < text.size()) { + const unsigned char next = static_cast(text[index + 1U]); + if (is_ascii_letter(next)) { + index += 2U; + while (index < text.size() && + is_ascii_letter(static_cast(text[index]))) { + index++; + } + tokens.push_back(byte_encode_token(text.substr(start, index - start))); + continue; + } + if (is_ascii_digit(next)) { + index += 2U; + while (index < text.size() && + is_ascii_digit(static_cast(text[index]))) { + index++; + } + tokens.push_back(byte_encode_token(text.substr(start, index - start))); + continue; + } + if (!is_ascii_space(next)) { + index += 2U; + while (index < text.size()) { + const unsigned char value = static_cast(text[index]); + if (is_ascii_space(value) || is_ascii_letter(value) || is_ascii_digit(value)) { + break; + } + index++; + } + tokens.push_back(byte_encode_token(text.substr(start, index - start))); + continue; + } + } + + if (is_ascii_letter(current)) { + index++; + while (index < text.size() && + is_ascii_letter(static_cast(text[index]))) { + index++; + } + } else if (is_ascii_digit(current)) { + index++; + while (index < text.size() && + is_ascii_digit(static_cast(text[index]))) { + index++; + } + } else if (is_ascii_space(current)) { + index++; + while (index < text.size() && + is_ascii_space(static_cast(text[index]))) { + index++; + } + if (index < text.size() && index > start + 1U) { + index--; + } + } else { + index++; + while (index < text.size()) { + const unsigned char value = static_cast(text[index]); + if (is_ascii_space(value) || is_ascii_letter(value) || is_ascii_digit(value)) { + break; + } + index++; + } + } + tokens.push_back(byte_encode_token(text.substr(start, index - start))); + } + return tokens; +} + +std::vector bpe_tokens( + const std::string& token, + const std::unordered_map& ranks +) { + std::vector word = utf8_scalars(token); + if (word.size() <= 1U) { + return word; + } + + while (true) { + int best_rank = std::numeric_limits::max(); + size_t best_index = std::numeric_limits::max(); + for (size_t index = 0; index + 1U < word.size(); index++) { + const auto found = ranks.find(word[index] + '\n' + word[index + 1U]); + if (found != ranks.end() && found->second < best_rank) { + best_rank = found->second; + best_index = index; + } + } + if (best_index == std::numeric_limits::max()) { + break; + } + + std::vector merged; + merged.reserve(word.size() - 1U); + for (size_t index = 0; index < word.size(); index++) { + if (index == best_index) { + merged.push_back(word[index] + word[index + 1U]); + index++; + } else { + merged.push_back(word[index]); + } + } + word = merged; + if (word.size() == 1U) { + break; + } + } + return word; +} + +void append_token_id( + const BonsaiTokenizerData& data, + const std::string& token, + std::vector* ids +) { + const auto added = data.added_token_ids.find(token); + if (added != data.added_token_ids.end()) { + ids->push_back(added->second); + return; + } + const auto vocab = data.vocab.find(token); + if (vocab != data.vocab.end()) { + ids->push_back(vocab->second); + return; + } + + for (unsigned char byte : token) { + std::ostringstream key; + key << "<0x"; + const char* digits = "0123456789ABCDEF"; + key << digits[(byte >> 4U) & 0x0FU] << digits[byte & 0x0FU] << ">"; + const auto fallback = data.vocab.find(key.str()); + if (fallback == data.vocab.end()) { + throw std::runtime_error("Bonsai tokenizer token is missing from vocab: " + token); + } + ids->push_back(fallback->second); + } +} + +void encode_regular_segment( + const BonsaiTokenizerData& data, + const std::unordered_map& ranks, + const std::string& segment, + std::vector* ids +) { + for (const std::string& pretoken : byte_level_pretokens(segment)) { + for (const std::string& token : bpe_tokens(pretoken, ranks)) { + append_token_id(data, token, ids); + } + } +} + +} // namespace + +BonsaiTokenizerMetadata bonsai_load_tokenizer_metadata(const std::string& tokenizer_path) { + return bonsai_load_tokenizer_data(tokenizer_path).metadata; +} + +BonsaiTokenizerData bonsai_load_tokenizer_data(const std::string& tokenizer_path) { + const std::string tokenizer_config_path = bonsai_join_path( + tokenizer_path, + "tokenizer_config.json" + ); + const std::string tokenizer_json_path = bonsai_join_path(tokenizer_path, "tokenizer.json"); + bonsai_require_file(tokenizer_config_path, "tokenizer config"); + bonsai_require_file(tokenizer_json_path, "tokenizer"); + + const std::string tokenizer_config_json = bonsai_read_text_file(tokenizer_config_path); + const std::string tokenizer_json = bonsai_read_text_file(tokenizer_json_path); + const std::string model_json = extract_container_for_key(tokenizer_json, "model", '{', '}'); + const std::string vocab_json = extract_container_for_key(model_json, "vocab", '{', '}'); + const std::string merges_json = extract_container_for_key(model_json, "merges", '[', ']'); + + BonsaiTokenizerData data; + BonsaiTokenizerMetadata& metadata = data.metadata; + (void) optional_string_value( + tokenizer_config_json, + "tokenizer_class", + &metadata.original_tokenizer_class + ); + metadata.qwen_tokenizer_class_rewritten = starts_with( + metadata.original_tokenizer_class, + "Qwen" + ); + metadata.runtime_tokenizer_class = metadata.qwen_tokenizer_class_rewritten + ? "GPT2Tokenizer" + : metadata.original_tokenizer_class; + (void) optional_string_value(model_json, "type", &metadata.model_type); + + bool merges_are_pair_arrays = false; + data.vocab = parse_vocab_entries(vocab_json); + data.merges = parse_merge_entries(merges_json, &merges_are_pair_arrays); + data.added_tokens = parse_added_tokens(tokenizer_json); + for (const BonsaiTokenizerAddedToken& token : data.added_tokens) { + data.added_token_ids[token.content] = token.id; + } + metadata.vocab_size = static_cast(data.vocab.size()); + metadata.merge_count = static_cast(data.merges.size()); + metadata.merge_pairs_need_normalization = merges_are_pair_arrays; + + const BonsaiQwenPromptSpec prompt_spec = bonsai_qwen_prompt_spec(); + (void) optional_token_string(tokenizer_config_json, "eos_token", &metadata.eos_token); + if (!metadata.eos_token.empty()) { + const auto eos_added_id = data.added_token_ids.find(metadata.eos_token); + if (eos_added_id != data.added_token_ids.end()) { + metadata.eos_token_id = eos_added_id->second; + } else if (const auto eos_id = data.vocab.find(metadata.eos_token); + eos_id != data.vocab.end()) { + metadata.eos_token_id = eos_id->second; + } + } + if (metadata.eos_token_id == 0) { + metadata.eos_token_id = prompt_spec.eos_token_id; + } + + (void) optional_token_string(tokenizer_config_json, "pad_token", &metadata.pad_token); + if (metadata.pad_token.empty()) { + const std::string special_tokens_map_path = bonsai_join_path( + tokenizer_path, + "special_tokens_map.json" + ); + if (bonsai_path_is_regular_file(special_tokens_map_path)) { + const std::string special_tokens_json = bonsai_read_text_file(special_tokens_map_path); + (void) optional_token_string(special_tokens_json, "pad_token", &metadata.pad_token); + } + } + if (!metadata.pad_token.empty()) { + const auto pad_added_id = data.added_token_ids.find(metadata.pad_token); + if (pad_added_id != data.added_token_ids.end()) { + metadata.pad_token_id = pad_added_id->second; + } else if (const auto pad_id = data.vocab.find(metadata.pad_token); + pad_id != data.vocab.end()) { + metadata.pad_token_id = pad_id->second; + } + } + if (metadata.pad_token_id == 0) { + metadata.pad_token_id = metadata.eos_token_id != 0 + ? metadata.eos_token_id + : prompt_spec.pad_token_id; + } + + return data; +} + +uint64_t bonsai_tokenizer_metadata_checksum(const BonsaiTokenizerMetadata& metadata) { + uint64_t checksum = metadata.vocab_size * 3U + metadata.merge_count * 5U; + checksum += static_cast(metadata.pad_token_id) * 7U; + checksum += static_cast(metadata.eos_token_id) * 11U; + checksum += metadata.qwen_tokenizer_class_rewritten ? 13U : 0U; + checksum += metadata.merge_pairs_need_normalization ? 17U : 0U; + for (char value : metadata.runtime_tokenizer_class) { + checksum += static_cast(static_cast(value)); + } + for (char value : metadata.model_type) { + checksum += static_cast(static_cast(value)) * 2U; + } + return checksum; +} + +uint64_t bonsai_tokenizer_data_checksum(const BonsaiTokenizerData& data) { + uint64_t checksum = bonsai_tokenizer_metadata_checksum(data.metadata); + checksum += static_cast(data.vocab.size()) * 29U; + checksum += static_cast(data.merges.size()) * 31U; + checksum += static_cast(data.added_tokens.size()) * 37U; + const size_t merge_limit = std::min(data.merges.size(), 16); + for (size_t index = 0; index < merge_limit; index++) { + for (char value : data.merges[index].first) { + checksum += static_cast(static_cast(value)); + } + for (char value : data.merges[index].second) { + checksum += static_cast(static_cast(value)) * 2U; + } + } + return checksum; +} + +std::vector bonsai_tokenizer_encode( + const BonsaiTokenizerData& data, + const std::string& text +) { + const std::unordered_map ranks = merge_ranks(data); + std::vector ids; + size_t index = 0; + + while (index < text.size()) { + size_t best_position = std::string::npos; + const BonsaiTokenizerAddedToken* best_token = nullptr; + for (const BonsaiTokenizerAddedToken& token : data.added_tokens) { + if (token.content.empty()) { + continue; + } + const size_t position = text.find(token.content, index); + if (position != std::string::npos && + (best_token == nullptr || + position < best_position || + (position == best_position && token.content.size() > best_token->content.size()))) { + best_position = position; + best_token = &token; + } + } + + if (best_token == nullptr) { + encode_regular_segment(data, ranks, text.substr(index), &ids); + break; + } + + if (best_position > index) { + encode_regular_segment( + data, + ranks, + text.substr(index, best_position - index), + &ids + ); + } + ids.push_back(best_token->id); + index = best_position + best_token->content.size(); + } + + return ids; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.h b/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.h new file mode 100644 index 000000000..e664afbf1 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_tokenizer.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include + +struct BonsaiTokenizerMetadata { + std::string original_tokenizer_class; + std::string runtime_tokenizer_class; + std::string model_type; + std::string pad_token; + std::string eos_token; + int32_t pad_token_id = 0; + int32_t eos_token_id = 0; + uint64_t vocab_size = 0; + uint64_t merge_count = 0; + bool qwen_tokenizer_class_rewritten = false; + bool merge_pairs_need_normalization = false; +}; + +struct BonsaiTokenizerMerge { + std::string first; + std::string second; +}; + +struct BonsaiTokenizerAddedToken { + std::string content; + int32_t id = 0; + bool special = false; +}; + +struct BonsaiTokenizerData { + BonsaiTokenizerMetadata metadata; + std::unordered_map vocab; + std::unordered_map added_token_ids; + std::vector merges; + std::vector added_tokens; +}; + +BonsaiTokenizerMetadata bonsai_load_tokenizer_metadata(const std::string& tokenizer_path); + +BonsaiTokenizerData bonsai_load_tokenizer_data(const std::string& tokenizer_path); + +uint64_t bonsai_tokenizer_metadata_checksum(const BonsaiTokenizerMetadata& metadata); + +uint64_t bonsai_tokenizer_data_checksum(const BonsaiTokenizerData& data); + +std::vector bonsai_tokenizer_encode( + const BonsaiTokenizerData& data, + const std::string& text +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.cpp new file mode 100644 index 000000000..478417803 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.cpp @@ -0,0 +1,335 @@ +#include "bonsai_vae_decoder.h" + +#include "bonsai_activation.h" +#include "bonsai_tensor.h" + +#include + +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; + +void log_vae_phase(const char* phase, const BonsaiNchwTensor& tensor) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=%s batch=%llu channels=%llu height=%llu width=%llu", + phase, + static_cast(tensor.batch_size), + static_cast(tensor.channels), + static_cast(tensor.height), + static_cast(tensor.width) + ); +} + +void log_vae_block_phase(const char* phase, uint64_t block, const BonsaiNchwTensor& tensor) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=%s block=%llu batch=%llu channels=%llu height=%llu width=%llu", + phase, + static_cast(block), + static_cast(tensor.batch_size), + static_cast(tensor.channels), + static_cast(tensor.height), + static_cast(tensor.width) + ); +} + +void add_bytes(uint64_t* bytes, uint64_t extra, const char* label) { + if (*bytes > std::numeric_limits::max() - extra) { + throw std::runtime_error(std::string("Bonsai VAE decoder byte count overflow: ") + label); + } + *bytes += extra; +} + +void require_conv_shape( + const BonsaiVaeConv2dViews& views, + uint64_t input_channels, + uint64_t output_channels, + const std::string& prefix +) { + if (input_channels != 0 && views.input_channels != input_channels) { + throw std::runtime_error("Bonsai VAE decoder conv input channel mismatch: " + prefix); + } + if (output_channels != 0 && views.output_channels != output_channels) { + throw std::runtime_error("Bonsai VAE decoder conv output channel mismatch: " + prefix); + } +} + +void require_tensor_vector( + const BonsaiTensorView& view, + const std::string& key +) { + if (!bonsai_dtype_is_floating_point(view.dtype)) { + throw std::runtime_error("Bonsai VAE decoder tensor must be floating point: " + key); + } + if (view.element_count == 0) { + throw std::runtime_error("Bonsai VAE decoder tensor must be non-empty: " + key); + } +} + +BonsaiVaeMidBlockViews require_mid_block_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + uint64_t channels, + uint64_t group_count, + float epsilon +) { + if (channels == 0) { + throw std::runtime_error("Bonsai VAE mid block channels must be positive."); + } + return BonsaiVaeMidBlockViews { + bonsai_vae_require_resnet_views( + storage, + index, + "decoder.mid_block.resnets.0", + channels, + channels, + group_count, + epsilon + ), + bonsai_vae_require_attention_views( + storage, + index, + "decoder.mid_block.attentions.0", + channels, + group_count, + epsilon + ), + bonsai_vae_require_resnet_views( + storage, + index, + "decoder.mid_block.resnets.1", + channels, + channels, + group_count, + epsilon + ), + channels, + }; +} + +BonsaiVaeDecoderViews require_decoder_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiFluxVaeConfig& config, + uint64_t decoder_input_channels +) { + if (config.block_out_channels_count == 0 || + config.block_out_channels.size() != static_cast(config.block_out_channels_count) || + config.layers_per_block == 0 || + config.norm_num_groups == 0) { + throw std::runtime_error("invalid Bonsai VAE decoder config."); + } + + BonsaiVaeDecoderViews views; + views.input_channels = decoder_input_channels; + views.conv_in = bonsai_vae_require_conv2d_views(storage, index, "decoder.conv_in"); + require_conv_shape( + views.conv_in, + decoder_input_channels, + config.block_out_channels.back(), + "decoder.conv_in" + ); + views.mid_block = require_mid_block_views( + storage, + index, + config.block_out_channels.back(), + config.norm_num_groups, + 1e-6F + ); + + views.up_blocks.reserve(static_cast(config.block_out_channels_count)); + const uint64_t layer_count = config.layers_per_block + 1U; + for (uint64_t block = 0; block < config.block_out_channels_count; block++) { + const uint64_t output_channels = config.block_out_channels[ + static_cast(config.block_out_channels_count - 1U - block) + ]; + const uint64_t input_channels = block == 0 + ? output_channels + : config.block_out_channels[static_cast(config.block_out_channels_count - block)]; + views.up_blocks.push_back(bonsai_vae_require_up_block_views( + storage, + index, + "decoder.up_blocks." + std::to_string(block), + input_channels, + output_channels, + layer_count, + config.norm_num_groups, + block + 1U < config.block_out_channels_count, + 1e-6F + )); + } + + views.norm_out = bonsai_vae_require_group_norm_views( + storage, + index, + "decoder.conv_norm_out", + config.block_out_channels.front(), + config.norm_num_groups, + 1e-6F + ); + views.conv_out = bonsai_vae_require_conv2d_views(storage, index, "decoder.conv_out"); + require_conv_shape( + views.conv_out, + config.block_out_channels.front(), + 0, + "decoder.conv_out" + ); + views.output_channels = views.conv_out.output_channels; + return views; +} + +} // namespace + +BonsaiVaeDecodeViews bonsai_vae_require_decode_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiFluxVaeConfig& config +) { + BonsaiTensorView mean = storage.require_view(index, "bn.running_mean"); + BonsaiTensorView variance = storage.require_view(index, "bn.running_var"); + require_tensor_vector(mean, "bn.running_mean"); + require_tensor_vector(variance, "bn.running_var"); + if (mean.element_count != variance.element_count || mean.element_count % 4U != 0) { + throw std::runtime_error("Bonsai VAE decoder batch norm channel mismatch."); + } + if (config.batch_norm_eps <= 0.0F) { + throw std::runtime_error("Bonsai VAE decoder batch norm epsilon must be positive."); + } + + BonsaiVaeConv2dViews post_quant_conv = bonsai_vae_require_conv2d_views( + storage, + index, + "post_quant_conv" + ); + require_conv_shape( + post_quant_conv, + mean.element_count / 4U, + 0, + "post_quant_conv" + ); + + return BonsaiVaeDecodeViews { + mean, + variance, + post_quant_conv, + require_decoder_views(storage, index, config, post_quant_conv.output_channels), + mean.element_count, + config.batch_norm_eps, + }; +} + +BonsaiNchwTensor bonsai_vae_mid_block_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeMidBlockViews& views +) { + if (input.channels != views.channels) { + throw std::runtime_error("Bonsai VAE mid block input channel mismatch."); + } + log_vae_phase("vae_mid_resnet0_start", input); + BonsaiNchwTensor output = bonsai_vae_resnet_view_nchw(input, views.resnet0); + log_vae_phase("vae_mid_resnet0_done", output); + log_vae_phase("vae_mid_attention_start", output); + output = bonsai_vae_attention_view_nchw(output, views.attention); + log_vae_phase("vae_mid_attention_done", output); + log_vae_phase("vae_mid_resnet1_start", output); + output = bonsai_vae_resnet_view_nchw(output, views.resnet1); + log_vae_phase("vae_mid_resnet1_done", output); + return output; +} + +BonsaiNchwTensor bonsai_vae_decoder_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeDecoderViews& views +) { + if (input.channels != views.input_channels) { + throw std::runtime_error("Bonsai VAE decoder input channel mismatch."); + } + + log_vae_phase("vae_decoder_conv_in_start", input); + BonsaiNchwTensor output = bonsai_vae_conv2d_view_nchw(input, views.conv_in); + log_vae_phase("vae_decoder_conv_in_done", output); + log_vae_phase("vae_decoder_mid_start", output); + output = bonsai_vae_mid_block_view_nchw(output, views.mid_block); + log_vae_phase("vae_decoder_mid_done", output); + for (size_t index = 0; index < views.up_blocks.size(); index++) { + const BonsaiVaeUpBlockViews& up_block = views.up_blocks[index]; + log_vae_block_phase("vae_decoder_up_block_start", index + 1U, output); + output = bonsai_vae_up_block_view_nchw(output, up_block); + log_vae_block_phase("vae_decoder_up_block_done", index + 1U, output); + } + log_vae_phase("vae_decoder_norm_out_start", output); + output = bonsai_vae_group_norm_view_nchw(output, views.norm_out); + log_vae_phase("vae_decoder_norm_out_done", output); + output.values = bonsai_silu(output.values); + log_vae_phase("vae_decoder_conv_out_start", output); + return bonsai_vae_conv2d_view_nchw(output, views.conv_out); +} + +BonsaiNchwTensor bonsai_vae_decode_prelude_view_nchw( + const BonsaiNchwTensor& packed_latents, + const BonsaiVaeDecodeViews& views +) { + if (packed_latents.channels != views.packed_channels) { + throw std::runtime_error("Bonsai VAE packed latent channel mismatch."); + } + log_vae_phase("vae_prelude_denormalize_start", packed_latents); + const BonsaiNchwTensor denormalized = bonsai_vae_denormalize_channels_nchw( + packed_latents, + bonsai_tensor_view_to_f32_vector(views.batch_norm_mean), + bonsai_tensor_view_to_f32_vector(views.batch_norm_variance), + views.batch_norm_epsilon + ); + log_vae_phase("vae_prelude_denormalize_done", denormalized); + log_vae_phase("vae_prelude_unpatchify_start", denormalized); + const BonsaiNchwTensor unpatchified = bonsai_vae_unpatchify_nchw(denormalized); + log_vae_phase("vae_prelude_unpatchify_done", unpatchified); + log_vae_phase("vae_prelude_post_quant_conv_start", unpatchified); + return bonsai_vae_conv2d_view_nchw( + unpatchified, + views.post_quant_conv + ); +} + +BonsaiNchwTensor bonsai_vae_decode_packed_view_nchw( + const BonsaiNchwTensor& packed_latents, + const BonsaiVaeDecodeViews& views +) { + BonsaiNchwTensor output = bonsai_vae_decode_prelude_view_nchw(packed_latents, views); + log_vae_phase("vae_prelude_done", output); + output = bonsai_vae_decoder_view_nchw(output, views.decoder); + log_vae_phase("vae_decoder_done", output); + return output; +} + +uint64_t bonsai_vae_mid_block_byte_count(const BonsaiVaeMidBlockViews& views) { + uint64_t bytes = bonsai_vae_resnet_byte_count(views.resnet0); + add_bytes(&bytes, bonsai_vae_attention_byte_count(views.attention), "mid attention"); + add_bytes(&bytes, bonsai_vae_resnet_byte_count(views.resnet1), "mid resnet1"); + return bytes; +} + +uint64_t bonsai_vae_decoder_byte_count(const BonsaiVaeDecoderViews& views) { + uint64_t bytes = bonsai_vae_conv2d_byte_count(views.conv_in); + add_bytes(&bytes, bonsai_vae_mid_block_byte_count(views.mid_block), "mid block"); + for (const BonsaiVaeUpBlockViews& up_block : views.up_blocks) { + add_bytes(&bytes, bonsai_vae_up_block_byte_count(up_block), "up block"); + } + add_bytes(&bytes, bonsai_vae_group_norm_byte_count(views.norm_out), "norm out"); + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.conv_out), "conv out"); + return bytes; +} + +uint64_t bonsai_vae_decode_byte_count(const BonsaiVaeDecodeViews& views) { + uint64_t bytes = views.batch_norm_mean.byte_count; + add_bytes(&bytes, views.batch_norm_variance.byte_count, "batch norm variance"); + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.post_quant_conv), "post quant conv"); + add_bytes(&bytes, bonsai_vae_decoder_byte_count(views.decoder), "decoder"); + return bytes; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.h b/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.h new file mode 100644 index 000000000..356a2d3c1 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vae_decoder.h @@ -0,0 +1,67 @@ +#pragma once + +#include "bonsai_flux_vae.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" +#include "bonsai_vae_ops.h" + +#include +#include + +struct BonsaiVaeMidBlockViews { + BonsaiVaeResnetViews resnet0; + BonsaiVaeAttentionViews attention; + BonsaiVaeResnetViews resnet1; + uint64_t channels = 0; +}; + +struct BonsaiVaeDecoderViews { + BonsaiVaeConv2dViews conv_in; + BonsaiVaeMidBlockViews mid_block; + std::vector up_blocks; + BonsaiVaeGroupNormViews norm_out; + BonsaiVaeConv2dViews conv_out; + uint64_t input_channels = 0; + uint64_t output_channels = 0; +}; + +struct BonsaiVaeDecodeViews { + BonsaiTensorView batch_norm_mean; + BonsaiTensorView batch_norm_variance; + BonsaiVaeConv2dViews post_quant_conv; + BonsaiVaeDecoderViews decoder; + uint64_t packed_channels = 0; + float batch_norm_epsilon = 0.0F; +}; + +BonsaiVaeDecodeViews bonsai_vae_require_decode_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const BonsaiFluxVaeConfig& config +); + +BonsaiNchwTensor bonsai_vae_mid_block_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeMidBlockViews& views +); + +BonsaiNchwTensor bonsai_vae_decoder_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeDecoderViews& views +); + +BonsaiNchwTensor bonsai_vae_decode_prelude_view_nchw( + const BonsaiNchwTensor& packed_latents, + const BonsaiVaeDecodeViews& views +); + +BonsaiNchwTensor bonsai_vae_decode_packed_view_nchw( + const BonsaiNchwTensor& packed_latents, + const BonsaiVaeDecodeViews& views +); + +uint64_t bonsai_vae_mid_block_byte_count(const BonsaiVaeMidBlockViews& views); + +uint64_t bonsai_vae_decoder_byte_count(const BonsaiVaeDecoderViews& views); + +uint64_t bonsai_vae_decode_byte_count(const BonsaiVaeDecodeViews& views); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.cpp new file mode 100644 index 000000000..cc5d7d970 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.cpp @@ -0,0 +1,1049 @@ +#include "bonsai_vae_ops.h" + +#include "bonsai_activation.h" +#include "bonsai_attention.h" +#include "bonsai_linear.h" +#include "bonsai_tensor.h" + +#include +#include +#include +#include +#include + +namespace { + +uint64_t checked_multiply(uint64_t left, uint64_t right, const char* label) { + if (left != 0 && right > std::numeric_limits::max() / left) { + throw std::runtime_error(std::string("Bonsai VAE tensor shape is too large: ") + label); + } + return left * right; +} + +uint64_t tensor_size( + uint64_t batch_size, + uint64_t channels, + uint64_t height, + uint64_t width, + const char* label +) { + return checked_multiply( + checked_multiply(checked_multiply(batch_size, channels, label), height, label), + width, + label + ); +} + +void add_bytes(uint64_t* bytes, uint64_t extra, const char* label) { + if (*bytes > std::numeric_limits::max() - extra) { + throw std::runtime_error(std::string("Bonsai VAE byte count overflow: ") + label); + } + *bytes += extra; +} + +void require_tensor_shape(const BonsaiNchwTensor& tensor, const char* label) { + if (tensor.batch_size == 0 || tensor.channels == 0 || tensor.height == 0 || tensor.width == 0) { + throw std::runtime_error(std::string("Bonsai VAE tensor has invalid shape: ") + label); + } + const uint64_t expected = tensor_size( + tensor.batch_size, + tensor.channels, + tensor.height, + tensor.width, + label + ); + if (tensor.values.size() != static_cast(expected)) { + throw std::runtime_error(std::string("Bonsai VAE tensor value count mismatch: ") + label); + } +} + +void require_same_shape( + const BonsaiNchwTensor& left, + const BonsaiNchwTensor& right, + const char* label +) { + require_tensor_shape(left, label); + require_tensor_shape(right, label); + if (left.batch_size != right.batch_size || + left.channels != right.channels || + left.height != right.height || + left.width != right.width) { + throw std::runtime_error(std::string("Bonsai VAE tensor shape mismatch: ") + label); + } +} + +size_t nchw_index( + const BonsaiNchwTensor& tensor, + uint64_t batch, + uint64_t channel, + uint64_t row, + uint64_t column +) { + return static_cast( + ((batch * tensor.channels + channel) * tensor.height + row) * tensor.width + column + ); +} + +std::vector to_attention_layout(const BonsaiNchwTensor& tensor) { + std::vector output; + output.reserve(static_cast(tensor_size( + tensor.batch_size, + tensor.channels, + tensor.height, + tensor.width, + "attention layout" + ))); + for (uint64_t batch = 0; batch < tensor.batch_size; batch++) { + for (uint64_t row = 0; row < tensor.height; row++) { + for (uint64_t column = 0; column < tensor.width; column++) { + for (uint64_t channel = 0; channel < tensor.channels; channel++) { + output.push_back(tensor.values[nchw_index(tensor, batch, channel, row, column)]); + } + } + } + } + return output; +} + +std::string bias_key_for_weight(const std::string& weight_key) { + const std::string suffix = ".weight"; + if (weight_key.size() < suffix.size() || + weight_key.compare(weight_key.size() - suffix.size(), suffix.size(), suffix) != 0) { + throw std::runtime_error("Bonsai VAE linear weight key must end with .weight: " + weight_key); + } + return weight_key.substr(0, weight_key.size() - suffix.size()) + ".bias"; +} + +BonsaiLinearViews require_vae_dense_linear_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + const std::string& fallback_prefix, + uint64_t channels +) { + if (channels == 0) { + throw std::runtime_error("Bonsai VAE attention linear channels must be positive: " + prefix); + } + + const std::string preferred_weight_key = prefix + ".weight"; + const std::string fallback_weight_key = fallback_prefix.empty() + ? std::string() + : fallback_prefix + ".weight"; + std::string weight_key = preferred_weight_key; + if (index.optional(weight_key) == nullptr) { + if (fallback_weight_key.empty() || index.optional(fallback_weight_key) == nullptr) { + throw std::runtime_error("missing Bonsai VAE attention linear weight: " + weight_key); + } + weight_key = fallback_weight_key; + } + + BonsaiLinearViews views = bonsai_require_dense_linear_views( + storage, + index, + weight_key, + bias_key_for_weight(weight_key) + ); + if (views.input_values != channels || views.output_rows != channels) { + throw std::runtime_error("Bonsai VAE attention linear shape mismatch: " + weight_key); + } + return views; +} + +void require_conv_view_shape( + const BonsaiVaeConv2dViews& views, + uint64_t input_channels, + uint64_t output_channels, + const std::string& prefix +) { + if (input_channels != 0 && views.input_channels != input_channels) { + throw std::runtime_error("Bonsai VAE conv input channel mismatch: " + prefix); + } + if (output_channels != 0 && views.output_channels != output_channels) { + throw std::runtime_error("Bonsai VAE conv output channel mismatch: " + prefix); + } +} + +BonsaiNchwTensor from_attention_layout( + const std::vector& values, + uint64_t batch_size, + uint64_t channels, + uint64_t height, + uint64_t width +) { + BonsaiNchwTensor output { + batch_size, + channels, + height, + width, + {}, + }; + output.values.assign( + static_cast(tensor_size(batch_size, channels, height, width, "attention output")), + 0.0F + ); + const uint64_t spatial_length = checked_multiply(height, width, "attention output"); + if (values.size() != static_cast(tensor_size( + batch_size, + spatial_length, + channels, + 1, + "attention output" + ))) { + throw std::runtime_error("Bonsai VAE attention output shape mismatch."); + } + for (uint64_t batch = 0; batch < batch_size; batch++) { + for (uint64_t spatial = 0; spatial < spatial_length; spatial++) { + const uint64_t row = spatial / width; + const uint64_t column = spatial % width; + for (uint64_t channel = 0; channel < channels; channel++) { + const size_t source_index = static_cast( + (batch * spatial_length + spatial) * channels + channel + ); + output.values[nchw_index(output, batch, channel, row, column)] = + values[source_index]; + } + } + } + return output; +} + +size_t weight_index( + uint64_t input_channels, + uint64_t kernel_height, + uint64_t kernel_width, + uint64_t output_channel, + uint64_t input_channel, + uint64_t kernel_row, + uint64_t kernel_column +) { + return static_cast( + ((output_channel * input_channels + input_channel) * kernel_height + kernel_row) * + kernel_width + + kernel_column + ); +} + +} // namespace + +BonsaiVaeConv2dViews bonsai_vae_require_conv2d_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix +) { + const std::string weight_key = prefix + ".weight"; + BonsaiTensorView weight = storage.require_view(index, weight_key); + if (!bonsai_dtype_is_floating_point(weight.dtype)) { + throw std::runtime_error("Bonsai VAE conv weight must be floating point: " + weight_key); + } + if (weight.descriptor->shape.size() != 4) { + throw std::runtime_error("Bonsai VAE conv weight must be 4D: " + weight_key); + } + + BonsaiVaeConv2dViews views { + weight, + {}, + false, + weight.descriptor->shape[0], + weight.descriptor->shape[1], + weight.descriptor->shape[2], + weight.descriptor->shape[3], + weight.descriptor->shape[2] == 1 ? 0U : 1U, + }; + if (views.output_channels == 0 || + views.input_channels == 0 || + views.kernel_height == 0 || + views.kernel_width == 0) { + throw std::runtime_error("Bonsai VAE conv weight shape must be positive: " + weight_key); + } + + const std::string bias_key = prefix + ".bias"; + const BonsaiTensorDescriptor* bias_descriptor = index.optional(bias_key); + if (bias_descriptor != nullptr) { + views.bias = storage.view(*bias_descriptor); + views.has_bias = true; + if (!bonsai_dtype_is_floating_point(views.bias.dtype)) { + throw std::runtime_error("Bonsai VAE conv bias must be floating point: " + bias_key); + } + if (views.bias.element_count != views.output_channels) { + throw std::runtime_error("Bonsai VAE conv bias size mismatch: " + bias_key); + } + } + return views; +} + +BonsaiNchwTensor bonsai_vae_conv2d_nchw( + const BonsaiNchwTensor& input, + const std::vector& weight, + uint64_t output_channels, + uint64_t kernel_height, + uint64_t kernel_width, + uint64_t padding, + const std::vector* bias +) { + require_tensor_shape(input, "conv input"); + if (output_channels == 0 || kernel_height == 0 || kernel_width == 0) { + throw std::runtime_error("Bonsai VAE conv dimensions must be positive."); + } + const uint64_t expected_weight = tensor_size( + output_channels, + input.channels, + kernel_height, + kernel_width, + "conv weight" + ); + if (weight.size() != static_cast(expected_weight)) { + throw std::runtime_error("Bonsai VAE conv weight size mismatch."); + } + if (bias != nullptr && bias->size() != static_cast(output_channels)) { + throw std::runtime_error("Bonsai VAE conv bias size mismatch."); + } + if (input.height + 2 * padding < kernel_height || input.width + 2 * padding < kernel_width) { + throw std::runtime_error("Bonsai VAE conv kernel is larger than padded input."); + } + + BonsaiNchwTensor output { + input.batch_size, + output_channels, + input.height + 2 * padding - kernel_height + 1, + input.width + 2 * padding - kernel_width + 1, + {}, + }; + output.values.assign( + static_cast( + tensor_size(output.batch_size, output.channels, output.height, output.width, "conv out") + ), + 0.0F + ); + + const uint64_t input_area = input.height * input.width; + const uint64_t output_area = output.height * output.width; + const uint64_t input_batch_stride = input.channels * input_area; + const uint64_t output_batch_stride = output.channels * output_area; + const float* input_values = input.values.data(); + const float* weight_values = weight.data(); + float* output_values = output.values.data(); + + const int64_t input_height = static_cast(input.height); + const int64_t input_width = static_cast(input.width); + const int64_t output_height = static_cast(output.height); + const int64_t output_width = static_cast(output.width); + const int64_t pad = static_cast(padding); + + for (uint64_t batch = 0; batch < output.batch_size; batch++) { + const uint64_t input_batch_offset = batch * input_batch_stride; + const uint64_t output_batch_offset = batch * output_batch_stride; + for (uint64_t output_channel = 0; output_channel < output.channels; output_channel++) { + float* output_plane = output_values + output_batch_offset + output_channel * output_area; + const float bias_value = bias == nullptr + ? 0.0F + : (*bias)[static_cast(output_channel)]; + std::fill(output_plane, output_plane + output_area, bias_value); + + for (uint64_t input_channel = 0; input_channel < input.channels; input_channel++) { + const float* input_plane = + input_values + input_batch_offset + input_channel * input_area; + const float* weight_plane = weight_values + weight_index( + input.channels, + kernel_height, + kernel_width, + output_channel, + input_channel, + 0, + 0 + ); + + for (uint64_t kernel_row = 0; kernel_row < kernel_height; kernel_row++) { + const int64_t input_row_offset = static_cast(kernel_row) - pad; + int64_t output_row_begin = 0; + int64_t output_row_end = output_height; + if (input_row_offset < 0) { + output_row_begin = -input_row_offset; + } + if (input_row_offset + output_height > input_height) { + output_row_end = input_height - input_row_offset; + } + if (output_row_begin >= output_row_end) { + continue; + } + + for (uint64_t kernel_column = 0; kernel_column < kernel_width; kernel_column++) { + const int64_t input_column_offset = + static_cast(kernel_column) - pad; + int64_t output_column_begin = 0; + int64_t output_column_end = output_width; + if (input_column_offset < 0) { + output_column_begin = -input_column_offset; + } + if (input_column_offset + output_width > input_width) { + output_column_end = input_width - input_column_offset; + } + if (output_column_begin >= output_column_end) { + continue; + } + + const float weight_value = weight_plane[ + static_cast(kernel_row * kernel_width + kernel_column) + ]; + for (int64_t output_row = output_row_begin; + output_row < output_row_end; + output_row++) { + const int64_t input_row = output_row + input_row_offset; + const int64_t input_column_start = + output_column_begin + input_column_offset; + const float* input_ptr = input_plane + + static_cast(input_row) * input.width + + static_cast(input_column_start); + float* output_ptr = output_plane + + static_cast(output_row) * output.width + + static_cast(output_column_begin); + for (int64_t output_column = output_column_begin; + output_column < output_column_end; + output_column++) { + *output_ptr += *input_ptr * weight_value; + output_ptr++; + input_ptr++; + } + } + } + } + } + } + } + return output; +} + +BonsaiNchwTensor bonsai_vae_conv2d_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeConv2dViews& views +) { + require_tensor_shape(input, "conv view input"); + if (input.channels != views.input_channels) { + throw std::runtime_error("Bonsai VAE conv view input channel mismatch."); + } + const std::vector weight = bonsai_tensor_view_to_f32_vector(views.weight); + const std::vector bias = views.has_bias + ? bonsai_tensor_view_to_f32_vector(views.bias) + : std::vector {}; + return bonsai_vae_conv2d_nchw( + input, + weight, + views.output_channels, + views.kernel_height, + views.kernel_width, + views.padding, + views.has_bias ? &bias : nullptr + ); +} + +uint64_t bonsai_vae_conv2d_byte_count(const BonsaiVaeConv2dViews& views) { + uint64_t bytes = views.weight.byte_count; + if (views.has_bias) { + if (bytes > std::numeric_limits::max() - views.bias.byte_count) { + throw std::runtime_error("Bonsai VAE conv byte count overflow."); + } + bytes += views.bias.byte_count; + } + return bytes; +} + +BonsaiVaeGroupNormViews bonsai_vae_require_group_norm_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t channels, + uint64_t group_count, + float epsilon +) { + if (channels == 0 || group_count == 0 || channels % group_count != 0) { + throw std::runtime_error("Bonsai VAE GroupNorm channel/group mismatch: " + prefix); + } + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai VAE GroupNorm epsilon must be finite and positive: " + prefix); + } + + const std::string weight_key = prefix + ".weight"; + const std::string bias_key = prefix + ".bias"; + BonsaiTensorView weight = storage.require_view(index, weight_key); + BonsaiTensorView bias = storage.require_view(index, bias_key); + if (!bonsai_dtype_is_floating_point(weight.dtype) || + !bonsai_dtype_is_floating_point(bias.dtype)) { + throw std::runtime_error("Bonsai VAE GroupNorm affine tensors must be floating point: " + prefix); + } + if (weight.element_count != channels || bias.element_count != channels) { + throw std::runtime_error("Bonsai VAE GroupNorm affine size mismatch: " + prefix); + } + + return BonsaiVaeGroupNormViews { + weight, + bias, + channels, + group_count, + epsilon, + }; +} + +BonsaiNchwTensor bonsai_vae_group_norm_nchw( + const BonsaiNchwTensor& input, + uint64_t group_count, + const std::vector& weight, + const std::vector& bias, + float epsilon +) { + require_tensor_shape(input, "group norm input"); + if (group_count == 0 || input.channels % group_count != 0) { + throw std::runtime_error("Bonsai VAE group norm channel/group mismatch."); + } + if (weight.size() != static_cast(input.channels) || + bias.size() != static_cast(input.channels)) { + throw std::runtime_error("Bonsai VAE group norm affine size mismatch."); + } + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai VAE group norm epsilon must be finite and positive."); + } + + BonsaiNchwTensor output = input; + output.values.assign(input.values.size(), 0.0F); + const uint64_t group_size = input.channels / group_count; + const uint64_t values_per_group = checked_multiply( + checked_multiply(group_size, input.height, "group norm"), + input.width, + "group norm" + ); + + for (uint64_t batch = 0; batch < input.batch_size; batch++) { + for (uint64_t group = 0; group < group_count; group++) { + double mean = 0.0; + for (uint64_t group_channel = 0; group_channel < group_size; group_channel++) { + const uint64_t channel = group * group_size + group_channel; + for (uint64_t row = 0; row < input.height; row++) { + for (uint64_t column = 0; column < input.width; column++) { + mean += static_cast( + input.values[nchw_index(input, batch, channel, row, column)] + ); + } + } + } + mean /= static_cast(values_per_group); + + double variance = 0.0; + for (uint64_t group_channel = 0; group_channel < group_size; group_channel++) { + const uint64_t channel = group * group_size + group_channel; + for (uint64_t row = 0; row < input.height; row++) { + for (uint64_t column = 0; column < input.width; column++) { + const double centered = + static_cast(input.values[nchw_index(input, batch, channel, row, column)]) - + mean; + variance += centered * centered; + } + } + } + variance /= static_cast(values_per_group); + const float scale = 1.0F / std::sqrt(static_cast(variance) + epsilon); + + for (uint64_t group_channel = 0; group_channel < group_size; group_channel++) { + const uint64_t channel = group * group_size + group_channel; + for (uint64_t row = 0; row < input.height; row++) { + for (uint64_t column = 0; column < input.width; column++) { + const size_t index = nchw_index(input, batch, channel, row, column); + output.values[index] = + (input.values[index] - static_cast(mean)) * + scale * + weight[static_cast(channel)] + + bias[static_cast(channel)]; + } + } + } + } + } + return output; +} + +BonsaiNchwTensor bonsai_vae_group_norm_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeGroupNormViews& views +) { + require_tensor_shape(input, "group norm view input"); + if (input.channels != views.channels) { + throw std::runtime_error("Bonsai VAE GroupNorm view input channel mismatch."); + } + return bonsai_vae_group_norm_nchw( + input, + views.group_count, + bonsai_tensor_view_to_f32_vector(views.weight), + bonsai_tensor_view_to_f32_vector(views.bias), + views.epsilon + ); +} + +uint64_t bonsai_vae_group_norm_byte_count(const BonsaiVaeGroupNormViews& views) { + if (views.weight.byte_count > std::numeric_limits::max() - views.bias.byte_count) { + throw std::runtime_error("Bonsai VAE GroupNorm byte count overflow."); + } + return views.weight.byte_count + views.bias.byte_count; +} + +BonsaiVaeAttentionViews bonsai_vae_require_attention_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t channels, + uint64_t group_count, + float epsilon +) { + if (channels == 0) { + throw std::runtime_error("Bonsai VAE attention channels must be positive: " + prefix); + } + + return BonsaiVaeAttentionViews { + bonsai_vae_require_group_norm_views( + storage, + index, + prefix + ".group_norm", + channels, + group_count, + epsilon + ), + require_vae_dense_linear_views(storage, index, prefix + ".to_q", "", channels), + require_vae_dense_linear_views(storage, index, prefix + ".to_k", "", channels), + require_vae_dense_linear_views(storage, index, prefix + ".to_v", "", channels), + require_vae_dense_linear_views(storage, index, prefix + ".to_out.0", prefix + ".to_out", channels), + channels, + 1.0F / std::sqrt(static_cast(channels)), + }; +} + +BonsaiNchwTensor bonsai_vae_attention_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeAttentionViews& views +) { + require_tensor_shape(input, "attention view input"); + if (input.channels != views.channels || + views.group_norm.channels != views.channels || + views.to_q.input_values != views.channels || + views.to_q.output_rows != views.channels || + views.to_k.input_values != views.channels || + views.to_k.output_rows != views.channels || + views.to_v.input_values != views.channels || + views.to_v.output_rows != views.channels || + views.to_out.input_values != views.channels || + views.to_out.output_rows != views.channels) { + throw std::runtime_error("Bonsai VAE attention view channel mismatch."); + } + if (!std::isfinite(views.scale) || views.scale <= 0.0F) { + throw std::runtime_error("Bonsai VAE attention view scale must be finite and positive."); + } + + const BonsaiNchwTensor normed = bonsai_vae_group_norm_view_nchw(input, views.group_norm); + const uint64_t spatial_length = checked_multiply(input.height, input.width, "attention view"); + const std::vector sequence = to_attention_layout(normed); + const std::vector query_sequence = bonsai_linear_sequence( + views.to_q, + sequence, + input.batch_size, + spatial_length + ); + const std::vector key_sequence = bonsai_linear_sequence( + views.to_k, + sequence, + input.batch_size, + spatial_length + ); + const std::vector value_sequence = bonsai_linear_sequence( + views.to_v, + sequence, + input.batch_size, + spatial_length + ); + const BonsaiNchwTensor queries = from_attention_layout( + query_sequence, + input.batch_size, + views.channels, + input.height, + input.width + ); + const BonsaiNchwTensor keys = from_attention_layout( + key_sequence, + input.batch_size, + views.channels, + input.height, + input.width + ); + const BonsaiNchwTensor values = from_attention_layout( + value_sequence, + input.batch_size, + views.channels, + input.height, + input.width + ); + const BonsaiNchwTensor attended = bonsai_vae_spatial_attention_nchw( + queries, + keys, + values, + views.scale + ); + const std::vector output_sequence = bonsai_linear_sequence( + views.to_out, + to_attention_layout(attended), + input.batch_size, + spatial_length + ); + return bonsai_vae_add_nchw( + input, + from_attention_layout( + output_sequence, + input.batch_size, + views.channels, + input.height, + input.width + ) + ); +} + +uint64_t bonsai_vae_attention_byte_count(const BonsaiVaeAttentionViews& views) { + uint64_t bytes = bonsai_vae_group_norm_byte_count(views.group_norm); + const std::vector parts { + bonsai_linear_byte_count(views.to_q), + bonsai_linear_byte_count(views.to_k), + bonsai_linear_byte_count(views.to_v), + bonsai_linear_byte_count(views.to_out), + }; + for (uint64_t part : parts) { + add_bytes(&bytes, part, "attention"); + } + return bytes; +} + +BonsaiVaeResnetViews bonsai_vae_require_resnet_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t input_channels, + uint64_t output_channels, + uint64_t group_count, + float epsilon +) { + if (input_channels == 0 || output_channels == 0) { + throw std::runtime_error("Bonsai VAE resnet channels must be positive: " + prefix); + } + + BonsaiVaeResnetViews views { + bonsai_vae_require_group_norm_views( + storage, + index, + prefix + ".norm1", + input_channels, + group_count, + epsilon + ), + bonsai_vae_require_conv2d_views(storage, index, prefix + ".conv1"), + bonsai_vae_require_group_norm_views( + storage, + index, + prefix + ".norm2", + output_channels, + group_count, + epsilon + ), + bonsai_vae_require_conv2d_views(storage, index, prefix + ".conv2"), + {}, + false, + input_channels, + output_channels, + }; + require_conv_view_shape(views.conv1, input_channels, output_channels, prefix + ".conv1"); + require_conv_view_shape(views.conv2, output_channels, output_channels, prefix + ".conv2"); + + const std::string shortcut_weight_key = prefix + ".conv_shortcut.weight"; + if (index.optional(shortcut_weight_key) != nullptr) { + views.shortcut = bonsai_vae_require_conv2d_views(storage, index, prefix + ".conv_shortcut"); + views.has_shortcut = true; + require_conv_view_shape( + views.shortcut, + input_channels, + output_channels, + prefix + ".conv_shortcut" + ); + } else if (input_channels != output_channels) { + throw std::runtime_error("Bonsai VAE resnet missing required shortcut: " + prefix); + } + + return views; +} + +BonsaiNchwTensor bonsai_vae_resnet_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeResnetViews& views +) { + require_tensor_shape(input, "resnet view input"); + if (input.channels != views.input_channels) { + throw std::runtime_error("Bonsai VAE resnet input channel mismatch."); + } + + BonsaiNchwTensor output = bonsai_vae_group_norm_view_nchw(input, views.norm1); + output.values = bonsai_silu(output.values); + output = bonsai_vae_conv2d_view_nchw(output, views.conv1); + output = bonsai_vae_group_norm_view_nchw(output, views.norm2); + output.values = bonsai_silu(output.values); + output = bonsai_vae_conv2d_view_nchw(output, views.conv2); + const BonsaiNchwTensor residual = views.has_shortcut + ? bonsai_vae_conv2d_view_nchw(input, views.shortcut) + : input; + return bonsai_vae_add_nchw(output, residual); +} + +uint64_t bonsai_vae_resnet_byte_count(const BonsaiVaeResnetViews& views) { + uint64_t bytes = bonsai_vae_group_norm_byte_count(views.norm1); + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.conv1), "resnet conv1"); + add_bytes(&bytes, bonsai_vae_group_norm_byte_count(views.norm2), "resnet norm2"); + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.conv2), "resnet conv2"); + if (views.has_shortcut) { + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.shortcut), "resnet shortcut"); + } + return bytes; +} + +BonsaiVaeUpBlockViews bonsai_vae_require_up_block_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t input_channels, + uint64_t output_channels, + uint64_t layer_count, + uint64_t group_count, + bool add_upsample, + float epsilon +) { + if (input_channels == 0 || output_channels == 0 || layer_count == 0) { + throw std::runtime_error("Bonsai VAE up block dimensions must be positive: " + prefix); + } + + BonsaiVaeUpBlockViews views; + views.input_channels = input_channels; + views.output_channels = output_channels; + views.resnets.reserve(static_cast(layer_count)); + for (uint64_t layer = 0; layer < layer_count; layer++) { + views.resnets.push_back(bonsai_vae_require_resnet_views( + storage, + index, + prefix + ".resnets." + std::to_string(layer), + layer == 0 ? input_channels : output_channels, + output_channels, + group_count, + epsilon + )); + } + + if (add_upsample) { + views.upsample = bonsai_vae_require_conv2d_views( + storage, + index, + prefix + ".upsamplers.0.conv" + ); + views.has_upsample = true; + require_conv_view_shape( + views.upsample, + output_channels, + output_channels, + prefix + ".upsamplers.0.conv" + ); + } + return views; +} + +BonsaiNchwTensor bonsai_vae_up_block_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeUpBlockViews& views +) { + require_tensor_shape(input, "up block view input"); + if (input.channels != views.input_channels) { + throw std::runtime_error("Bonsai VAE up block input channel mismatch."); + } + + BonsaiNchwTensor output = input; + for (const BonsaiVaeResnetViews& resnet : views.resnets) { + output = bonsai_vae_resnet_view_nchw(output, resnet); + } + if (views.has_upsample) { + output = bonsai_vae_upsample_nearest2x_nchw(output); + output = bonsai_vae_conv2d_view_nchw(output, views.upsample); + } + return output; +} + +uint64_t bonsai_vae_up_block_byte_count(const BonsaiVaeUpBlockViews& views) { + uint64_t bytes = 0; + for (const BonsaiVaeResnetViews& resnet : views.resnets) { + add_bytes(&bytes, bonsai_vae_resnet_byte_count(resnet), "up block resnet"); + } + if (views.has_upsample) { + add_bytes(&bytes, bonsai_vae_conv2d_byte_count(views.upsample), "up block upsample"); + } + return bytes; +} + +BonsaiNchwTensor bonsai_vae_spatial_attention_nchw( + const BonsaiNchwTensor& queries, + const BonsaiNchwTensor& keys, + const BonsaiNchwTensor& values, + float scale +) { + require_same_shape(queries, keys, "attention key"); + require_same_shape(queries, values, "attention value"); + if (!std::isfinite(scale) || scale <= 0.0F) { + throw std::runtime_error("Bonsai VAE attention scale must be finite and positive."); + } + + const uint64_t spatial_length = checked_multiply( + queries.height, + queries.width, + "attention" + ); + const std::vector output = bonsai_scaled_dot_product_attention( + to_attention_layout(queries), + to_attention_layout(keys), + to_attention_layout(values), + {}, + queries.batch_size, + spatial_length, + queries.channels, + scale + ); + return from_attention_layout( + output, + queries.batch_size, + queries.channels, + queries.height, + queries.width + ); +} + +BonsaiNchwTensor bonsai_vae_add_nchw( + const BonsaiNchwTensor& left, + const BonsaiNchwTensor& right +) { + require_same_shape(left, right, "add"); + BonsaiNchwTensor output = left; + for (size_t index = 0; index < output.values.size(); index++) { + output.values[index] = left.values[index] + right.values[index]; + } + return output; +} + +BonsaiNchwTensor bonsai_vae_upsample_nearest2x_nchw( + const BonsaiNchwTensor& input +) { + require_tensor_shape(input, "upsample input"); + BonsaiNchwTensor output { + input.batch_size, + input.channels, + input.height * 2, + input.width * 2, + {}, + }; + output.values.assign( + static_cast( + tensor_size(output.batch_size, output.channels, output.height, output.width, "upsample") + ), + 0.0F + ); + + for (uint64_t batch = 0; batch < input.batch_size; batch++) { + for (uint64_t channel = 0; channel < input.channels; channel++) { + for (uint64_t row = 0; row < output.height; row++) { + for (uint64_t column = 0; column < output.width; column++) { + output.values[nchw_index(output, batch, channel, row, column)] = + input.values[nchw_index(input, batch, channel, row / 2, column / 2)]; + } + } + } + } + return output; +} + +BonsaiNchwTensor bonsai_vae_denormalize_channels_nchw( + const BonsaiNchwTensor& input, + const std::vector& mean, + const std::vector& variance, + float epsilon +) { + require_tensor_shape(input, "denormalize input"); + if (mean.size() != static_cast(input.channels) || + variance.size() != static_cast(input.channels)) { + throw std::runtime_error("Bonsai VAE denormalize channel size mismatch."); + } + if (epsilon <= 0.0F || !std::isfinite(epsilon)) { + throw std::runtime_error("Bonsai VAE denormalize epsilon must be finite and positive."); + } + + BonsaiNchwTensor output = input; + for (uint64_t batch = 0; batch < input.batch_size; batch++) { + for (uint64_t channel = 0; channel < input.channels; channel++) { + const float stddev = std::sqrt(variance[static_cast(channel)] + epsilon); + const float offset = mean[static_cast(channel)]; + for (uint64_t row = 0; row < input.height; row++) { + for (uint64_t column = 0; column < input.width; column++) { + const size_t index = nchw_index(input, batch, channel, row, column); + output.values[index] = input.values[index] * stddev + offset; + } + } + } + } + return output; +} + +BonsaiNchwTensor bonsai_vae_unpatchify_nchw( + const BonsaiNchwTensor& input +) { + require_tensor_shape(input, "unpatchify input"); + if (input.channels % 4 != 0) { + throw std::runtime_error("Bonsai VAE unpatchify channels must be divisible by 4."); + } + + BonsaiNchwTensor output { + input.batch_size, + input.channels / 4, + input.height * 2, + input.width * 2, + {}, + }; + output.values.assign( + static_cast( + tensor_size( + output.batch_size, + output.channels, + output.height, + output.width, + "unpatchify" + ) + ), + 0.0F + ); + + for (uint64_t batch = 0; batch < input.batch_size; batch++) { + for (uint64_t channel = 0; channel < output.channels; channel++) { + for (uint64_t patch_row = 0; patch_row < 2; patch_row++) { + for (uint64_t patch_column = 0; patch_column < 2; patch_column++) { + const uint64_t input_channel = ((channel * 2) + patch_row) * 2 + patch_column; + for (uint64_t row = 0; row < input.height; row++) { + for (uint64_t column = 0; column < input.width; column++) { + output.values[nchw_index( + output, + batch, + channel, + row * 2 + patch_row, + column * 2 + patch_column + )] = input.values[nchw_index( + input, + batch, + input_channel, + row, + column + )]; + } + } + } + } + } + } + return output; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.h b/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.h new file mode 100644 index 000000000..a8356a95b --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vae_ops.h @@ -0,0 +1,191 @@ +#pragma once + +#include "bonsai_linear.h" +#include "bonsai_safetensors.h" +#include "bonsai_tensor_storage.h" + +#include +#include +#include + +struct BonsaiNchwTensor { + uint64_t batch_size = 0; + uint64_t channels = 0; + uint64_t height = 0; + uint64_t width = 0; + std::vector values; +}; + +struct BonsaiVaeConv2dViews { + BonsaiTensorView weight; + BonsaiTensorView bias; + bool has_bias = false; + uint64_t output_channels = 0; + uint64_t input_channels = 0; + uint64_t kernel_height = 0; + uint64_t kernel_width = 0; + uint64_t padding = 0; +}; + +struct BonsaiVaeGroupNormViews { + BonsaiTensorView weight; + BonsaiTensorView bias; + uint64_t channels = 0; + uint64_t group_count = 0; + float epsilon = 1e-6F; +}; + +struct BonsaiVaeAttentionViews { + BonsaiVaeGroupNormViews group_norm; + BonsaiLinearViews to_q; + BonsaiLinearViews to_k; + BonsaiLinearViews to_v; + BonsaiLinearViews to_out; + uint64_t channels = 0; + float scale = 0.0F; +}; + +struct BonsaiVaeResnetViews { + BonsaiVaeGroupNormViews norm1; + BonsaiVaeConv2dViews conv1; + BonsaiVaeGroupNormViews norm2; + BonsaiVaeConv2dViews conv2; + BonsaiVaeConv2dViews shortcut; + bool has_shortcut = false; + uint64_t input_channels = 0; + uint64_t output_channels = 0; +}; + +struct BonsaiVaeUpBlockViews { + std::vector resnets; + BonsaiVaeConv2dViews upsample; + bool has_upsample = false; + uint64_t input_channels = 0; + uint64_t output_channels = 0; +}; + +BonsaiVaeConv2dViews bonsai_vae_require_conv2d_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix +); + +BonsaiNchwTensor bonsai_vae_conv2d_nchw( + const BonsaiNchwTensor& input, + const std::vector& weight, + uint64_t output_channels, + uint64_t kernel_height, + uint64_t kernel_width, + uint64_t padding, + const std::vector* bias +); + +BonsaiNchwTensor bonsai_vae_conv2d_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeConv2dViews& views +); + +uint64_t bonsai_vae_conv2d_byte_count(const BonsaiVaeConv2dViews& views); + +BonsaiVaeGroupNormViews bonsai_vae_require_group_norm_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t channels, + uint64_t group_count, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_group_norm_nchw( + const BonsaiNchwTensor& input, + uint64_t group_count, + const std::vector& weight, + const std::vector& bias, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_group_norm_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeGroupNormViews& views +); + +uint64_t bonsai_vae_group_norm_byte_count(const BonsaiVaeGroupNormViews& views); + +BonsaiVaeAttentionViews bonsai_vae_require_attention_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t channels, + uint64_t group_count, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_attention_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeAttentionViews& views +); + +uint64_t bonsai_vae_attention_byte_count(const BonsaiVaeAttentionViews& views); + +BonsaiVaeResnetViews bonsai_vae_require_resnet_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t input_channels, + uint64_t output_channels, + uint64_t group_count, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_resnet_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeResnetViews& views +); + +uint64_t bonsai_vae_resnet_byte_count(const BonsaiVaeResnetViews& views); + +BonsaiVaeUpBlockViews bonsai_vae_require_up_block_views( + const BonsaiTensorStorage& storage, + const BonsaiSafetensorsIndex& index, + const std::string& prefix, + uint64_t input_channels, + uint64_t output_channels, + uint64_t layer_count, + uint64_t group_count, + bool add_upsample, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_up_block_view_nchw( + const BonsaiNchwTensor& input, + const BonsaiVaeUpBlockViews& views +); + +uint64_t bonsai_vae_up_block_byte_count(const BonsaiVaeUpBlockViews& views); + +BonsaiNchwTensor bonsai_vae_spatial_attention_nchw( + const BonsaiNchwTensor& queries, + const BonsaiNchwTensor& keys, + const BonsaiNchwTensor& values, + float scale +); + +BonsaiNchwTensor bonsai_vae_add_nchw( + const BonsaiNchwTensor& left, + const BonsaiNchwTensor& right +); + +BonsaiNchwTensor bonsai_vae_upsample_nearest2x_nchw( + const BonsaiNchwTensor& input +); + +BonsaiNchwTensor bonsai_vae_denormalize_channels_nchw( + const BonsaiNchwTensor& input, + const std::vector& mean, + const std::vector& variance, + float epsilon +); + +BonsaiNchwTensor bonsai_vae_unpatchify_nchw( + const BonsaiNchwTensor& input +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.cpp b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.cpp new file mode 100644 index 000000000..6333611b4 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.cpp @@ -0,0 +1,1232 @@ +#include "bonsai_vulkan.h" + +#include "bonsai_tensor.h" +#include "bonsai_tensor_storage.h" +#include "bonsai_vulkan_quantized_matvec_spv.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr const char* LOG_TAG = "SDAI-Bonsai"; +constexpr uint64_t BYTES_IN_MB = 1024ULL * 1024ULL; +constexpr uint32_t MIN_STORAGE_BUFFER_RANGE = 64U * 1024U * 1024U; +constexpr uint64_t STATIC_CACHE_MAX_BYTES = 96ULL * BYTES_IN_MB; +constexpr uint64_t SHARED_QUEUE_MAX_SEQUENCE_TOKENS = 16U; +constexpr float COMPUTE_QUEUE_PRIORITY = 0.35F; + +struct BonsaiVulkanBuffer { + VkBuffer buffer = VK_NULL_HANDLE; + VkDeviceMemory memory = VK_NULL_HANDLE; + VkDeviceSize size = 0; + bool coherent = false; +}; + +struct BonsaiVulkanParams { + uint32_t input_values = 0; + uint32_t packed_columns = 0; + uint32_t scale_groups = 0; + uint32_t bits = 0; + uint32_t group_size = 0; + uint32_t output_rows = 0; +}; + +enum class BonsaiVulkanStaticBufferKind { + RawView, + F32View, +}; + +struct BonsaiVulkanStaticBufferKey { + const uint8_t* data = nullptr; + uint64_t source_byte_count = 0; + uint64_t element_count = 0; + BonsaiDType dtype = BonsaiDType::Bool; + BonsaiVulkanStaticBufferKind kind = BonsaiVulkanStaticBufferKind::RawView; +}; + +struct BonsaiVulkanStaticBufferEntry { + BonsaiVulkanStaticBufferKey key; + BonsaiVulkanBuffer buffer; + VkDeviceSize byte_count = 0; + uint64_t last_used = 0; +}; + +std::atomic_bool g_logged_available(false); +std::atomic_bool g_logged_cache(false); +std::atomic_bool g_logged_success(false); +std::atomic_bool g_logged_fallback(false); +std::atomic_bool g_logged_runtime_disabled(false); +std::atomic g_backend_mode(static_cast(BonsaiVulkanBackendMode::Auto)); + +std::string bool_value(bool value) { + return value ? "true" : "false"; +} + +std::string version_string(uint32_t version) { + std::ostringstream output; + output << VK_VERSION_MAJOR(version) << "." + << VK_VERSION_MINOR(version) << "." + << VK_VERSION_PATCH(version); + return output.str(); +} + +std::string vk_result_name(VkResult result) { + switch (result) { + case VK_SUCCESS: + return "VK_SUCCESS"; + case VK_TIMEOUT: + return "VK_TIMEOUT"; + case VK_ERROR_OUT_OF_HOST_MEMORY: + return "VK_ERROR_OUT_OF_HOST_MEMORY"; + case VK_ERROR_OUT_OF_DEVICE_MEMORY: + return "VK_ERROR_OUT_OF_DEVICE_MEMORY"; + case VK_ERROR_INITIALIZATION_FAILED: + return "VK_ERROR_INITIALIZATION_FAILED"; + case VK_ERROR_DEVICE_LOST: + return "VK_ERROR_DEVICE_LOST"; + case VK_ERROR_MEMORY_MAP_FAILED: + return "VK_ERROR_MEMORY_MAP_FAILED"; + case VK_ERROR_FEATURE_NOT_PRESENT: + return "VK_ERROR_FEATURE_NOT_PRESENT"; + case VK_ERROR_INCOMPATIBLE_DRIVER: + return "VK_ERROR_INCOMPATIBLE_DRIVER"; + case VK_ERROR_FORMAT_NOT_SUPPORTED: + return "VK_ERROR_FORMAT_NOT_SUPPORTED"; + default: + return "VK_RESULT_" + std::to_string(static_cast(result)); + } +} + +std::string lowercase_ascii(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char character) { + return static_cast(std::tolower(character)); + }); + return value; +} + +bool is_software_vulkan_device(const VkPhysicalDeviceProperties& properties) { + const std::string device_name = lowercase_ascii(properties.deviceName); + return properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU || + device_name.find("llvmpipe") != std::string::npos || + device_name.find("swiftshader") != std::string::npos || + device_name.find("software") != std::string::npos; +} + +BonsaiVulkanBackendMode current_backend_mode() { + return static_cast(g_backend_mode.load()); +} + +const char* backend_mode_name(BonsaiVulkanBackendMode mode) { + switch (mode) { + case BonsaiVulkanBackendMode::Auto: + return "auto"; + case BonsaiVulkanBackendMode::Cpu: + return "cpu"; + case BonsaiVulkanBackendMode::Vulkan: + return "vulkan"; + } + return "auto"; +} + +uint32_t loader_api_version() { + uint32_t version = VK_API_VERSION_1_0; + const auto enumerate_instance_version = + reinterpret_cast( + vkGetInstanceProcAddr(nullptr, "vkEnumerateInstanceVersion") + ); + if (enumerate_instance_version != nullptr && + enumerate_instance_version(&version) != VK_SUCCESS) { + return VK_API_VERSION_1_0; + } + return version; +} + +bool find_compute_queue_family( + VkPhysicalDevice physical_device, + uint32_t& queue_family_index, + bool& compute_only +) { + uint32_t queue_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_count, nullptr); + if (queue_count == 0) { + return false; + } + + std::vector queues(queue_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_count, queues.data()); + int32_t fallback_index = -1; + for (uint32_t index = 0; index < queue_count; ++index) { + const VkQueueFamilyProperties& queue = queues[index]; + if ((queue.queueFlags & VK_QUEUE_COMPUTE_BIT) != 0 && queue.queueCount > 0) { + if ((queue.queueFlags & VK_QUEUE_GRAPHICS_BIT) == 0) { + queue_family_index = index; + compute_only = true; + return true; + } + if (fallback_index < 0) { + fallback_index = static_cast(index); + } + } + } + if (fallback_index >= 0) { + queue_family_index = static_cast(fallback_index); + compute_only = false; + return true; + } + return false; +} + +int32_t find_memory_type( + VkPhysicalDevice physical_device, + uint32_t type_bits, + VkMemoryPropertyFlags required_properties +) { + VkPhysicalDeviceMemoryProperties memory_properties {}; + vkGetPhysicalDeviceMemoryProperties(physical_device, &memory_properties); + for (uint32_t index = 0; index < memory_properties.memoryTypeCount; ++index) { + const bool type_supported = (type_bits & (1U << index)) != 0; + const bool properties_supported = + (memory_properties.memoryTypes[index].propertyFlags & required_properties) == + required_properties; + if (type_supported && properties_supported) { + return static_cast(index); + } + } + return -1; +} + +uint64_t last_dimension(const BonsaiTensorView& view) { + if (view.descriptor == nullptr || view.descriptor->shape.empty()) { + return 0; + } + return view.descriptor->shape.back(); +} + +bool fits_u32(uint64_t value) { + return value <= static_cast(std::numeric_limits::max()); +} + +bool same_static_buffer_key( + const BonsaiVulkanStaticBufferKey& left, + const BonsaiVulkanStaticBufferKey& right +) { + return left.data == right.data && + left.source_byte_count == right.source_byte_count && + left.element_count == right.element_count && + left.dtype == right.dtype && + left.kind == right.kind; +} + +void destroy_buffer(VkDevice device, BonsaiVulkanBuffer& buffer) { + if (buffer.buffer != VK_NULL_HANDLE) { + vkDestroyBuffer(device, buffer.buffer, nullptr); + buffer.buffer = VK_NULL_HANDLE; + } + if (buffer.memory != VK_NULL_HANDLE) { + vkFreeMemory(device, buffer.memory, nullptr); + buffer.memory = VK_NULL_HANDLE; + } +} + +class BonsaiVulkanRuntime { +public: + ~BonsaiVulkanRuntime() { + if (device_ != VK_NULL_HANDLE) { + clear_static_cache(); + if (command_pool_ != VK_NULL_HANDLE) { + vkDestroyCommandPool(device_, command_pool_, nullptr); + } + if (pipeline_ != VK_NULL_HANDLE) { + vkDestroyPipeline(device_, pipeline_, nullptr); + } + if (shader_module_ != VK_NULL_HANDLE) { + vkDestroyShaderModule(device_, shader_module_, nullptr); + } + if (pipeline_layout_ != VK_NULL_HANDLE) { + vkDestroyPipelineLayout(device_, pipeline_layout_, nullptr); + } + if (descriptor_set_layout_ != VK_NULL_HANDLE) { + vkDestroyDescriptorSetLayout(device_, descriptor_set_layout_, nullptr); + } + vkDestroyDevice(device_, nullptr); + } + if (instance_ != VK_NULL_HANDLE) { + vkDestroyInstance(instance_, nullptr); + } + } + + bool available() { + std::call_once(init_once_, [this]() { + available_ = initialize(); + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_runtime available=%s reason=%s device=%s api=%s maxStorageBufferRangeMb=%llu queueFamily=%u computeOnly=%s queuePriority=%.2f", + bool_value(available_).c_str(), + init_reason_.c_str(), + device_name_.c_str(), + api_version_.c_str(), + static_cast(max_storage_buffer_range_ / BYTES_IN_MB), + queue_family_index_, + bool_value(queue_family_compute_only_).c_str(), + COMPUTE_QUEUE_PRIORITY + ); + }); + return available_; + } + + bool quantized_matvec_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output, + std::string& reason + ) { + return quantized_matvec_sequence_into(views, input, output, 1, reason); + } + + bool quantized_matvec_sequence_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output, + uint64_t token_count, + std::string& reason + ) { + if (current_backend_mode() == BonsaiVulkanBackendMode::Cpu) { + reason = "backend_cpu"; + return false; + } + if (!available()) { + reason = init_reason_; + return false; + } + if (disabled_after_device_loss_.load()) { + reason = "vulkan_disabled_after_device_lost"; + return false; + } + if (!supports_views(views, reason)) { + return false; + } + if (!supports_sequence(views, token_count, reason)) { + return false; + } + + const uint64_t scale_groups = last_dimension(views.scales); + const BonsaiVulkanParams params { + static_cast(views.input_values), + static_cast(last_dimension(views.weight)), + static_cast(scale_groups), + static_cast(views.bits), + static_cast(views.group_size), + static_cast(views.leading_rows), + }; + + std::lock_guard lock(queue_mutex_); + return dispatch_quantized_matvec( + views, + input, + output, + token_count, + params, + reason + ); + } + +private: + bool initialize() { + const uint32_t loader_version = loader_api_version(); + VkApplicationInfo app_info {}; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pApplicationName = "SDAI Bonsai Runtime"; + app_info.applicationVersion = 1; + app_info.pEngineName = "SDAI"; + app_info.engineVersion = 1; + app_info.apiVersion = loader_version >= VK_API_VERSION_1_1 + ? VK_API_VERSION_1_1 + : VK_API_VERSION_1_0; + + VkInstanceCreateInfo instance_info {}; + instance_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instance_info.pApplicationInfo = &app_info; + VkResult result = vkCreateInstance(&instance_info, nullptr, &instance_); + if (result != VK_SUCCESS) { + init_reason_ = "create_instance_failed:" + vk_result_name(result); + return false; + } + + uint32_t device_count = 0; + result = vkEnumeratePhysicalDevices(instance_, &device_count, nullptr); + if (result != VK_SUCCESS || device_count == 0) { + init_reason_ = "no_physical_devices:" + vk_result_name(result); + return false; + } + + std::vector devices(device_count); + result = vkEnumeratePhysicalDevices(instance_, &device_count, devices.data()); + if (result != VK_SUCCESS) { + init_reason_ = "enumerate_devices_failed:" + vk_result_name(result); + return false; + } + + for (VkPhysicalDevice candidate : devices) { + if (select_physical_device(candidate)) { + break; + } + } + if (physical_device_ == VK_NULL_HANDLE) { + init_reason_ = last_candidate_rejection_.empty() + ? "no_usable_compute_device" + : last_candidate_rejection_; + return false; + } + + if (!create_device()) { + return false; + } + if (!create_pipeline()) { + return false; + } + init_reason_ = "ok"; + return true; + } + + bool select_physical_device(VkPhysicalDevice candidate) { + VkPhysicalDeviceProperties properties {}; + vkGetPhysicalDeviceProperties(candidate, &properties); + if (is_software_vulkan_device(properties)) { + last_candidate_rejection_ = "software_vulkan_device:"; + last_candidate_rejection_ += properties.deviceName; + return false; + } + uint32_t candidate_queue_family = 0; + bool candidate_compute_only = false; + if (properties.apiVersion < VK_API_VERSION_1_1 || + properties.limits.maxStorageBufferRange < MIN_STORAGE_BUFFER_RANGE || + !find_compute_queue_family( + candidate, + candidate_queue_family, + candidate_compute_only + )) { + return false; + } + physical_device_ = candidate; + queue_family_index_ = candidate_queue_family; + queue_family_compute_only_ = candidate_compute_only; + max_storage_buffer_range_ = properties.limits.maxStorageBufferRange; + max_workgroup_count_x_ = properties.limits.maxComputeWorkGroupCount[0]; + max_workgroup_count_y_ = properties.limits.maxComputeWorkGroupCount[1]; + device_name_ = properties.deviceName; + api_version_ = version_string(properties.apiVersion); + return true; + } + + bool create_device() { + const float priority = COMPUTE_QUEUE_PRIORITY; + VkDeviceQueueCreateInfo queue_info {}; + queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_info.queueFamilyIndex = queue_family_index_; + queue_info.queueCount = 1; + queue_info.pQueuePriorities = &priority; + + VkDeviceCreateInfo device_info {}; + device_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_info.queueCreateInfoCount = 1; + device_info.pQueueCreateInfos = &queue_info; + VkResult result = vkCreateDevice(physical_device_, &device_info, nullptr, &device_); + if (result != VK_SUCCESS) { + init_reason_ = "create_device_failed:" + vk_result_name(result); + return false; + } + vkGetDeviceQueue(device_, queue_family_index_, 0, &queue_); + if (queue_ == VK_NULL_HANDLE) { + init_reason_ = "get_queue_failed"; + return false; + } + + VkCommandPoolCreateInfo command_pool_info {}; + command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + command_pool_info.queueFamilyIndex = queue_family_index_; + result = vkCreateCommandPool(device_, &command_pool_info, nullptr, &command_pool_); + if (result != VK_SUCCESS) { + init_reason_ = "create_command_pool_failed:" + vk_result_name(result); + return false; + } + return true; + } + + bool create_pipeline() { + VkDescriptorSetLayoutBinding bindings[5] {}; + for (uint32_t index = 0; index < 5U; ++index) { + bindings[index].binding = index; + bindings[index].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bindings[index].descriptorCount = 1; + bindings[index].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + } + + VkDescriptorSetLayoutCreateInfo set_layout_info {}; + set_layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + set_layout_info.bindingCount = 5; + set_layout_info.pBindings = bindings; + VkResult result = vkCreateDescriptorSetLayout( + device_, + &set_layout_info, + nullptr, + &descriptor_set_layout_ + ); + if (result != VK_SUCCESS) { + init_reason_ = "create_descriptor_set_layout_failed:" + vk_result_name(result); + return false; + } + + VkPushConstantRange push_constant {}; + push_constant.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + push_constant.offset = 0; + push_constant.size = sizeof(BonsaiVulkanParams); + + VkPipelineLayoutCreateInfo pipeline_layout_info {}; + pipeline_layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + pipeline_layout_info.setLayoutCount = 1; + pipeline_layout_info.pSetLayouts = &descriptor_set_layout_; + pipeline_layout_info.pushConstantRangeCount = 1; + pipeline_layout_info.pPushConstantRanges = &push_constant; + result = vkCreatePipelineLayout( + device_, + &pipeline_layout_info, + nullptr, + &pipeline_layout_ + ); + if (result != VK_SUCCESS) { + init_reason_ = "create_pipeline_layout_failed:" + vk_result_name(result); + return false; + } + + VkShaderModuleCreateInfo shader_info {}; + shader_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_info.codeSize = kBonsaiVulkanQuantizedMatvecSpvSize; + shader_info.pCode = reinterpret_cast(kBonsaiVulkanQuantizedMatvecSpv); + result = vkCreateShaderModule(device_, &shader_info, nullptr, &shader_module_); + if (result != VK_SUCCESS) { + init_reason_ = "create_shader_module_failed:" + vk_result_name(result); + return false; + } + + VkComputePipelineCreateInfo pipeline_info {}; + pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeline_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeline_info.stage.module = shader_module_; + pipeline_info.stage.pName = "main"; + pipeline_info.layout = pipeline_layout_; + result = vkCreateComputePipelines( + device_, + VK_NULL_HANDLE, + 1, + &pipeline_info, + nullptr, + &pipeline_ + ); + if (result != VK_SUCCESS) { + init_reason_ = "create_compute_pipeline_failed:" + vk_result_name(result); + return false; + } + return true; + } + + bool supports_views(const BonsaiPackedWeightViews& views, std::string& reason) const { + if (!views.packed || views.weight.dtype != BonsaiDType::U32) { + reason = "unsupported_weight_layout"; + return false; + } + if (views.bits != 1 && views.bits != 2 && views.bits != 4) { + reason = "unsupported_bits"; + return false; + } + if (views.group_size <= 0 || views.input_values == 0 || views.leading_rows == 0) { + reason = "invalid_shape"; + return false; + } + if (!fits_u32(views.input_values) || + !fits_u32(views.leading_rows) || + !fits_u32(last_dimension(views.weight)) || + !fits_u32(last_dimension(views.scales)) || + !fits_u32(static_cast(views.group_size))) { + reason = "shape_too_large"; + return false; + } + if (views.leading_rows > max_workgroup_count_x_) { + reason = "too_many_rows_for_dispatch"; + return false; + } + if (views.weight.byte_count > max_storage_buffer_range_ || + views.scales.element_count * sizeof(float) > max_storage_buffer_range_ || + views.biases.element_count * sizeof(float) > max_storage_buffer_range_ || + views.input_values * sizeof(float) > max_storage_buffer_range_ || + views.leading_rows * sizeof(float) > max_storage_buffer_range_) { + reason = "buffer_range_too_large"; + return false; + } + return true; + } + + bool supports_sequence( + const BonsaiPackedWeightViews& views, + uint64_t token_count, + std::string& reason + ) const { + if (token_count == 0 || !fits_u32(token_count)) { + reason = "invalid_token_count"; + return false; + } + if (token_count > max_workgroup_count_y_) { + reason = "too_many_tokens_for_dispatch"; + return false; + } + if (!queue_family_compute_only_ && token_count > SHARED_QUEUE_MAX_SEQUENCE_TOKENS) { + reason = "shared_queue_token_limit"; + return false; + } + if (views.input_values > std::numeric_limits::max() / token_count || + views.leading_rows > std::numeric_limits::max() / token_count) { + reason = "sequence_shape_too_large"; + return false; + } + + const uint64_t input_elements = views.input_values * token_count; + const uint64_t output_elements = views.leading_rows * token_count; + if (input_elements > max_storage_buffer_range_ / sizeof(float) || + output_elements > max_storage_buffer_range_ / sizeof(float)) { + reason = "sequence_buffer_range_too_large"; + return false; + } + return true; + } + + bool create_host_storage_buffer( + VkDeviceSize size, + BonsaiVulkanBuffer& output, + std::string& reason + ) { + output.size = size; + + VkBufferCreateInfo buffer_info {}; + buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + buffer_info.size = size; + buffer_info.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + VkResult result = vkCreateBuffer(device_, &buffer_info, nullptr, &output.buffer); + if (result != VK_SUCCESS) { + reason = "create_buffer_failed:" + vk_result_name(result); + return false; + } + + VkMemoryRequirements requirements {}; + vkGetBufferMemoryRequirements(device_, output.buffer, &requirements); + int32_t memory_type = find_memory_type( + physical_device_, + requirements.memoryTypeBits, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT + ); + output.coherent = memory_type >= 0; + if (memory_type < 0) { + memory_type = find_memory_type( + physical_device_, + requirements.memoryTypeBits, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT + ); + } + if (memory_type < 0) { + reason = "missing_host_visible_memory"; + return false; + } + + VkMemoryAllocateInfo allocate_info {}; + allocate_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + allocate_info.allocationSize = requirements.size; + allocate_info.memoryTypeIndex = static_cast(memory_type); + result = vkAllocateMemory(device_, &allocate_info, nullptr, &output.memory); + if (result != VK_SUCCESS) { + reason = "allocate_memory_failed:" + vk_result_name(result); + return false; + } + + result = vkBindBufferMemory(device_, output.buffer, output.memory, 0); + if (result != VK_SUCCESS) { + reason = "bind_buffer_failed:" + vk_result_name(result); + return false; + } + return true; + } + + bool write_buffer( + const BonsaiVulkanBuffer& buffer, + const void* source, + size_t byte_count, + std::string& reason + ) { + void* mapped = nullptr; + VkResult result = vkMapMemory(device_, buffer.memory, 0, byte_count, 0, &mapped); + if (result != VK_SUCCESS) { + reason = "map_write_failed:" + vk_result_name(result); + return false; + } + std::memcpy(mapped, source, byte_count); + if (!buffer.coherent) { + VkMappedMemoryRange range {}; + range.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + range.memory = buffer.memory; + range.offset = 0; + range.size = VK_WHOLE_SIZE; + result = vkFlushMappedMemoryRanges(device_, 1, &range); + if (result != VK_SUCCESS) { + vkUnmapMemory(device_, buffer.memory); + reason = "flush_write_failed:" + vk_result_name(result); + return false; + } + } + vkUnmapMemory(device_, buffer.memory); + return true; + } + + bool read_buffer( + const BonsaiVulkanBuffer& buffer, + void* destination, + size_t byte_count, + std::string& reason + ) { + if (!buffer.coherent) { + VkMappedMemoryRange range {}; + range.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + range.memory = buffer.memory; + range.offset = 0; + range.size = VK_WHOLE_SIZE; + VkResult result = vkInvalidateMappedMemoryRanges(device_, 1, &range); + if (result != VK_SUCCESS) { + reason = "invalidate_read_failed:" + vk_result_name(result); + return false; + } + } + + void* mapped = nullptr; + VkResult result = vkMapMemory(device_, buffer.memory, 0, byte_count, 0, &mapped); + if (result != VK_SUCCESS) { + reason = "map_read_failed:" + vk_result_name(result); + return false; + } + std::memcpy(destination, mapped, byte_count); + vkUnmapMemory(device_, buffer.memory); + return true; + } + + BonsaiVulkanStaticBufferEntry* find_static_buffer( + const BonsaiVulkanStaticBufferKey& key + ) { + for (const std::unique_ptr& entry : static_cache_) { + if (same_static_buffer_key(entry->key, key)) { + entry->last_used = ++static_cache_tick_; + return entry.get(); + } + } + return nullptr; + } + + BonsaiVulkanStaticBufferEntry* cached_static_buffer( + const BonsaiVulkanStaticBufferKey& key, + VkDeviceSize byte_count, + const void* source, + std::string& reason + ) { + if (byte_count == 0 || source == nullptr) { + reason = "empty_static_buffer"; + return nullptr; + } + + BonsaiVulkanStaticBufferEntry* cached = find_static_buffer(key); + if (cached != nullptr) { + return cached; + } + + auto entry = std::make_unique(); + entry->key = key; + entry->byte_count = byte_count; + entry->last_used = ++static_cache_tick_; + if (!create_host_storage_buffer(byte_count, entry->buffer, reason) || + !write_buffer(entry->buffer, source, static_cast(byte_count), reason)) { + destroy_buffer(device_, entry->buffer); + return nullptr; + } + + BonsaiVulkanStaticBufferEntry* pointer = entry.get(); + static_cache_bytes_ += static_cast(byte_count); + static_cache_.push_back(std::move(entry)); + log_static_cache_once(); + return pointer; + } + + BonsaiVulkanStaticBufferEntry* cached_raw_view_buffer( + const BonsaiTensorView& view, + std::string& reason + ) { + BonsaiVulkanStaticBufferKey key { + view.data, + view.byte_count, + view.element_count, + view.dtype, + BonsaiVulkanStaticBufferKind::RawView, + }; + return cached_static_buffer( + key, + static_cast(view.byte_count), + view.data, + reason + ); + } + + BonsaiVulkanStaticBufferEntry* cached_f32_view_buffer( + const BonsaiTensorView& view, + std::string& reason + ) { + BonsaiVulkanStaticBufferKey key { + view.data, + view.byte_count, + view.element_count, + view.dtype, + BonsaiVulkanStaticBufferKind::F32View, + }; + BonsaiVulkanStaticBufferEntry* cached = find_static_buffer(key); + if (cached != nullptr) { + return cached; + } + + std::vector values; + try { + values = bonsai_tensor_view_to_f32_vector(view); + } catch (const std::exception& error) { + reason = "f32_static_buffer_conversion_failed:"; + reason += error.what(); + return nullptr; + } catch (...) { + reason = "f32_static_buffer_conversion_failed"; + return nullptr; + } + return cached_static_buffer( + key, + static_cast(values.size() * sizeof(float)), + values.data(), + reason + ); + } + + void trim_static_cache() { + while (static_cache_bytes_ > STATIC_CACHE_MAX_BYTES && static_cache_.size() > 1U) { + auto victim = std::min_element( + static_cache_.begin(), + static_cache_.end(), + [](const auto& left, const auto& right) { + return left->last_used < right->last_used; + } + ); + if (victim == static_cache_.end()) { + return; + } + static_cache_bytes_ -= static_cast((*victim)->byte_count); + destroy_buffer(device_, (*victim)->buffer); + static_cache_.erase(victim); + } + } + + void clear_static_cache() { + for (const std::unique_ptr& entry : static_cache_) { + destroy_buffer(device_, entry->buffer); + } + static_cache_.clear(); + static_cache_bytes_ = 0; + } + + void log_static_cache_once() const { + bool expected = false; + if (g_logged_cache.compare_exchange_strong(expected, true)) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_static_cache enabled=true maxMb=%llu", + static_cast(STATIC_CACHE_MAX_BYTES / BYTES_IN_MB) + ); + } + } + + bool dispatch_quantized_matvec( + const BonsaiPackedWeightViews& views, + const float* input, + float* output, + uint64_t token_count, + const BonsaiVulkanParams& params, + std::string& reason + ) { + BonsaiVulkanBuffer input_buffer; + BonsaiVulkanBuffer output_buffer; + VkDescriptorPool descriptor_pool = VK_NULL_HANDLE; + VkCommandBuffer command_buffer = VK_NULL_HANDLE; + VkFence fence = VK_NULL_HANDLE; + + const auto cleanup = [&]() { + if (fence != VK_NULL_HANDLE) { + vkDestroyFence(device_, fence, nullptr); + } + if (command_buffer != VK_NULL_HANDLE) { + vkFreeCommandBuffers(device_, command_pool_, 1, &command_buffer); + } + if (descriptor_pool != VK_NULL_HANDLE) { + vkDestroyDescriptorPool(device_, descriptor_pool, nullptr); + } + destroy_buffer(device_, input_buffer); + destroy_buffer(device_, output_buffer); + }; + + BonsaiVulkanStaticBufferEntry* weight_entry = cached_raw_view_buffer(views.weight, reason); + BonsaiVulkanStaticBufferEntry* scale_entry = cached_f32_view_buffer(views.scales, reason); + BonsaiVulkanStaticBufferEntry* bias_entry = cached_f32_view_buffer(views.biases, reason); + if (weight_entry == nullptr || scale_entry == nullptr || bias_entry == nullptr) { + trim_static_cache(); + return false; + } + + const VkDeviceSize weight_bytes = weight_entry->byte_count; + const VkDeviceSize scale_bytes = scale_entry->byte_count; + const VkDeviceSize bias_bytes = bias_entry->byte_count; + const VkDeviceSize input_bytes = views.input_values * token_count * sizeof(float); + const VkDeviceSize output_bytes = views.leading_rows * token_count * sizeof(float); + + bool ok = create_host_storage_buffer(input_bytes, input_buffer, reason) && + create_host_storage_buffer(output_bytes, output_buffer, reason); + if (!ok) { + cleanup(); + trim_static_cache(); + return false; + } + + ok = write_buffer(input_buffer, input, static_cast(input_bytes), reason); + if (!ok) { + cleanup(); + trim_static_cache(); + return false; + } + + VkDescriptorPoolSize pool_size {}; + pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + pool_size.descriptorCount = 5; + VkDescriptorPoolCreateInfo pool_info {}; + pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + pool_info.maxSets = 1; + pool_info.poolSizeCount = 1; + pool_info.pPoolSizes = &pool_size; + VkResult result = vkCreateDescriptorPool(device_, &pool_info, nullptr, &descriptor_pool); + if (result != VK_SUCCESS) { + reason = "create_descriptor_pool_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + VkDescriptorSet descriptor_set = VK_NULL_HANDLE; + VkDescriptorSetAllocateInfo descriptor_allocate {}; + descriptor_allocate.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + descriptor_allocate.descriptorPool = descriptor_pool; + descriptor_allocate.descriptorSetCount = 1; + descriptor_allocate.pSetLayouts = &descriptor_set_layout_; + result = vkAllocateDescriptorSets(device_, &descriptor_allocate, &descriptor_set); + if (result != VK_SUCCESS) { + reason = "allocate_descriptor_set_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + VkDescriptorBufferInfo buffer_infos[5] {}; + buffer_infos[0] = { weight_entry->buffer.buffer, 0, weight_bytes }; + buffer_infos[1] = { scale_entry->buffer.buffer, 0, scale_bytes }; + buffer_infos[2] = { bias_entry->buffer.buffer, 0, bias_bytes }; + buffer_infos[3] = { input_buffer.buffer, 0, input_bytes }; + buffer_infos[4] = { output_buffer.buffer, 0, output_bytes }; + + VkWriteDescriptorSet descriptor_writes[5] {}; + for (uint32_t index = 0; index < 5U; ++index) { + descriptor_writes[index].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + descriptor_writes[index].dstSet = descriptor_set; + descriptor_writes[index].dstBinding = index; + descriptor_writes[index].descriptorCount = 1; + descriptor_writes[index].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + descriptor_writes[index].pBufferInfo = &buffer_infos[index]; + } + vkUpdateDescriptorSets(device_, 5, descriptor_writes, 0, nullptr); + + VkCommandBufferAllocateInfo command_allocate {}; + command_allocate.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + command_allocate.commandPool = command_pool_; + command_allocate.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + command_allocate.commandBufferCount = 1; + result = vkAllocateCommandBuffers(device_, &command_allocate, &command_buffer); + if (result != VK_SUCCESS) { + reason = "allocate_command_buffer_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + VkCommandBufferBeginInfo begin_info {}; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + result = vkBeginCommandBuffer(command_buffer, &begin_info); + if (result != VK_SUCCESS) { + reason = "begin_command_buffer_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + vkCmdBindPipeline(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_); + vkCmdBindDescriptorSets( + command_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline_layout_, + 0, + 1, + &descriptor_set, + 0, + nullptr + ); + vkCmdPushConstants( + command_buffer, + pipeline_layout_, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(BonsaiVulkanParams), + ¶ms + ); + vkCmdDispatch( + command_buffer, + static_cast(views.leading_rows), + static_cast(token_count), + 1 + ); + + VkBufferMemoryBarrier barrier {}; + barrier.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + barrier.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT; + barrier.dstAccessMask = VK_ACCESS_HOST_READ_BIT; + barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.buffer = output_buffer.buffer; + barrier.offset = 0; + barrier.size = output_bytes; + vkCmdPipelineBarrier( + command_buffer, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_HOST_BIT, + 0, + 0, + nullptr, + 1, + &barrier, + 0, + nullptr + ); + + result = vkEndCommandBuffer(command_buffer); + if (result != VK_SUCCESS) { + reason = "end_command_buffer_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + VkFenceCreateInfo fence_info {}; + fence_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + result = vkCreateFence(device_, &fence_info, nullptr, &fence); + if (result != VK_SUCCESS) { + reason = "create_fence_failed:" + vk_result_name(result); + cleanup(); + trim_static_cache(); + return false; + } + + VkSubmitInfo submit_info {}; + submit_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submit_info.commandBufferCount = 1; + submit_info.pCommandBuffers = &command_buffer; + result = vkQueueSubmit(queue_, 1, &submit_info, fence); + if (result != VK_SUCCESS) { + reason = "queue_submit_failed:" + vk_result_name(result); + disable_after_device_loss(result, reason); + cleanup(); + trim_static_cache(); + return false; + } + + result = vkWaitForFences(device_, 1, &fence, VK_TRUE, 30'000'000'000ULL); + if (result != VK_SUCCESS) { + reason = "wait_fence_failed:" + vk_result_name(result); + disable_after_device_loss(result, reason); + if (result != VK_ERROR_DEVICE_LOST) { + vkDeviceWaitIdle(device_); + } + cleanup(); + trim_static_cache(); + return false; + } + + ok = read_buffer(output_buffer, output, static_cast(output_bytes), reason); + cleanup(); + trim_static_cache(); + return ok; + } + + void disable_after_device_loss(VkResult result, const std::string& reason) { + if (result != VK_ERROR_DEVICE_LOST) { + return; + } + disabled_after_device_loss_.store(true); + bool expected = false; + if (g_logged_runtime_disabled.compare_exchange_strong(expected, true)) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_runtime_disabled reason=%s", + reason.c_str() + ); + } + } + + std::once_flag init_once_; + bool available_ = false; + std::atomic_bool disabled_after_device_loss_ { false }; + std::string init_reason_ = "not_initialized"; + std::string last_candidate_rejection_; + std::string device_name_ = "unknown"; + std::string api_version_ = "0.0.0"; + uint64_t max_storage_buffer_range_ = 0; + uint32_t max_workgroup_count_x_ = 0; + uint32_t max_workgroup_count_y_ = 0; + uint32_t queue_family_index_ = 0; + bool queue_family_compute_only_ = false; + VkInstance instance_ = VK_NULL_HANDLE; + VkPhysicalDevice physical_device_ = VK_NULL_HANDLE; + VkDevice device_ = VK_NULL_HANDLE; + VkQueue queue_ = VK_NULL_HANDLE; + VkDescriptorSetLayout descriptor_set_layout_ = VK_NULL_HANDLE; + VkPipelineLayout pipeline_layout_ = VK_NULL_HANDLE; + VkShaderModule shader_module_ = VK_NULL_HANDLE; + VkPipeline pipeline_ = VK_NULL_HANDLE; + VkCommandPool command_pool_ = VK_NULL_HANDLE; + std::vector> static_cache_; + uint64_t static_cache_bytes_ = 0; + uint64_t static_cache_tick_ = 0; + std::mutex queue_mutex_; +}; + +BonsaiVulkanRuntime& runtime() { + static BonsaiVulkanRuntime value; + return value; +} + +void log_dispatch_success_once(const BonsaiPackedWeightViews& views, uint64_t token_count) { + bool expected = false; + if (g_logged_success.compare_exchange_strong(expected, true)) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_matvec first_dispatch=true rows=%llu input_values=%llu tokens=%llu bits=%d group_size=%d", + static_cast(views.leading_rows), + static_cast(views.input_values), + static_cast(token_count), + views.bits, + views.group_size + ); + } +} + +void log_fallback_once(const std::string& reason) { + bool expected = false; + if (g_logged_fallback.compare_exchange_strong(expected, true)) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_matvec fallback=true reason=%s", + reason.c_str() + ); + } +} + +} // namespace + +void bonsai_vulkan_set_backend_mode(BonsaiVulkanBackendMode mode) { + g_backend_mode.store(static_cast(mode)); + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_backend mode=%s", + backend_mode_name(mode) + ); +} + +bool bonsai_vulkan_runtime_available() { + const bool available = runtime().available(); + bool expected = false; + if (g_logged_available.compare_exchange_strong(expected, true)) { + __android_log_print( + ANDROID_LOG_INFO, + LOG_TAG, + "phase=vulkan_runtime_checked available=%s", + bool_value(available).c_str() + ); + } + return available; +} + +bool bonsai_vulkan_quantized_matvec_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output +) { + return bonsai_vulkan_quantized_matvec_sequence_into(views, input, output, 1); +} + +bool bonsai_vulkan_quantized_matvec_sequence_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output, + uint64_t token_count +) { + std::string reason; + const bool ok = runtime().quantized_matvec_sequence_into( + views, + input, + output, + token_count, + reason + ); + if (ok) { + log_dispatch_success_once(views, token_count); + return true; + } + if (!reason.empty() && + reason != "unsupported_weight_layout" && + reason != "unsupported_bits" && + reason != "invalid_shape") { + log_fallback_once(reason); + } + return false; +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.h b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.h new file mode 100644 index 000000000..346e4ed2d --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan.h @@ -0,0 +1,28 @@ +#pragma once + +#include "bonsai_packed_weight.h" + +#include + +enum class BonsaiVulkanBackendMode { + Auto, + Cpu, + Vulkan, +}; + +void bonsai_vulkan_set_backend_mode(BonsaiVulkanBackendMode mode); + +bool bonsai_vulkan_runtime_available(); + +bool bonsai_vulkan_quantized_matvec_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output +); + +bool bonsai_vulkan_quantized_matvec_sequence_into( + const BonsaiPackedWeightViews& views, + const float* input, + float* output, + uint64_t token_count +); diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec.comp b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec.comp new file mode 100644 index 000000000..c0ba7928e --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec.comp @@ -0,0 +1,70 @@ +#version 450 + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer WeightBuffer { + uint values[]; +} weights; + +layout(set = 0, binding = 1) readonly buffer ScaleBuffer { + float values[]; +} scales; + +layout(set = 0, binding = 2) readonly buffer BiasBuffer { + float values[]; +} biases; + +layout(set = 0, binding = 3) readonly buffer InputBuffer { + float values[]; +} input_values; + +layout(set = 0, binding = 4) writeonly buffer OutputBuffer { + float values[]; +} output_values; + +layout(push_constant) uniform Params { + uint input_values_count; + uint packed_columns; + uint scale_groups; + uint bits; + uint group_size; + uint output_rows; +} params; + +shared float partials[64]; + +void main() { + const uint row = gl_WorkGroupID.x; + const uint token = gl_WorkGroupID.y; + const uint lane = gl_LocalInvocationID.x; + const uint values_per_word = 32u / params.bits; + const uint mask = (1u << params.bits) - 1u; + const uint input_start = token * params.input_values_count; + + float sum = 0.0; + for (uint column = lane; column < params.input_values_count; column += 64u) { + const uint word_column = column / values_per_word; + const uint offset = column - word_column * values_per_word; + const uint word = weights.values[row * params.packed_columns + word_column]; + const uint quantized = (word >> (offset * params.bits)) & mask; + const uint group = column / params.group_size; + const uint scale_index = row * params.scale_groups + group; + const float weight = float(quantized) * scales.values[scale_index] + + biases.values[scale_index]; + sum += input_values.values[input_start + column] * weight; + } + + partials[lane] = sum; + barrier(); + + for (uint stride = 32u; stride > 0u; stride >>= 1u) { + if (lane < stride) { + partials[lane] += partials[lane + stride]; + } + barrier(); + } + + if (lane == 0u) { + output_values.values[token * params.output_rows + row] = partials[0]; + } +} diff --git a/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec_spv.h b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec_spv.h new file mode 100644 index 000000000..c855e6860 --- /dev/null +++ b/feature/bonsai/src/androidMain/cpp/bonsai_vulkan_quantized_matvec_spv.h @@ -0,0 +1,463 @@ +#pragma once + +#include +#include + +alignas(4) static const uint8_t kBonsaiVulkanQuantizedMatvecSpv[] = { + 0x03, 0x02, 0x23, 0x07, 0x00, 0x03, 0x01, 0x00, 0x0a, 0x00, 0x0d, 0x00, + 0xc7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x53, 0x4c, 0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x07, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x10, 0x00, 0x06, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x03, 0x00, + 0x02, 0x00, 0x00, 0x00, 0xc2, 0x01, 0x00, 0x00, 0x04, 0x00, 0x0a, 0x00, + 0x47, 0x4c, 0x5f, 0x47, 0x4f, 0x4f, 0x47, 0x4c, 0x45, 0x5f, 0x63, 0x70, + 0x70, 0x5f, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x6e, 0x65, + 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x76, 0x65, 0x00, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x47, 0x4c, 0x5f, 0x47, 0x4f, 0x4f, 0x47, 0x4c, + 0x45, 0x5f, 0x69, 0x6e, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x5f, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x76, 0x65, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x08, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x77, 0x00, + 0x05, 0x00, 0x06, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x57, + 0x6f, 0x72, 0x6b, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x44, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x6c, 0x61, 0x6e, 0x65, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x08, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x67, 0x6c, 0x5f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, + 0x49, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x44, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x5f, 0x70, 0x65, 0x72, 0x5f, 0x77, + 0x6f, 0x72, 0x64, 0x00, 0x05, 0x00, 0x04, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x70, 0x75, + 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x5f, 0x63, 0x6f, 0x75, + 0x6e, 0x74, 0x00, 0x00, 0x06, 0x00, 0x07, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, + 0x6f, 0x6c, 0x75, 0x6d, 0x6e, 0x73, 0x00, 0x00, 0x06, 0x00, 0x07, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x73, 0x63, 0x61, 0x6c, + 0x65, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x06, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x67, 0x72, 0x6f, 0x75, + 0x70, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x00, 0x00, 0x06, 0x00, 0x06, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x70, + 0x75, 0x74, 0x5f, 0x72, 0x6f, 0x77, 0x73, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x23, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x73, 0x6b, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x30, 0x00, 0x00, 0x00, 0x73, 0x75, 0x6d, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x32, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6c, 0x75, + 0x6d, 0x6e, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x3e, 0x00, 0x00, 0x00, + 0x77, 0x6f, 0x72, 0x64, 0x5f, 0x63, 0x6f, 0x6c, 0x75, 0x6d, 0x6e, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x42, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x66, 0x73, + 0x65, 0x74, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x77, 0x6f, 0x72, 0x64, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, + 0x4a, 0x00, 0x00, 0x00, 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x42, 0x75, + 0x66, 0x66, 0x65, 0x72, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x73, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x57, 0x00, 0x00, 0x00, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, + 0x64, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0x60, 0x00, 0x00, 0x00, + 0x67, 0x72, 0x6f, 0x75, 0x70, 0x00, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x66, 0x00, 0x00, 0x00, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x5f, 0x69, 0x6e, + 0x64, 0x65, 0x78, 0x00, 0x05, 0x00, 0x04, 0x00, 0x6e, 0x00, 0x00, 0x00, + 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x72, 0x00, 0x00, 0x00, 0x53, 0x63, 0x61, 0x6c, 0x65, 0x42, 0x75, 0x66, + 0x66, 0x65, 0x72, 0x00, 0x06, 0x00, 0x05, 0x00, 0x72, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x74, 0x00, 0x00, 0x00, 0x73, 0x63, 0x61, 0x6c, + 0x65, 0x73, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x42, 0x69, 0x61, 0x73, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, + 0x7d, 0x00, 0x00, 0x00, 0x62, 0x69, 0x61, 0x73, 0x65, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x05, 0x00, 0x83, 0x00, 0x00, 0x00, 0x49, 0x6e, 0x70, 0x75, + 0x74, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x83, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x73, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, 0x85, 0x00, 0x00, 0x00, + 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, + 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x61, 0x6c, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x04, 0x00, 0x9b, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, + 0x64, 0x65, 0x00, 0x00, 0x05, 0x00, 0x06, 0x00, 0xb9, 0x00, 0x00, 0x00, + 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, + 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, 0xb9, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, + 0x05, 0x00, 0x06, 0x00, 0xbb, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x70, + 0x75, 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x15, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x05, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, + 0x1a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x49, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x4a, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x71, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x72, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x72, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x74, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x7a, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x7d, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x7d, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0x82, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0x83, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x83, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0x83, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0x85, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x85, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0xb8, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x04, 0x00, 0xb9, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0xb9, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x03, 0x00, 0xb9, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x04, 0x00, 0xbb, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0xbb, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, + 0xc6, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, 0x21, 0x00, 0x03, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x15, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x08, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x15, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x2f, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x02, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x49, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x4b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x4a, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x4b, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x54, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0x62, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, + 0x1d, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x71, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x72, 0x00, 0x00, 0x00, 0x71, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x73, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x72, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x73, 0x00, 0x00, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x76, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x1d, 0x00, 0x03, 0x00, 0x7a, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x1e, 0x00, 0x03, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x7c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x7c, 0x00, 0x00, 0x00, + 0x7d, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0x82, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0x83, 0x00, 0x00, 0x00, 0x82, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x84, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x83, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x84, 0x00, 0x00, 0x00, 0x85, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x8f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x04, 0x00, + 0x92, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x04, 0x00, 0x93, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x92, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x93, 0x00, 0x00, 0x00, + 0x94, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0x97, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x9a, 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x1d, 0x00, 0x03, 0x00, + 0xb8, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x03, 0x00, + 0xb9, 0x00, 0x00, 0x00, 0xb8, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00, + 0xba, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xb9, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0xba, 0x00, 0x00, 0x00, 0xbb, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x00, 0x00, + 0xbd, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x06, 0x00, + 0x09, 0x00, 0x00, 0x00, 0xc6, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x36, 0x00, 0x05, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x23, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x2f, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x57, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x66, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, + 0x2f, 0x00, 0x00, 0x00, 0x6e, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x86, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x22, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0xc4, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x82, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x27, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x29, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x1f, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x2d, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x2d, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x31, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x33, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, + 0x34, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x34, 0x00, 0x00, 0x00, + 0xf6, 0x00, 0x04, 0x00, 0x36, 0x00, 0x00, 0x00, 0x37, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, 0x38, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0x38, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x39, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x00, + 0xb0, 0x00, 0x05, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x00, 0x00, + 0x39, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00, + 0x3d, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00, 0x36, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0x35, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x3f, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x86, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x00, 0x00, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x43, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x46, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, 0x82, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x47, 0x00, 0x00, 0x00, 0x43, 0x00, 0x00, 0x00, + 0x46, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x42, 0x00, 0x00, 0x00, + 0x47, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x4d, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x1f, 0x00, 0x00, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x51, 0x00, 0x00, 0x00, 0x4d, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x52, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x51, 0x00, 0x00, 0x00, + 0x52, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x54, 0x00, 0x00, 0x00, + 0x55, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, + 0x53, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x56, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x59, 0x00, 0x00, 0x00, + 0x42, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, + 0x5a, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x5b, 0x00, 0x00, 0x00, + 0x5a, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, 0x59, 0x00, 0x00, 0x00, 0x5b, 0x00, 0x00, 0x00, + 0xc2, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x5d, 0x00, 0x00, 0x00, + 0x58, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x5e, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0xc7, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, + 0x5d, 0x00, 0x00, 0x00, 0x5e, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x57, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, + 0x86, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x67, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x6a, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, + 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x6b, 0x00, 0x00, 0x00, + 0x67, 0x00, 0x00, 0x00, 0x6a, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, + 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x6d, 0x00, 0x00, 0x00, + 0x6b, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x66, 0x00, 0x00, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x57, 0x00, 0x00, 0x00, + 0x70, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x6f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x75, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x76, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0x75, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, + 0x85, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x70, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x7e, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x06, 0x00, 0x76, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x00, + 0x7d, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, 0x7e, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x7f, 0x00, 0x00, 0x00, 0x81, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x6e, 0x00, 0x00, 0x00, 0x81, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x86, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x87, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x86, 0x00, 0x00, 0x00, + 0x87, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x76, 0x00, 0x00, 0x00, + 0x89, 0x00, 0x00, 0x00, 0x85, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00, + 0x88, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0x8a, 0x00, 0x00, 0x00, 0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x00, 0x00, 0x6e, 0x00, 0x00, 0x00, + 0x85, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x8c, 0x00, 0x00, 0x00, + 0x8a, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x00, 0x00, + 0x8d, 0x00, 0x00, 0x00, 0x8c, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, + 0x37, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x37, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, + 0x32, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x91, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x32, 0x00, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, + 0xf9, 0x00, 0x02, 0x00, 0x34, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x36, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x95, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x96, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x97, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, + 0x94, 0x00, 0x00, 0x00, 0x95, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0x98, 0x00, 0x00, 0x00, 0x96, 0x00, 0x00, 0x00, 0xe0, 0x00, 0x04, 0x00, + 0x99, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00, 0x9a, 0x00, 0x00, 0x00, + 0x3e, 0x00, 0x03, 0x00, 0x9b, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0xf9, 0x00, 0x02, 0x00, 0x9c, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x9c, 0x00, 0x00, 0x00, 0xf6, 0x00, 0x04, 0x00, 0x9e, 0x00, 0x00, 0x00, + 0x9f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, + 0xa0, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0xa0, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa1, 0x00, 0x00, 0x00, + 0x9b, 0x00, 0x00, 0x00, 0xac, 0x00, 0x05, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0xa2, 0x00, 0x00, 0x00, 0xa1, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0xfa, 0x00, 0x04, 0x00, 0xa2, 0x00, 0x00, 0x00, 0x9d, 0x00, 0x00, 0x00, + 0x9e, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x9d, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xa4, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00, 0xb0, 0x00, 0x05, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0xa5, 0x00, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, + 0xa4, 0x00, 0x00, 0x00, 0xf7, 0x00, 0x03, 0x00, 0xa7, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00, 0xa5, 0x00, 0x00, 0x00, + 0xa6, 0x00, 0x00, 0x00, 0xa7, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0xa6, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xa8, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0xaa, 0x00, 0x00, 0x00, + 0x9b, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xab, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, 0xaa, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x97, 0x00, 0x00, 0x00, 0xac, 0x00, 0x00, 0x00, + 0x94, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0xad, 0x00, 0x00, 0x00, 0xac, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x97, 0x00, 0x00, 0x00, 0xae, 0x00, 0x00, 0x00, + 0x94, 0x00, 0x00, 0x00, 0xa8, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0xaf, 0x00, 0x00, 0x00, 0xae, 0x00, 0x00, 0x00, + 0x81, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, 0xb0, 0x00, 0x00, 0x00, + 0xaf, 0x00, 0x00, 0x00, 0xad, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x97, 0x00, 0x00, 0x00, 0xb1, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0xa8, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0xb1, 0x00, 0x00, 0x00, + 0xb0, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, 0xa7, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0xa7, 0x00, 0x00, 0x00, 0xe0, 0x00, 0x04, 0x00, + 0x99, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00, 0x9a, 0x00, 0x00, 0x00, + 0xf9, 0x00, 0x02, 0x00, 0x9f, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, + 0x9f, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xb2, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00, 0xc2, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xb3, 0x00, 0x00, 0x00, 0xb2, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, 0x9b, 0x00, 0x00, 0x00, + 0xb3, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, 0x9c, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0x9e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0xaa, 0x00, 0x05, 0x00, 0x3c, 0x00, 0x00, 0x00, 0xb5, 0x00, 0x00, 0x00, + 0xb4, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xf7, 0x00, 0x03, 0x00, + 0xb7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00, + 0xb5, 0x00, 0x00, 0x00, 0xb6, 0x00, 0x00, 0x00, 0xb7, 0x00, 0x00, 0x00, + 0xf8, 0x00, 0x02, 0x00, 0xb6, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xbc, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x41, 0x00, 0x05, 0x00, 0x1f, 0x00, 0x00, 0x00, 0xbe, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0xbd, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xbf, 0x00, 0x00, 0x00, 0xbe, 0x00, 0x00, 0x00, + 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00, + 0xbc, 0x00, 0x00, 0x00, 0xbf, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0xc1, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xc2, 0x00, 0x00, 0x00, + 0xc0, 0x00, 0x00, 0x00, 0xc1, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, + 0x97, 0x00, 0x00, 0x00, 0xc3, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00, + 0xc4, 0x00, 0x00, 0x00, 0xc3, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, + 0x76, 0x00, 0x00, 0x00, 0xc5, 0x00, 0x00, 0x00, 0xbb, 0x00, 0x00, 0x00, + 0x2a, 0x00, 0x00, 0x00, 0xc2, 0x00, 0x00, 0x00, 0x3e, 0x00, 0x03, 0x00, + 0xc5, 0x00, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, + 0xb7, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0xb7, 0x00, 0x00, 0x00, + 0xfd, 0x00, 0x01, 0x00, 0x38, 0x00, 0x01, 0x00 +}; +static constexpr size_t kBonsaiVulkanQuantizedMatvecSpvSize = + sizeof(kBonsaiVulkanQuantizedMatvecSpv); diff --git a/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayout.kt b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayout.kt new file mode 100644 index 000000000..2cfd759a3 --- /dev/null +++ b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayout.kt @@ -0,0 +1,160 @@ +package com.shifthackz.aisdv1.feature.bonsai + +import java.io.File +import java.io.FileOutputStream +import java.util.zip.ZipFile + +/** + * Resolves Android Bonsai model resources using the same layout rules as the iOS Swift runtime. + * + * @property rootPath directory containing the Bonsai resource folders. + * @property packedTransformerPath directory containing packed mflux transformer weights. + * @property textEncoderPath preferred text encoder directory. + * @property tokenizerPath tokenizer directory. + * @property vaePath VAE directory. + * @property schedulerPath scheduler directory. + * @author Dmitriy Moroz + */ +internal data class AndroidBonsaiModelLayout( + val rootPath: String, + val packedTransformerPath: String, + val textEncoderPath: String, + val tokenizerPath: String, + val vaePath: String, + val schedulerPath: String, +) { + + companion object { + private const val TEXT_ENCODER_MLX_DIRECTORY = "text_encoder-mlx-4bit" + private const val TEXT_ENCODER_LEGACY_DIRECTORY = "text_encoder" + private const val TRANSFORMER_DIRECTORY = "transformer-packed-mflux" + private const val TOKENIZER_DIRECTORY = "tokenizer" + private const val VAE_DIRECTORY = "vae" + private const val SCHEDULER_DIRECTORY = "scheduler" + private const val MODEL_ARCHIVE = "model.zip" + private const val EXTRACTED_DIRECTORY = "extracted" + private const val RESOURCES_DIRECTORY = "Resources" + private const val MAX_NESTED_SEARCH_DEPTH = 4 + + /** + * Resolves a Bonsai model path to its runtime layout, extracting `model.zip` when needed. + * + * @param modelPath selected model directory. + * @return resolved Bonsai layout. + * @throws IllegalStateException when required Bonsai resources cannot be found. + * @author Dmitriy Moroz + */ + fun resolve(modelPath: String): AndroidBonsaiModelLayout { + val modelDirectory = File(modelPath) + find(inDirectory = modelDirectory)?.let { layout -> return layout } + + val archive = File(modelDirectory, MODEL_ARCHIVE) + if (!archive.isFile) { + throw IllegalStateException("Bonsai model resources not found at $modelPath.") + } + + val extracted = File(modelDirectory, EXTRACTED_DIRECTORY) + if (find(inDirectory = extracted) == null) { + if (extracted.exists()) extracted.deleteRecursively() + extracted.mkdirs() + archive.unzipSafely(destination = extracted) + } + + return find(inDirectory = extracted) + ?: throw IllegalStateException( + "Invalid Bonsai model layout: expected $TRANSFORMER_DIRECTORY, " + + "$TEXT_ENCODER_MLX_DIRECTORY or $TEXT_ENCODER_LEGACY_DIRECTORY, " + + "$TOKENIZER_DIRECTORY, $VAE_DIRECTORY, and $SCHEDULER_DIRECTORY directories.", + ) + } + + private fun find(inDirectory: File): AndroidBonsaiModelLayout? { + if (!inDirectory.exists()) return null + + directCandidates(inDirectory) + .firstNotNullOfOrNull(::layout) + ?.let { layout -> return layout } + + val rootDepth = inDirectory.toPath().nameCount + return inDirectory + .walkTopDown() + .filter(File::isDirectory) + .filter { candidate -> + candidate.toPath().nameCount - rootDepth <= MAX_NESTED_SEARCH_DEPTH + } + .firstNotNullOfOrNull(::layout) + } + + private fun directCandidates(root: File): List = listOf( + root, + File(root, RESOURCES_DIRECTORY), + File(root, EXTRACTED_DIRECTORY), + File(File(root, EXTRACTED_DIRECTORY), RESOURCES_DIRECTORY), + ) + + private fun layout(root: File): AndroidBonsaiModelLayout? { + if (!isBonsaiRoot(root)) return null + val textEncoder = firstDirectory( + root = root, + names = listOf(TEXT_ENCODER_MLX_DIRECTORY, TEXT_ENCODER_LEGACY_DIRECTORY), + ) ?: return null + + return AndroidBonsaiModelLayout( + rootPath = root.path, + packedTransformerPath = File(root, TRANSFORMER_DIRECTORY).path, + textEncoderPath = textEncoder.path, + tokenizerPath = File(root, TOKENIZER_DIRECTORY).path, + vaePath = File(root, VAE_DIRECTORY).path, + schedulerPath = File(root, SCHEDULER_DIRECTORY).path, + ) + } + + private fun isBonsaiRoot(root: File): Boolean { + val quantizationConfig = File( + File(root, TRANSFORMER_DIRECTORY), + "quantization_config.json", + ) + val requiredDirectories = listOf( + File(root, TOKENIZER_DIRECTORY), + File(root, VAE_DIRECTORY), + File(root, SCHEDULER_DIRECTORY), + ) + + return quantizationConfig.isFile && + firstDirectory( + root = root, + names = listOf(TEXT_ENCODER_MLX_DIRECTORY, TEXT_ENCODER_LEGACY_DIRECTORY), + ) != null && + requiredDirectories.all(File::isDirectory) + } + + private fun firstDirectory( + root: File, + names: List, + ): File? = names + .map { name -> File(root, name) } + .firstOrNull(File::isDirectory) + + private fun File.unzipSafely(destination: File) { + val destinationRoot = destination.canonicalFile + ZipFile(this).use { zip -> + zip.entries().asSequence().forEach { entry -> + val target = File(destinationRoot, entry.name).canonicalFile + if (!target.path.startsWith(destinationRoot.path + File.separator)) { + throw IllegalStateException("Invalid Bonsai model archive entry: ${entry.name}.") + } + if (entry.isDirectory) { + target.mkdirs() + } else { + target.parentFile?.mkdirs() + zip.getInputStream(entry).use { input -> + FileOutputStream(target).use { output -> + input.copyTo(output) + } + } + } + } + } + } + } +} diff --git a/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidator.kt b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidator.kt new file mode 100644 index 000000000..fe38ada37 --- /dev/null +++ b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidator.kt @@ -0,0 +1,33 @@ +package com.shifthackz.aisdv1.feature.bonsai + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import java.io.File + +/** + * Validates Android Bonsai generation requests before entering the native runtime. + * + * @author Dmitriy Moroz + */ +internal object AndroidBonsaiRequestValidator { + + fun validate( + payload: TextToImagePayload, + modelPath: String, + ) { + if (payload.prompt.isBlank()) { + throw IllegalStateException("Prompt is required.") + } + + if (payload.width <= 0 || + payload.height <= 0 || + payload.width % 32 != 0 || + payload.height % 32 != 0 + ) { + throw IllegalStateException("Bonsai image size must be positive and divisible by 32.") + } + + if (!File(modelPath).exists()) { + throw IllegalStateException("Bonsai model resources were not found at $modelPath.") + } + } +} diff --git a/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiDiffusionImpl.kt b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiDiffusionImpl.kt index 2fef60d52..e59325e97 100644 --- a/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiDiffusionImpl.kt +++ b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiDiffusionImpl.kt @@ -1,43 +1,108 @@ package com.shifthackz.aisdv1.feature.bonsai +import android.os.Process import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.feature.bonsai.BonsaiDiffusion +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.onStart +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.job +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext /** - * Implements Android noop behavior for the iOS-only Bonsai feature layer. + * Android Bonsai runtime entry point backed by the NDK bridge. * - * @author Dmitriy Moroz + * The implementation keeps native execution off the caller thread, serializes + * generation requests, and relays step progress through the shared local + * diffusion status contract. */ internal class BonsaiDiffusionImpl : BonsaiDiffusion { - /** - * Executes the `process` step in the SDAI Bonsai feature layer. - * - * @param payload generation payload used by the operation. - * @param modelPath local model directory selected by the user. - * @author Dmitriy Moroz - */ + private val mutex = Mutex() + private val statusFlow = MutableSharedFlow(replay = 1) + override suspend fun process( payload: TextToImagePayload, modelPath: String, - ): String { - throw IllegalStateException("Bonsai Image generation is available on iOS only.") + ): String = mutex.withLock { + withContext(Dispatchers.Default) { + AndroidBonsaiRequestValidator.validate( + payload = payload, + modelPath = modelPath, + ) + BonsaiNativeBridge.ensureLoaded() + val layout = AndroidBonsaiModelLayout.resolve(modelPath) + statusFlow.tryEmit(LocalDiffusionStatus(current = 0, total = payload.samplingSteps)) + val job = currentCoroutineContext().job + val cancellationHandle = job.invokeOnCompletion { cause -> + if (cause is CancellationException) { + BonsaiNativeBridge.interrupt() + } + } + + try { + withBackgroundThreadPriority { + BonsaiNativeBridge.generate( + layout = layout, + prompt = payload.prompt, + negativePrompt = payload.negativePrompt, + samplingSteps = payload.samplingSteps, + cfgScale = payload.cfgScale, + width = payload.width, + height = payload.height, + seed = payload.seed, + batchCount = payload.batchCount.coerceAtLeast(1), + allowNsfw = payload.nsfw, + backend = payload.bonsaiBackend.key, + callback = object : BonsaiNativeBridge.ProgressCallback { + override fun onProgress(current: Int, total: Int) { + statusFlow.tryEmit( + LocalDiffusionStatus( + current = current, + total = total, + ), + ) + } + }, + ) + } + } finally { + cancellationHandle.dispose() + } + } + } + + private inline fun withBackgroundThreadPriority(block: () -> T): T { + val threadId = Process.myTid() + val previousPriority = runCatching { + Process.getThreadPriority(threadId) + }.getOrNull() + runCatching { + Process.setThreadPriority(threadId, Process.THREAD_PRIORITY_BACKGROUND) + } + return try { + block() + } finally { + if (previousPriority != null) { + runCatching { + Process.setThreadPriority(threadId, previousPriority) + } + } + } + } + + override suspend fun interrupt() { + if (BonsaiNativeBridge.isAvailable) { + runCatching { BonsaiNativeBridge.interrupt() } + } } - /** - * Performs the SDAI side effect handled by `interrupt`. - * - * @author Dmitriy Moroz - */ - override suspend fun interrupt() = Unit - - /** - * Loads SDAI data through `observeStatus`. - * - * @author Dmitriy Moroz - */ - override fun observeStatus(): Flow = flowOf(LocalDiffusionStatus(0, 0)) + override fun observeStatus(): Flow = statusFlow + .onStart { emit(LocalDiffusionStatus(current = 0, total = 0)) } } diff --git a/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge.kt b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge.kt new file mode 100644 index 000000000..421d5ee2c --- /dev/null +++ b/feature/bonsai/src/androidMain/kotlin/com/shifthackz/aisdv1/feature/bonsai/BonsaiNativeBridge.kt @@ -0,0 +1,110 @@ +package com.shifthackz.aisdv1.feature.bonsai + +/** + * Thin Kotlin boundary around the Android Bonsai native library. + * + * The bridge centralizes library loading, model probing, generation, progress + * callbacks, and cancellation so the rest of the feature never calls JNI entry + * points directly. + */ +internal object BonsaiNativeBridge { + + private val loadResult = runCatching { + System.loadLibrary(LIBRARY_NAME) + } + + val isAvailable: Boolean + get() = loadResult.isSuccess + + fun ensureLoaded() { + loadResult.getOrElse { error -> + throw IllegalStateException( + "Bonsai native runtime is not available on this Android build.", + error, + ) + } + } + + fun probe(layout: AndroidBonsaiModelLayout): String = + probeModel( + rootPath = layout.rootPath, + packedTransformerPath = layout.packedTransformerPath, + textEncoderPath = layout.textEncoderPath, + tokenizerPath = layout.tokenizerPath, + vaePath = layout.vaePath, + schedulerPath = layout.schedulerPath, + ) + + private external fun probeModel( + rootPath: String, + packedTransformerPath: String, + textEncoderPath: String, + tokenizerPath: String, + vaePath: String, + schedulerPath: String, + ): String + + fun generate( + layout: AndroidBonsaiModelLayout, + prompt: String, + negativePrompt: String, + samplingSteps: Int, + cfgScale: Float, + width: Int, + height: Int, + seed: String, + batchCount: Int, + allowNsfw: Boolean, + backend: String, + callback: ProgressCallback, + ): String = generateModel( + rootPath = layout.rootPath, + packedTransformerPath = layout.packedTransformerPath, + textEncoderPath = layout.textEncoderPath, + tokenizerPath = layout.tokenizerPath, + vaePath = layout.vaePath, + schedulerPath = layout.schedulerPath, + prompt = prompt, + negativePrompt = negativePrompt, + samplingSteps = samplingSteps, + cfgScale = cfgScale, + width = width, + height = height, + seed = seed, + batchCount = batchCount, + allowNsfw = allowNsfw, + backend = backend, + callback = callback, + ) + + private external fun generateModel( + rootPath: String, + packedTransformerPath: String, + textEncoderPath: String, + tokenizerPath: String, + vaePath: String, + schedulerPath: String, + prompt: String, + negativePrompt: String, + samplingSteps: Int, + cfgScale: Float, + width: Int, + height: Int, + seed: String, + batchCount: Int, + allowNsfw: Boolean, + backend: String, + callback: ProgressCallback, + ): String + + external fun interrupt() + + /** + * JNI callback used by the native runtime to report completed diffusion steps. + */ + interface ProgressCallback { + fun onProgress(current: Int, total: Int) + } +} + +private const val LIBRARY_NAME = "sdai_bonsai" diff --git a/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayoutTest.kt b/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayoutTest.kt new file mode 100644 index 000000000..16622784a --- /dev/null +++ b/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiModelLayoutTest.kt @@ -0,0 +1,137 @@ +package com.shifthackz.aisdv1.feature.bonsai + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import java.io.File +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream + +class AndroidBonsaiModelLayoutTest { + + @get:Rule + val temporaryFolder = TemporaryFolder() + + @Test + fun `given direct Bonsai root, resolve returns direct layout`() { + val root = temporaryFolder.newFolder("bonsai") + root.createBonsaiLayout() + + val actual = AndroidBonsaiModelLayout.resolve(root.path) + + assertEquals(root.path, actual.rootPath) + assertEquals(File(root, "transformer-packed-mflux").path, actual.packedTransformerPath) + assertEquals(File(root, "text_encoder-mlx-4bit").path, actual.textEncoderPath) + assertEquals(File(root, "tokenizer").path, actual.tokenizerPath) + assertEquals(File(root, "vae").path, actual.vaePath) + assertEquals(File(root, "scheduler").path, actual.schedulerPath) + } + + @Test + fun `given both text encoder directories, resolve prefers mlx text encoder`() { + val root = temporaryFolder.newFolder("bonsai") + root.createBonsaiLayout(includeLegacyTextEncoder = true) + + val actual = AndroidBonsaiModelLayout.resolve(root.path) + + assertEquals(File(root, "text_encoder-mlx-4bit").path, actual.textEncoderPath) + } + + @Test + fun `given Resources Bonsai root, resolve returns resources layout`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + val resources = File(modelDirectory, "Resources").apply(File::mkdirs) + resources.createBonsaiLayout(includeMlxTextEncoder = false, includeLegacyTextEncoder = true) + + val actual = AndroidBonsaiModelLayout.resolve(modelDirectory.path) + + assertEquals(resources.path, actual.rootPath) + assertEquals(File(resources, "text_encoder").path, actual.textEncoderPath) + } + + @Test + fun `given nested Resources Bonsai root, resolve finds nested layout`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + val resources = File(modelDirectory, "wrapper/level/Resources").apply(File::mkdirs) + resources.createBonsaiLayout() + + val actual = AndroidBonsaiModelLayout.resolve(modelDirectory.path) + + assertEquals(resources.path, actual.rootPath) + } + + @Test + fun `given model zip, resolve extracts archive and returns extracted layout`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + val archive = File(modelDirectory, "model.zip") + archive.writeZip( + "ArchiveRoot/Resources/transformer-packed-mflux/quantization_config.json" to "{}", + "ArchiveRoot/Resources/text_encoder-mlx-4bit/model.safetensors" to "text", + "ArchiveRoot/Resources/tokenizer/tokenizer.json" to "{}", + "ArchiveRoot/Resources/vae/model.safetensors" to "vae", + "ArchiveRoot/Resources/scheduler/scheduler_config.json" to "{}", + ) + + val actual = AndroidBonsaiModelLayout.resolve(modelDirectory.path) + + val expectedRoot = File(modelDirectory, "extracted/ArchiveRoot/Resources") + assertEquals(expectedRoot.path, actual.rootPath) + assertTrue(File(expectedRoot, "tokenizer/tokenizer.json").isFile) + } + + @Test + fun `given invalid model directory, resolve reports missing resources`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + + val actual = runCatching { + AndroidBonsaiModelLayout.resolve(modelDirectory.path) + }.exceptionOrNull() + + assertEquals( + "Bonsai model resources not found at ${modelDirectory.path}.", + actual?.message, + ) + } + + @Test + fun `given unsafe model zip entry, resolve rejects archive`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + File(modelDirectory, "model.zip").writeZip("../escape.txt" to "escape") + + val actual = runCatching { + AndroidBonsaiModelLayout.resolve(modelDirectory.path) + }.exceptionOrNull() + + assertEquals( + "Invalid Bonsai model archive entry: ../escape.txt.", + actual?.message, + ) + } +} + +private fun File.createBonsaiLayout( + includeMlxTextEncoder: Boolean = true, + includeLegacyTextEncoder: Boolean = false, +) { + File(this, "transformer-packed-mflux").apply(File::mkdirs) + File(this, "transformer-packed-mflux/quantization_config.json").writeText("{}") + if (includeMlxTextEncoder) File(this, "text_encoder-mlx-4bit").mkdirs() + if (includeLegacyTextEncoder) File(this, "text_encoder").mkdirs() + File(this, "tokenizer").mkdirs() + File(this, "vae").mkdirs() + File(this, "scheduler").mkdirs() +} + +private fun File.writeZip(vararg entries: Pair) { + outputStream().use { output -> + ZipOutputStream(output).use { zip -> + entries.forEach { (name, content) -> + zip.putNextEntry(ZipEntry(name)) + zip.write(content.toByteArray()) + zip.closeEntry() + } + } + } +} diff --git a/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidatorTest.kt b/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidatorTest.kt new file mode 100644 index 000000000..0064e6111 --- /dev/null +++ b/feature/bonsai/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/feature/bonsai/AndroidBonsaiRequestValidatorTest.kt @@ -0,0 +1,100 @@ +package com.shifthackz.aisdv1.feature.bonsai + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import org.junit.Assert.assertEquals +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder + +class AndroidBonsaiRequestValidatorTest { + + @get:Rule + val temporaryFolder = TemporaryFolder() + + @Test + fun `given valid request, validate succeeds`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + + val actual = runCatching { + AndroidBonsaiRequestValidator.validate( + payload = makePayload(), + modelPath = modelDirectory.path, + ) + }.exceptionOrNull() + + assertEquals(null, actual) + } + + @Test + fun `given blank prompt, validate reports required prompt`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + + val actual = runCatching { + AndroidBonsaiRequestValidator.validate( + payload = makePayload(prompt = " \n "), + modelPath = modelDirectory.path, + ) + }.exceptionOrNull() + + assertEquals("Prompt is required.", actual?.message) + } + + @Test + fun `given invalid size, validate reports size requirement`() { + val modelDirectory = temporaryFolder.newFolder("bonsai") + + val actual = runCatching { + AndroidBonsaiRequestValidator.validate( + payload = makePayload(width = 130), + modelPath = modelDirectory.path, + ) + }.exceptionOrNull() + + assertEquals( + "Bonsai image size must be positive and divisible by 32.", + actual?.message, + ) + } + + @Test + fun `given missing model path, validate reports missing resources`() { + val missingPath = temporaryFolder.root.resolve("missing-bonsai").path + + val actual = runCatching { + AndroidBonsaiRequestValidator.validate( + payload = makePayload(), + modelPath = missingPath, + ) + }.exceptionOrNull() + + assertEquals( + "Bonsai model resources were not found at $missingPath.", + actual?.message, + ) + } +} + +private fun makePayload( + prompt: String = "a bonsai tree", + width: Int = 128, + height: Int = 128, +): TextToImagePayload = TextToImagePayload( + prompt = prompt, + negativePrompt = "", + samplingSteps = 4, + cfgScale = 1f, + width = width, + height = height, + restoreFaces = false, + seed = "1", + subSeed = "", + subSeedStrength = 0f, + sampler = "", + nsfw = false, + batchCount = 1, + style = null, + quality = null, + openAiModel = null, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt index c996a8726..fa642e426 100644 --- a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt +++ b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt @@ -13,6 +13,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveBonsaiProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase import kotlinx.coroutines.CoroutineScope @@ -24,51 +25,28 @@ import kotlinx.coroutines.launch import com.shifthackz.aisdv1.core.localization.R as LocalizationR /** - * Coordinates `CoreGenerationWorker` behavior in the SDAI background work feature layer. + * Shared foreground worker base for generation tasks. * - * @author Dmitriy Moroz + * It owns cancellation, progress subscriptions, foreground notifications, and + * background observer updates that are common to txt2img and img2img work. */ internal abstract class CoreGenerationWorker( context: Context, workerParameters: WorkerParameters, pushNotificationManager: PushNotificationManager, activityIntentProvider: ActivityIntentProvider, - /** - * Exposes the `preferenceManager` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val preferenceManager: PreferenceManager, - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `observeHordeProcessStatusUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, - /** - * Exposes the `observeLocalDiffusionProcessStatusUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, - /** - * Exposes the `interruptGenerationUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ + private val observeBonsaiProcessStatusUseCase: ObserveBonsaiProcessStatusUseCase, private val interruptGenerationUseCase: InterruptGenerationUseCase, ) : NotificationWorker( context = context, workerParameters = workerParameters, pushNotificationManager = pushNotificationManager, activityIntentProvider = activityIntentProvider, -){ +) { private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) @@ -85,6 +63,15 @@ internal abstract class CoreGenerationWorker( coroutineScope.coroutineContext.cancelChildren() } + protected fun listenSourceStatus() { + when (source) { + ServerSource.HORDE -> listenHordeStatus() + ServerSource.LOCAL_MICROSOFT_ONNX -> listenLocalDiffusionStatus() + ServerSource.LOCAL_APPLE_BONSAI -> listenBonsaiStatus() + else -> Unit + } + } + protected fun listenHordeStatus() { coroutineScope.launch { observeHordeProcessStatusUseCase() @@ -139,6 +126,29 @@ internal abstract class CoreGenerationWorker( } } + protected fun listenBonsaiStatus() { + coroutineScope.launch { + observeBonsaiProcessStatusUseCase() + .catch { t -> errorLog(t) } + .collect { status -> + val title = applicationContext.getString(LocalizationR.string.notification_running_title) + val subTitle = applicationContext.getString( + LocalizationR.string.communicating_status_steps, + status.current.toString(), + status.total.toString(), + ) + backgroundWorkObserver.postStatusMessage(title, subTitle) + setForegroundNotification( + title = title, + body = subTitle, + silent = true, + progress = status.current to status.total, + canCancel = true, + ) + } + } + } + protected fun handleStart() { val title = applicationContext.getString(LocalizationR.string.notification_started_title) val subTitle = applicationContext.getString(LocalizationR.string.notification_running_sub_title) diff --git a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/di/SdaiWorkerFactory.kt b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/di/SdaiWorkerFactory.kt index ce4110d34..0eae8a70b 100644 --- a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/di/SdaiWorkerFactory.kt +++ b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/di/SdaiWorkerFactory.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCase import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveBonsaiProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCase @@ -18,82 +19,26 @@ import com.shifthackz.aisdv1.work.task.ImageToImageTask import com.shifthackz.aisdv1.work.task.TextToImageTask /** - * Coordinates `SdaiWorkerFactory` behavior in the SDAI background work feature layer. + * WorkManager factory that wires SDAI generation workers with app services. * - * @author Dmitriy Moroz + * The factory keeps foreground notification dependencies and local-generation + * progress observers in one place so background txt2img/img2img tasks use the + * same runtime status pipeline as foreground generation. */ class SdaiWorkerFactory( - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `pushNotificationManager` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val pushNotificationManager: PushNotificationManager, - /** - * Exposes the `preferenceManager` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val preferenceManager: PreferenceManager, - /** - * Exposes the `textToImageUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val textToImageUseCase: TextToImageUseCase, - /** - * Exposes the `imageToImageUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val imageToImageUseCase: ImageToImageUseCase, - /** - * Exposes the `observeHordeProcessStatusUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, - /** - * Exposes the `observeLocalDiffusionProcessStatusUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, - /** - * Exposes the `interruptGenerationUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ + private val observeBonsaiProcessStatusUseCase: ObserveBonsaiProcessStatusUseCase, private val interruptGenerationUseCase: InterruptGenerationUseCase, - /** - * Exposes the `fileProviderDescriptor` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val fileProviderDescriptor: FileProviderDescriptor, - /** - * Exposes the `activityIntentProvider` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val activityIntentProvider: ActivityIntentProvider, ) : WorkerFactory() { - /** - * Creates the SDAI value produced by `createWorker`. - * - * @param appContext app context value consumed by the API. - * @param workerClassName worker class name value consumed by the API. - * @param workerParameters worker parameters value consumed by the API. - * @return Result produced by `createWorker`. - * @author Dmitriy Moroz - */ override fun createWorker( appContext: Context, workerClassName: String, @@ -110,6 +55,7 @@ class SdaiWorkerFactory( textToImageUseCase = textToImageUseCase, observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase = observeBonsaiProcessStatusUseCase, interruptGenerationUseCase = interruptGenerationUseCase, fileProviderDescriptor = fileProviderDescriptor, ) @@ -124,6 +70,7 @@ class SdaiWorkerFactory( imageToImageUseCase = imageToImageUseCase, observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase = observeBonsaiProcessStatusUseCase, interruptGenerationUseCase = interruptGenerationUseCase, fileProviderDescriptor = fileProviderDescriptor, ) diff --git a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/mappers/TextToImagePayloadMappers.kt b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/mappers/TextToImagePayloadMappers.kt index 87a966c80..b90efd362 100644 --- a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/mappers/TextToImagePayloadMappers.kt +++ b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/mappers/TextToImagePayloadMappers.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.work.mappers +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.OpenAiModel import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize @@ -14,9 +15,11 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json /** - * Exposes the `payloadJson` value used by the SDAI background work feature layer. + * JSON codec for background txt2img payload handoff. * - * @author Dmitriy Moroz + * Unknown keys stay ignored so queued work from an older app version can still + * be read after new provider fields, such as local runtime backend choices, are + * added to the DTO. */ @OptIn(ExperimentalSerializationApi::class) internal val payloadJson = Json { @@ -25,10 +28,7 @@ internal val payloadJson = Json { } /** - * Converts SDAI data with `toByteArray`. - * - * @return Result produced by `toByteArray`. - * @author Dmitriy Moroz + * Serializes the domain payload for WorkManager cache storage. */ internal fun TextToImagePayload.toByteArray(): ByteArray { return payloadJson @@ -37,10 +37,7 @@ internal fun TextToImagePayload.toByteArray(): ByteArray { } /** - * Converts SDAI data with `toTextToImagePayload`. - * - * @return Result produced by `toTextToImagePayload`. - * @author Dmitriy Moroz + * Restores a cached WorkManager payload into the domain request model. */ internal fun ByteArray.toTextToImagePayload(): TextToImagePayload? { return runCatching { @@ -51,180 +48,42 @@ internal fun ByteArray.toTextToImagePayload(): TextToImagePayload? { } /** - * Carries `TextToImagePayloadDto` data through the SDAI background work feature layer. + * Stable serialized form for txt2img background work. * - * @author Dmitriy Moroz + * The DTO stores provider-specific options as strings so enum aliases can be + * parsed with backwards-compatible defaults when app versions change. */ @Serializable private data class TextToImagePayloadDto( - /** - * Exposes the `prompt` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `samplingSteps` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val samplingSteps: Int, - /** - * Exposes the `cfgScale` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val cfgScale: Float, - /** - * Exposes the `width` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val width: Int, - /** - * Exposes the `height` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val height: Int, - /** - * Exposes the `restoreFaces` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val restoreFaces: Boolean, - /** - * Exposes the `seed` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val seed: String, - /** - * Exposes the `subSeed` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val subSeed: String, - /** - * Exposes the `subSeedStrength` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val subSeedStrength: Float, - /** - * Exposes the `sampler` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val sampler: String, - /** - * Exposes the `scheduler` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val scheduler: String? = null, - /** - * Exposes the `nsfw` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val nsfw: Boolean, - /** - * Exposes the `batchCount` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val batchCount: Int, - /** - * Exposes the `style` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val style: String?, - /** - * Exposes the `quality` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val quality: String?, - /** - * Exposes the `openAiModel` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val openAiModel: String?, - /** - * Exposes the `stabilityAiClipGuidance` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val stabilityAiClipGuidance: String?, - /** - * Exposes the `stabilityAiStylePreset` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val stabilityAiStylePreset: String?, - /** - * Exposes the `aDetailer` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val aDetailer: ADetailerConfigDto = ADetailerConfigDto(), - /** - * Exposes the `hires` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val hires: HiresConfigDto = HiresConfigDto(), - /** - * Exposes the `forgeModules` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val forgeModules: List = emptyList(), - /** - * Exposes the `falAiModel` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val falAiModel: String? = null, - /** - * Exposes the `falAiImageSize` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val falAiImageSize: String? = null, - /** - * Exposes the `falAiAcceleration` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val falAiAcceleration: String? = null, - /** - * Exposes the `sdxlBackend` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ val sdxlBackend: String? = null, - /** - * Exposes the `falAiSyncMode` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ + val bonsaiBackend: String? = null, val falAiSyncMode: Boolean = false, ) { - /** - * Converts SDAI data with `toPayload`. - * - * @author Dmitriy Moroz - */ fun toPayload(): TextToImagePayload = TextToImagePayload( prompt = prompt, negativePrompt = negativePrompt, @@ -252,21 +111,11 @@ private data class TextToImagePayloadDto( falAiImageSize = FalAiImageSize.parse(falAiImageSize), falAiAcceleration = FalAiAcceleration.parse(falAiAcceleration), sdxlBackend = SdxlBackend.parse(sdxlBackend), + bonsaiBackend = BonsaiBackend.parse(bonsaiBackend), falAiSyncMode = falAiSyncMode, ) - /** - * Provides the `companion object` singleton used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ companion object { - /** - * Executes the `from` step in the SDAI background work feature layer. - * - * @param payload generation payload used by the operation. - * @author Dmitriy Moroz - */ fun from(payload: TextToImagePayload): TextToImagePayloadDto = TextToImagePayloadDto( prompt = payload.prompt, negativePrompt = payload.negativePrompt, @@ -294,16 +143,14 @@ private data class TextToImagePayloadDto( falAiImageSize = payload.falAiImageSize.key, falAiAcceleration = payload.falAiAcceleration.key, sdxlBackend = payload.sdxlBackend.key, + bonsaiBackend = payload.bonsaiBackend.key, falAiSyncMode = payload.falAiSyncMode, ) } } /** - * Executes the `function` step in the SDAI background work feature layer. - * - * @return Result produced by `function`. - * @author Dmitriy Moroz + * Parses enum names from cached payload fields while keeping unknown values nullable. */ internal inline fun > String?.parseEnumOrNull(): T? { return this?.let { value -> enumValues().firstOrNull { it.name == value } } diff --git a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/ImageToImageTask.kt b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/ImageToImageTask.kt index 2dc155e8e..a4a6676fb 100644 --- a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/ImageToImageTask.kt +++ b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/ImageToImageTask.kt @@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCase import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveBonsaiProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase import com.shifthackz.aisdv1.work.Constants @@ -20,9 +21,11 @@ import kotlinx.coroutines.CancellationException import java.io.File /** - * Coordinates `ImageToImageTask` behavior in the SDAI background work feature layer. + * Background img2img worker. * - * @author Dmitriy Moroz + * The task reads the cached serialized payload, subscribes to provider progress, + * runs the image-to-image use case, and publishes the result or failure through + * the common generation notification flow. */ internal class ImageToImageTask( context: Context, @@ -32,24 +35,10 @@ internal class ImageToImageTask( preferenceManager: PreferenceManager, observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase: ObserveBonsaiProcessStatusUseCase, interruptGenerationUseCase: InterruptGenerationUseCase, - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `imageToImageUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val imageToImageUseCase: ImageToImageUseCase, - /** - * Exposes the `fileProviderDescriptor` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val fileProviderDescriptor: FileProviderDescriptor, ) : CoreGenerationWorker( context = context, @@ -60,6 +49,7 @@ internal class ImageToImageTask( backgroundWorkObserver = backgroundWorkObserver, observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase = observeBonsaiProcessStatusUseCase, interruptGenerationUseCase = interruptGenerationUseCase, ) { @@ -88,8 +78,7 @@ internal class ImageToImageTask( return Result.failure() } - listenHordeStatus() - listenLocalDiffusionStatus() + listenSourceStatus() handleProcess() runCatching { imageToImageUseCase(payload) } diff --git a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/TextToImageTask.kt b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/TextToImageTask.kt index fc6269eb8..e77aa9933 100644 --- a/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/TextToImageTask.kt +++ b/feature/work/src/androidMain/kotlin/com/shifthackz/aisdv1/work/task/TextToImageTask.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.core.notification.PushNotificationManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveBonsaiProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCase @@ -22,9 +23,11 @@ import kotlinx.coroutines.CancellationException import java.io.File /** - * Coordinates `TextToImageTask` behavior in the SDAI background work feature layer. + * Background txt2img worker. * - * @author Dmitriy Moroz + * The task reads the cached generation payload, starts provider-specific + * progress observation, executes txt2img generation, and reports success, + * cancellation, or failure through the foreground notification path. */ internal class TextToImageTask( context: Context, @@ -33,30 +36,11 @@ internal class TextToImageTask( activityIntentProvider: ActivityIntentProvider, observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase: ObserveBonsaiProcessStatusUseCase, interruptGenerationUseCase: InterruptGenerationUseCase, - /** - * Exposes the `preferenceManager` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val preferenceManager: PreferenceManager, - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `textToImageUseCase` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val textToImageUseCase: TextToImageUseCase, - /** - * Exposes the `fileProviderDescriptor` value used by the SDAI background work feature layer. - * - * @author Dmitriy Moroz - */ private val fileProviderDescriptor: FileProviderDescriptor, ) : CoreGenerationWorker( context = context, @@ -67,6 +51,7 @@ internal class TextToImageTask( backgroundWorkObserver = backgroundWorkObserver, observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + observeBonsaiProcessStatusUseCase = observeBonsaiProcessStatusUseCase, interruptGenerationUseCase = interruptGenerationUseCase, ) { @@ -117,8 +102,7 @@ internal class TextToImageTask( return Result.failure() } - listenHordeStatus() - listenLocalDiffusionStatus() + listenSourceStatus() handleProcess() runCatching { textToImageUseCase(payload) } diff --git a/presentation/src/androidMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.android.kt b/presentation/src/androidMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.android.kt index 80e8d4434..0984a7b1d 100644 --- a/presentation/src/androidMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.android.kt +++ b/presentation/src/androidMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.android.kt @@ -51,5 +51,17 @@ internal actual fun isLocalGenerationSetupAvailable(): Boolean = Build.VERSION.SDK_INT < Build.VERSION_CODES.R || Environment.isExternalStorageManager() internal actual fun isServerSourceAvailableOnPlatform(source: ServerSource): Boolean = - source != ServerSource.LOCAL_APPLE_CORE_ML && - source != ServerSource.LOCAL_APPLE_BONSAI + isServerSourceAvailableOnAndroid(source, Build.SUPPORTED_64_BIT_ABIS) + +internal fun isServerSourceAvailableOnAndroid( + source: ServerSource, + supported64BitAbis: Array?, +): Boolean = when (source) { + ServerSource.LOCAL_APPLE_BONSAI -> isAndroidBonsaiSupportedInPrinciple(supported64BitAbis) + else -> true +} + +internal fun isAndroidBonsaiSupportedInPrinciple(supported64BitAbis: Array?): Boolean = + supported64BitAbis.orEmpty().any { abi -> abi == ANDROID_BONSAI_ABI } + +private const val ANDROID_BONSAI_ABI = "arm64-v8a" diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt index dcee1cce6..2d3e39f1b 100644 --- a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.presentation.screen.gallery.detail import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.usecase.caching.GetLastResultFromCacheUseCase @@ -42,6 +43,7 @@ class GalleryDetailViewModelTest { override val buildNumber = 1 override val version = BuildVersion() override val type = BuildType.FULL + override val platform: Platform = Platform.ANDROID } private val getGenerationResultUseCase = mockk() private val getAllGalleryUseCase = mockk() diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt index c0d5c3f0c..037e0f587 100644 --- a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion import com.shifthackz.aisdv1.core.common.links.LinksProvider +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.domain.entity.NetworkUsage import com.shifthackz.aisdv1.domain.entity.ServerSource @@ -264,6 +265,7 @@ private object TestBuildInfoProvider : BuildInfoProvider { override val buildNumber: Int = 5598 override val version: BuildVersion = BuildVersion() override val type: BuildType = BuildType.FOSS + override val platform: Platform = Platform.ANDROID override fun toString(): String = "test-version" } diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatformTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatformTest.kt new file mode 100644 index 000000000..88cc3aa62 --- /dev/null +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatformTest.kt @@ -0,0 +1,75 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.platform + +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion +import com.shifthackz.aisdv1.core.common.platform.Platform +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.ServerSourceReadiness +import com.shifthackz.aisdv1.presentation.model.displayName +import com.shifthackz.aisdv1.presentation.model.readinessFor +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test + +class ServerSetupPlatformTest { + + @Test + fun `given Android build, expected Core ML excluded by allowed platforms`() { + assertFalse(ServerSource.LOCAL_APPLE_CORE_ML in TestAndroidBuildInfoProvider.allowedModes) + } + + @Test + fun `given Android device has arm64 ABI, expected Bonsai provider visible`() { + assertTrue(isAndroidBonsaiSupportedInPrinciple(arrayOf("armeabi-v7a", "arm64-v8a"))) + } + + @Test + fun `given Android device lacks arm64 ABI, expected Bonsai provider hidden`() { + assertFalse(isAndroidBonsaiSupportedInPrinciple(arrayOf("x86_64"))) + } + + @Test + fun `given platform reports arm64 ABI, expected Bonsai provider visible in setup list`() { + assertTrue( + isServerSourceAvailableOnAndroid( + source = ServerSource.LOCAL_APPLE_BONSAI, + supported64BitAbis = arrayOf("arm64-v8a"), + ), + ) + } + + @Test + fun `given Bonsai provider, expected platform specific display name`() { + assertEquals( + "Local Diffusion PrismML Bonsai", + ServerSource.LOCAL_APPLE_BONSAI.displayName(Platform.ANDROID), + ) + assertEquals( + "Silicon Diffusion PrismML Bonsai", + ServerSource.LOCAL_APPLE_BONSAI.displayName(Platform.IOS), + ) + } + + @Test + fun `given Bonsai provider, expected platform specific readiness`() { + assertEquals( + ServerSourceReadiness.EXPERIMENTAL, + ServerSource.LOCAL_APPLE_BONSAI.readinessFor(Platform.ANDROID), + ) + assertEquals( + ServerSourceReadiness.BETA, + ServerSource.LOCAL_APPLE_BONSAI.readinessFor(Platform.IOS), + ) + } +} + +private object TestAndroidBuildInfoProvider : BuildInfoProvider { + override val isDebug: Boolean = true + override val buildNumber: Int = 0 + override val version: BuildVersion = BuildVersion() + override val type: BuildType = BuildType.FULL + override val platform: Platform = Platform.ANDROID +} diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/storageusage/StorageUsageViewModelTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/storageusage/StorageUsageViewModelTest.kt index 42d280480..5ce97ba05 100644 --- a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/storageusage/StorageUsageViewModelTest.kt +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/storageusage/StorageUsageViewModelTest.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.presentation.screen.storageusage import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase @@ -239,6 +240,7 @@ private object TestBuildInfoProvider : BuildInfoProvider { override val buildNumber: Int = 5598 override val version: BuildVersion = BuildVersion() override val type: BuildType = BuildType.FULL + override val platform: Platform = Platform.ANDROID } /** diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageStateTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageStateTest.kt index bf05fca8c..d91a0fcf9 100644 --- a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageStateTest.kt +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageStateTest.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.presentation.screen.txt2img import com.shifthackz.aisdv1.core.validation.ValidationResult import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.ADetailerConfig +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.ForgeModule import com.shifthackz.aisdv1.domain.entity.HiresConfig import com.shifthackz.aisdv1.domain.entity.Scheduler @@ -126,6 +127,7 @@ class TextToImageStateTest { cfgScale = 7.5f, batchCount = 4, nsfw = true, + bonsaiBackend = BonsaiBackend.VULKAN, ).mapToPayload() Assert.assertEquals(30, payload.samplingSteps) @@ -134,6 +136,17 @@ class TextToImageStateTest { Assert.assertEquals(DEFAULT_SIZE, payload.height) Assert.assertEquals(1, payload.batchCount) Assert.assertTrue(payload.nsfw) + Assert.assertEquals(BonsaiBackend.VULKAN, payload.bonsaiBackend) + } + + @Test + fun `given non Bonsai state with Vulkan Bonsai backend, expected payload uses Auto backend`() { + val payload = TextToImageState( + mode = ServerSource.AUTOMATIC1111, + bonsaiBackend = BonsaiBackend.VULKAN, + ).mapToPayload() + + Assert.assertEquals(BonsaiBackend.AUTO, payload.bonsaiBackend) } @Test diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/app/AppScaffold.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/app/AppScaffold.kt index ec59a122a..2aa3eac7e 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/app/AppScaffold.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/app/AppScaffold.kt @@ -107,6 +107,7 @@ internal fun AppScaffold( items = appDrawerItems( currentRoute = currentRoute, settings = settings, + buildInfoProvider = buildInfoProvider, router = router, ), header = { @@ -244,9 +245,10 @@ private fun appBottomNavigationItems( private fun appDrawerItems( currentRoute: AppRoute, settings: Settings, + buildInfoProvider: BuildInfoProvider, router: RootAppRouter, ): List { - val sourceName = settings.source.getName() + val sourceName = settings.source.getName(buildInfoProvider.platform) return buildList { add( DrawerSheetItem( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/model/ServerSourceUi.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/model/ServerSourceUi.kt new file mode 100644 index 000000000..84372c478 --- /dev/null +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/model/ServerSourceUi.kt @@ -0,0 +1,78 @@ +package com.shifthackz.aisdv1.presentation.model + +import com.shifthackz.aisdv1.core.common.platform.Platform +import com.shifthackz.aisdv1.core.localization.Localization +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.ServerSourceReadiness + +/** + * Returns whether the provider should be selectable on the current platform. + */ +internal fun ServerSource.isAvailableOn(platform: Platform): Boolean = + platform in allowedPlatforms + +/** + * Picks the user-facing readiness badge for the current platform. + */ +internal fun ServerSource.readinessFor(platform: Platform): ServerSourceReadiness = + readiness[platform] + +/** + * Full provider name for setup and source selection screens. + * + * Bonsai keeps one domain source id for persistence, but the label is platform + * aware: iOS shows the Silicon Diffusion wording and Android shows Local + * Diffusion while preserving the same underlying provider configuration. + */ +internal fun ServerSource.displayName(platform: Platform): String = when (this) { + ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own") + ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") + ServerSource.HORDE -> Localization.string("srv_type_horde") + ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face") + ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") + ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") + ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") + ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local") + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe") + ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl") + ServerSource.LOCAL_APPLE_CORE_ML -> "Silicon Diffusion Core ML" + ServerSource.LOCAL_APPLE_BONSAI -> Localization.string( + when (platform) { + Platform.ANDROID -> "srv_type_bonsai_android" + Platform.IOS -> "srv_type_bonsai" + }, + ) +} + +/** + * Short provider label used where the full setup name would be too long. + */ +internal fun ServerSource.shortDisplayName(): String = when (this) { + ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own_short") + ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") + ServerSource.HORDE -> Localization.string("srv_type_horde_short") + ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face_short") + ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") + ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") + ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") + ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") + ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") + ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" + ServerSource.LOCAL_APPLE_BONSAI -> Localization.string("srv_type_bonsai_short") +} + +/** + * Compact label for chips and small controls. + */ +internal fun ServerSource.compactDisplayName(platform: Platform): String = when (this) { + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + ServerSource.LOCAL_STABLE_DIFFUSION_CPP, + ServerSource.LOCAL_APPLE_CORE_ML, + -> shortDisplayName() + + else -> displayName(platform) +} diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt index 3ec96c52d..1139b23b7 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt @@ -50,6 +50,7 @@ import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.extensions.shimmer import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.model.asUiText @@ -57,6 +58,7 @@ import com.shifthackz.aisdv1.core.mvi.MviComponent import com.shifthackz.aisdv1.feature.benchmark.BenchmarkAccelerationStatus import com.shifthackz.aisdv1.domain.entity.SdxlBackend import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.ServerSourceType import com.shifthackz.aisdv1.feature.benchmark.BenchmarkAccelerator import com.shifthackz.aisdv1.feature.benchmark.BenchmarkDeviceInfo import com.shifthackz.aisdv1.feature.benchmark.BenchmarkPlatform @@ -65,6 +67,8 @@ import com.shifthackz.aisdv1.feature.benchmark.BenchmarkProviderRecommendation import com.shifthackz.aisdv1.feature.benchmark.BenchmarkResult import com.shifthackz.aisdv1.feature.benchmark.accelerationCapabilities import com.shifthackz.aisdv1.presentation.di.initKoin +import com.shifthackz.aisdv1.presentation.model.compactDisplayName +import com.shifthackz.aisdv1.presentation.model.isAvailableOn import com.shifthackz.aisdv1.presentation.navigation.router.BenchmarkRouter import com.shifthackz.aisdv1.presentation.theme.global.persistentBottomBarWindowInsets import com.shifthackz.aisdv1.presentation.theme.global.persistentTopAppBarWindowInsets @@ -462,6 +466,7 @@ private fun ProviderRecommendationsSection( providers.forEach { recommendation -> ProviderRecommendationBlock( recommendation = recommendation, + platform = deviceInfo?.platform?.toPlatform() ?: Platform.ANDROID, pending = recommendations.isEmpty(), ) } @@ -473,13 +478,14 @@ private fun ProviderRecommendationsSection( @Composable private fun ProviderRecommendationBlock( recommendation: BenchmarkProviderRecommendation, + platform: Platform, pending: Boolean, ) { Column(verticalArrangement = Arrangement.spacedBy(8.dp)) { ProviderRecommendationTitle( text = Localization.string( "benchmark_provider_recommended_settings", - recommendation.provider.displayName(), + recommendation.provider.compactDisplayName(platform), ), ) when { @@ -742,21 +748,17 @@ private fun BenchmarkDeviceInfo.deviceName(): String = .joinToString(" ") .ifBlank { Localization.string("benchmark_unknown") } -private fun BenchmarkDeviceInfo.localBenchmarkProviders(): List = when (platform) { - BenchmarkPlatform.ANDROID -> listOf( - ServerSource.LOCAL_MICROSOFT_ONNX, - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, - ServerSource.LOCAL_STABLE_DIFFUSION_CPP, - ) - BenchmarkPlatform.IOS -> listOf( - ServerSource.LOCAL_APPLE_CORE_ML, - ServerSource.LOCAL_APPLE_BONSAI, - ) - BenchmarkPlatform.UNKNOWN -> emptyList() -} +private fun BenchmarkDeviceInfo.localBenchmarkProviders(): List = + platform.toPlatformOrNull()?.let { targetPlatform -> + ServerSource.entries.filter { source -> + source.type == ServerSourceType.LOCAL && + source.isAvailableOn(targetPlatform) + } + }.orEmpty() private fun BenchmarkAccelerator.displayName(): String = when (this) { BenchmarkAccelerator.VULKAN -> "Vulkan backend" + BenchmarkAccelerator.BONSAI_VULKAN -> "Bonsai Vulkan compute" BenchmarkAccelerator.OPEN_CL -> "OpenCL backend" BenchmarkAccelerator.NNAPI -> "NNAPI delegate" BenchmarkAccelerator.METAL -> "Metal" @@ -787,24 +789,17 @@ private fun BenchmarkAccelerationStatus?.contentColor(): Color = when (this) { else -> Color.White } -private fun ServerSource.displayName(): String = when (this) { - ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own_short") - ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") - ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") - ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" - ServerSource.LOCAL_APPLE_BONSAI -> "Silicon Diffusion PrismML Bonsai" - ServerSource.HORDE -> Localization.string("srv_type_horde_short") - ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face_short") - ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") - ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") - ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") - ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") -} - private fun SdxlBackend.displayName(): String = displayName +private fun BenchmarkPlatform.toPlatform(): Platform = + toPlatformOrNull() ?: Platform.ANDROID + +private fun BenchmarkPlatform.toPlatformOrNull(): Platform? = when (this) { + BenchmarkPlatform.ANDROID -> Platform.ANDROID + BenchmarkPlatform.IOS -> Platform.IOS + BenchmarkPlatform.UNKNOWN -> null +} + private fun BenchmarkProviderIssue.localizedText(): String = when (this) { BenchmarkProviderIssue.PLATFORM_UNSUPPORTED -> Localization.string("benchmark_issue_platform_unsupported") diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt index 6c20287e6..93bdbec99 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt @@ -1,16 +1,19 @@ package com.shifthackz.aisdv1.presentation.screen.benchmark +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.mvi.BaseMviViewModel import com.shifthackz.aisdv1.core.mvi.EmptyEffect import com.shifthackz.aisdv1.feature.benchmark.BenchmarkManager +import com.shifthackz.aisdv1.feature.benchmark.BenchmarkPlatform import com.shifthackz.aisdv1.feature.benchmark.BenchmarkProviderIssue import com.shifthackz.aisdv1.feature.benchmark.BenchmarkProviderRecommendation import com.shifthackz.aisdv1.feature.benchmark.BenchmarkResult import com.shifthackz.aisdv1.presentation.navigation.router.BenchmarkRouter import com.shifthackz.aisdv1.domain.entity.SdxlBackend import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.model.compactDisplayName import kotlinx.coroutines.flow.catch import kotlinx.coroutines.withContext @@ -164,7 +167,10 @@ private fun BenchmarkResult.shareText(): String = buildString { appendLine("${Localization.string("benchmark_estimated_time_short")}: ${Localization.string("benchmark_seconds", estimatedTimeSeconds)}") providerRecommendations.forEach { recommendation -> appendLine() - appendLine("${recommendation.provider.displayName()} ${Localization.string("benchmark_recommendations")}") + appendLine( + "${recommendation.provider.compactDisplayName(deviceInfo.platform.toPlatform())} " + + Localization.string("benchmark_recommendations"), + ) if (recommendation.recommended) { appendLine("${Localization.string("benchmark_recommended_size")}: ${recommendation.width} x ${recommendation.height}") appendLine("${Localization.string("benchmark_recommended_steps")}: ${recommendation.samplingSteps}") @@ -208,24 +214,16 @@ private fun BenchmarkProviderIssue.localizedText(): String = when (this) { Localization.string("benchmark_issue_accelerator_api_not_available") } -private fun ServerSource.displayName(): String = when (this) { - ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own_short") - ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") - ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") - ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" - ServerSource.LOCAL_APPLE_BONSAI -> "Silicon Diffusion PrismML Bonsai" - ServerSource.HORDE -> Localization.string("srv_type_horde_short") - ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face_short") - ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") - ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") - ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") - ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") -} - private fun SdxlBackend.displayName(): String = displayName +private fun BenchmarkPlatform.toPlatform(): Platform = when (this) { + BenchmarkPlatform.ANDROID, + BenchmarkPlatform.UNKNOWN, + -> Platform.ANDROID + + BenchmarkPlatform.IOS -> Platform.IOS +} + private fun com.shifthackz.aisdv1.feature.benchmark.BenchmarkDeviceInfo.deviceName(): String = listOf(manufacturer, model) .filter(String::isNotBlank) diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageBodyContent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageBodyContent.kt index 68a9755ec..858b4db83 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageBodyContent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageBodyContent.kt @@ -40,6 +40,7 @@ import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.model.compactDisplayName import com.shifthackz.aisdv1.presentation.widget.scrollbar.verticalScrollbar @@ -186,7 +187,7 @@ internal fun ImageInputSection( ) { Column(modifier = Modifier.weight(1f)) { Text( - text = state.mode.displayName, + text = state.mode.compactDisplayName(state.platform), style = MaterialTheme.typography.titleSmall, fontWeight = FontWeight.W600, maxLines = 1, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt index 13eeee7ca..5cd2599e5 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt @@ -5,19 +5,14 @@ package com.shifthackz.aisdv1.presentation.screen.img2img import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.runtime.Composable import androidx.compose.ui.text.input.TextFieldValue -import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.domain.entity.AiGenerationResult -import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.modal.history.InputHistoryBottomSheet import com.shifthackz.aisdv1.presentation.model.ExtraType import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormEvent import kotlin.math.roundToInt - /** - * Converts SDAI data with `toImageToImageIntent`. - * - * @author Dmitriy Moroz + * Maps shared generation form events into img2img intents. */ internal fun GenerationInputFormEvent.toImageToImageIntent(): ImageToImageIntent? = when (this) { is GenerationInputFormEvent.EditTag -> ImageToImageIntent.ShowEditTag( @@ -52,6 +47,7 @@ internal fun GenerationInputFormEvent.toImageToImageIntent(): ImageToImageIntent is GenerationInputFormEvent.UpdateFalAiImageSize -> ImageToImageIntent.UpdateFalAiImageSize(value) is GenerationInputFormEvent.UpdateFalAiAcceleration -> ImageToImageIntent.UpdateFalAiAcceleration(value) is GenerationInputFormEvent.UpdateSdxlBackend -> null + is GenerationInputFormEvent.UpdateBonsaiBackend -> null is GenerationInputFormEvent.UpdateFalAiSyncMode -> ImageToImageIntent.UpdateFalAiSyncMode(value) is GenerationInputFormEvent.UpdateArliAiModel -> ImageToImageIntent.UpdateArliAiModel(value) is GenerationInputFormEvent.UpdateStabilityAiStyle -> ImageToImageIntent.UpdateStabilityAiStyle(value) @@ -63,27 +59,11 @@ internal fun GenerationInputFormEvent.toImageToImageIntent(): ImageToImageIntent GenerationInputFormEvent.OpenADetailerInstallInstructions -> ImageToImageIntent.OpenADetailerInstallInstructions } -/** - * Executes the `appendPromptTag` step in the SDAI presentation layer. - * - * @param tag tag value consumed by the API. - * @return Result produced by `appendPromptTag`. - * @author Dmitriy Moroz - */ internal fun String.appendPromptTag(tag: String): String = listOf(this, tag.trim()) .filter(String::isNotBlank) .joinToString(", ") -/** - * Executes the `flushPendingTaggedText` step in the SDAI presentation layer. - * - * @param state state rendered or processed by the component. - * @param promptChipTextFieldState prompt chip text field state value consumed by the API. - * @param negativePromptChipTextFieldState negative prompt chip text field state value consumed by the API. - * @param processIntent process intent value consumed by the API. - * @author Dmitriy Moroz - */ internal fun flushPendingTaggedText( state: ImageToImageState, promptChipTextFieldState: androidx.compose.runtime.MutableState, @@ -105,13 +85,6 @@ internal fun flushPendingTaggedText( ?.also { negativePromptChipTextFieldState.value = TextFieldValue("") } } -/** - * Renders the `GenerationHistoryDialog` UI for the SDAI presentation layer. - * - * @param onClose callback invoked by the component. - * @param onGenerationSelected callback invoked by the component. - * @author Dmitriy Moroz - */ @Composable internal fun GenerationHistoryDialog( onClose: () -> Unit, @@ -123,101 +96,23 @@ internal fun GenerationHistoryDialog( ) } -/** - * Defines the `ImageToImagePanel` contract for the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal sealed interface ImageToImagePanel { - /** - * Provides the `History` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object History : ImageToImagePanel - /** - * Carries `Embeddings` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class Embeddings( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, ) : ImageToImagePanel - /** - * Carries `Extras` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class Extras( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `type` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val type: ExtraType, ) : ImageToImagePanel } -/** - * Exposes the `AiGenerationResult` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal val AiGenerationResult.aspectRatio: Float get() = if (width > 0 && height > 0) width.toFloat() / height.toFloat() else 1f -/** - * Exposes the `ServerSource` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ -internal val ServerSource.displayName: String - get() = when (this) { - ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own") - ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") - ServerSource.HORDE -> Localization.string("srv_type_horde") - ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face") - ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") - ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") - ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") - ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") - ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") - ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" - ServerSource.LOCAL_APPLE_BONSAI -> "Silicon Diffusion PrismML Bonsai" - } - -/** - * Executes the `roundToString` step in the SDAI presentation layer. - * - * @return Result produced by `roundToString`. - * @author Dmitriy Moroz - */ internal fun Float.roundToString(): String { val rounded = (this * 100f).roundToInt() / 100f return rounded.toString() diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt index 12304fddc..f1e80d3dc 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt @@ -1,10 +1,12 @@ package com.shifthackz.aisdv1.presentation.screen.img2img import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.mvi.MviState import com.shifthackz.aisdv1.domain.entity.ADetailerConfig import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel @@ -24,320 +26,70 @@ import com.shifthackz.aisdv1.presentation.model.PromptTagEditRequest import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormState /** - * Carries `ImageToImageState` data through the SDAI presentation layer. + * Complete render state for img2img. * - * @author Dmitriy Moroz + * It extends the shared generation form state with source-image, inpaint, and + * result handling fields while keeping platform-aware local provider controls + * consistent with the txt2img screen. */ @Immutable data class ImageToImageState( - /** - * Exposes the `loadingConfiguration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val loadingConfiguration: Boolean = true, - /** - * Exposes the `imageBase64` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val imageBase64: String = "", - /** - * Exposes the `denoisingStrength` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val denoisingStrength: Float = DEFAULT_DENOISING_STRENGTH, - /** - * Exposes the `pickingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val pickingImage: Boolean = false, - /** - * Exposes the `generating` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val generating: Boolean = false, - /** - * Exposes the `savingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val savingImage: Boolean = false, - /** - * Exposes the `sharingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val sharingImage: Boolean = false, - /** - * Exposes the `inPaint` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val inPaint: ImageInPaintState = ImageInPaintState(), - /** - * Exposes the `promptValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val promptValidationError: String? = null, - /** - * Exposes the `error` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val error: String? = null, - /** - * Exposes the `message` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val message: String? = null, - /** - * Exposes the `screenModal` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val screenModal: GenerationModal = GenerationModal.None, - /** - * Exposes the `results` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val results: List = emptyList(), - /** - * Exposes the `editTag` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val editTag: PromptTagEditRequest? = null, - /** - * Exposes the `onBoardingDemo` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val onBoardingDemo: Boolean = false, - /** - * Exposes the `mode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ + val platform: Platform = Platform.ANDROID, override val mode: ServerSource = ServerSource.AUTOMATIC1111, - /** - * Exposes the `advancedToggleButtonVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val advancedToggleButtonVisible: Boolean = true, - /** - * Exposes the `advancedOptionsVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val advancedOptionsVisible: Boolean = false, - /** - * Exposes the `formPromptTaggedInput` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val formPromptTaggedInput: Boolean = false, - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val prompt: String = "", - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val negativePrompt: String = "", - /** - * Exposes the `width` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val width: String = DEFAULT_SIZE.toString(), - /** - * Exposes the `height` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val height: String = DEFAULT_SIZE.toString(), - /** - * Exposes the `samplingSteps` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val samplingSteps: Int = 20, - /** - * Exposes the `cfgScale` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val cfgScale: Float = 7f, - /** - * Exposes the `restoreFaces` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val restoreFaces: Boolean = false, - /** - * Exposes the `seed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val seed: String = "", - /** - * Exposes the `subSeed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val subSeed: String = "", - /** - * Exposes the `subSeedStrength` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val subSeedStrength: Float = 0f, - /** - * Exposes the `selectedSampler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedSampler: String = "", - /** - * Exposes the `selectedScheduler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedScheduler: Scheduler = Scheduler.AUTOMATIC, - /** - * Exposes the `availableForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val availableForgeModules: List = emptyList(), - /** - * Exposes the `selectedForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedForgeModules: List = emptyList(), - /** - * Exposes the `availableSamplers` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val availableSamplers: List = emptyList(), - /** - * Exposes the `selectedStylePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedStylePreset: StabilityAiStylePreset = StabilityAiStylePreset.NONE, - /** - * Exposes the `selectedClipGuidancePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedClipGuidancePreset: StabilityAiClipGuidance = StabilityAiClipGuidance.NONE, - /** - * Exposes the `openAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiModel: OpenAiModel = OpenAiModel.default, - /** - * Exposes the `openAiSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiSize: OpenAiSize = OpenAiSize.W1024_H1024, - /** - * Exposes the `openAiQuality` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiQuality: OpenAiQuality = OpenAiQuality.AUTO, - /** - * Exposes the `falAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiModel: FalAiModel = FalAiModel.defaultImageToImage, - /** - * Exposes the `falAiImageSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiImageSize: FalAiImageSize = FalAiImageSize.default, - /** - * Exposes the `falAiAcceleration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiAcceleration: FalAiAcceleration = FalAiAcceleration.default, - /** - * Exposes the `falAiSyncMode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiSyncMode: Boolean = false, override val sdxlBackend: SdxlBackend = SdxlBackend.AUTO, + override val bonsaiBackend: BonsaiBackend = BonsaiBackend.AUTO, + override val bonsaiBackendSelectionVisible: Boolean = false, override val arliAiModels: List = emptyList(), override val arliAiModel: String = "", - /** - * Exposes the `widthValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val widthValidationError: UiText? = null, - /** - * Exposes the `heightValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val heightValidationError: UiText? = null, - /** - * Exposes the `nsfw` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val nsfw: Boolean = false, - /** - * Exposes the `batchCount` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val batchCount: Int = 1, - /** - * Exposes the `hires` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val hires: HiresConfig = HiresConfig.DISABLED, - /** - * Exposes the `aDetailer` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailer: ADetailerConfig = ADetailerConfig.DISABLED, - /** - * Exposes the `aDetailerAvailable` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailerAvailable: Boolean = false, - /** - * Exposes the `aDetailerRefreshing` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailerRefreshing: Boolean = false, ) : MviState, GenerationInputFormState { @@ -372,10 +124,7 @@ data class ImageToImageState( } /** - * Converts SDAI data with `mapToPayload`. - * - * @param maskBase64 mask base64 value consumed by the API. - * @author Dmitriy Moroz + * Converts the current img2img state into the domain generation request. */ internal fun ImageToImageState.mapToPayload( maskBase64: String? = null, @@ -428,15 +177,5 @@ internal fun ImageToImageState.mapToPayload( arliAiModel = arliAiModel.takeIf { mode == ServerSource.ARLI_AI }.orEmpty(), ) -/** - * Exposes the `DEFAULT_SIZE` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal const val DEFAULT_SIZE = 512 -/** - * Exposes the `DEFAULT_DENOISING_STRENGTH` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal const val DEFAULT_DENOISING_STRENGTH = 0.75f diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt index 93dd7dbce..1debe52cd 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.screen.img2img import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.mvi.BaseMviViewModel @@ -39,175 +40,45 @@ import kotlinx.coroutines.flow.catch import kotlinx.coroutines.withContext /** - * Coordinates `ImageToImageViewModel` behavior in the SDAI presentation layer. + * View-model for the img2img screen. * - * @author Dmitriy Moroz + * It keeps source-image state, shared generation controls, provider + * configuration, benchmark gating, and platform image-picking actions + * coordinated for the image-to-image flow. */ class ImageToImageViewModel( - /** - * Exposes the `dispatchersProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val dispatchersProvider: DispatchersProvider, - /** - * Exposes the `getConfigurationUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getConfigurationUseCase: GetConfigurationUseCase, - /** - * Exposes the `getStableDiffusionSamplersUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase, - /** - * Exposes the `getForgeModulesUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getForgeModulesUseCase: GetForgeModulesUseCase, - /** - * Exposes the `fetchAndGetArliAiModelsUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val fetchAndGetArliAiModelsUseCase: FetchAndGetArliAiModelsUseCase, - /** - * Exposes the `isADetailerAvailableUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val isADetailerAvailableUseCase: IsADetailerAvailableUseCase, - /** - * Exposes the `getRandomImageUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getRandomImageUseCase: GetRandomImageUseCase, - /** - * Exposes the `imageToImageUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val imageToImageUseCase: ImageToImageUseCase, - /** - * Exposes the `saveGenerationResultUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveGenerationResultUseCase: SaveGenerationResultUseCase, - /** - * Exposes the `saveLastResultToCacheUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase, - /** - * Exposes the `interruptGenerationUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val interruptGenerationUseCase: InterruptGenerationUseCase, - /** - * Exposes the `observeHordeProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, - /** - * Exposes the `observeLocalDiffusionProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, - /** - * Exposes the `preferenceManager` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val preferenceManager: PreferenceManager, - /** - * Exposes the `backgroundTaskManager` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val backgroundTaskManager: BackgroundTaskManager, - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `wakeLockInterActor` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val wakeLockInterActor: WakeLockInterActor, - /** - * Exposes the `platformServices` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val platformServices: GenerationPlatformServices, - /** - * Exposes the `buildInfoProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val buildInfoProvider: BuildInfoProvider, - /** - * Exposes the `generationFormUpdateEvent` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val generationFormUpdateEvent: GenerationFormUpdateEvent, - /** - * Exposes the `dimensionValidator` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val dimensionValidator: DimensionValidator, - /** - * Exposes the `localGenerationBenchmarkGateProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val localGenerationBenchmarkGateProvider: () -> LocalGenerationBenchmarkGate, - /** - * Exposes the `imageSaver` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val imageSaver: ImageSaver, - /** - * Exposes the `imageSharer` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val imageSharer: ImageSharer, - /** - * Exposes the `router` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val router: ImageToImageRouter, - /** - * Exposes the `platformActions` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val platformActions: ImageToImagePlatformActions, - /** - * Exposes the `onError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val onError: (Throwable) -> Unit = {}, ) : BaseMviViewModel( - initialState = ImageToImageState(), + initialState = ImageToImageState( + platform = buildInfoProvider.platform, + bonsaiBackendSelectionVisible = buildInfoProvider.platform == Platform.ANDROID, + ), effectDispatcher = dispatchersProvider.immediate, ) { diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt index 3e7e21dba..155bd64f6 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.model.shortDisplayName /** * Resolves a localization key to the UiText format consumed by Settings rows. @@ -21,18 +22,4 @@ internal fun text(key: String): UiText = Localization.string(key).asUiText() * * @author Dmitriy Moroz */ -internal fun ServerSource.shortTitle(): String = when (this) { - ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own_short") - ServerSource.HORDE -> Localization.string("srv_type_horde_short") - ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face_short") - ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") - ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") - ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") - ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") - ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") - ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" - ServerSource.LOCAL_APPLE_BONSAI -> Localization.string("srv_type_bonsai_short") - ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") -} +internal fun ServerSource.shortTitle(): String = shortDisplayName() diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 6f1dbcb26..8763cdf3f 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -133,6 +133,7 @@ class ServerSetupViewModel( ) : BaseMviViewModel( initialState = ServerSetupState( showBackNavArrow = launchSource == LaunchSource.SETTINGS, + platform = buildInfoProvider.platform, allowedModes = buildInfoProvider.setupAllowedModes(), demoModeUrl = linksProvider.demoModeUrl, ), @@ -182,6 +183,7 @@ class ServerSetupViewModel( .map(HuggingFaceModel::alias) configuration.toServerSetupState( allowedModes = allowedModes, + platform = buildInfoProvider.platform, huggingFaceModels = models, localOnnxModels = onnxModels, localMediaPipeModels = mediaPipeModels, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt index 2985144d2..9236c835a 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt @@ -19,6 +19,7 @@ import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.domain.entity.LocalAiModel @@ -171,8 +172,7 @@ internal fun ServerSetupModal( } } -internal val ServerSource.icon: ImageVector - get() = when (this) { +internal fun ServerSource.icon(platform: Platform): ImageVector = when (this) { ServerSource.AUTOMATIC1111, ServerSource.SWARM_UI, -> Icons.Default.Computer @@ -191,11 +191,15 @@ internal val ServerSource.icon: ImageVector -> Icons.Default.Android ServerSource.LOCAL_APPLE_CORE_ML, - ServerSource.LOCAL_APPLE_BONSAI, -> BrandIcons.Apple + + ServerSource.LOCAL_APPLE_BONSAI -> when (platform) { + Platform.ANDROID -> Icons.Default.Android + Platform.IOS -> BrandIcons.Apple + } } -internal fun ServerSource.title(strings: ServerSetupStrings): String = when (this) { +internal fun ServerSource.title(strings: ServerSetupStrings, platform: Platform): String = when (this) { ServerSource.AUTOMATIC1111 -> strings.automaticTitle ServerSource.SWARM_UI -> strings.swarmTitle ServerSource.HORDE -> strings.hordeTitle @@ -208,10 +212,13 @@ internal fun ServerSource.title(strings: ServerSetupStrings): String = when (thi ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeTitle ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlTitle ServerSource.LOCAL_APPLE_CORE_ML -> strings.coreMlTitle - ServerSource.LOCAL_APPLE_BONSAI -> strings.bonsaiTitle + ServerSource.LOCAL_APPLE_BONSAI -> when (platform) { + Platform.ANDROID -> strings.bonsaiAndroidTitle + Platform.IOS -> strings.bonsaiTitle + } } -internal fun ServerSource.subtitle(strings: ServerSetupStrings): String = when (this) { +internal fun ServerSource.subtitle(strings: ServerSetupStrings, platform: Platform): String = when (this) { ServerSource.AUTOMATIC1111 -> strings.automaticSubtitle ServerSource.SWARM_UI -> strings.swarmSubtitle ServerSource.HORDE -> strings.hordeSubtitle @@ -224,7 +231,10 @@ internal fun ServerSource.subtitle(strings: ServerSetupStrings): String = when ( ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeSubtitle ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlSubtitle ServerSource.LOCAL_APPLE_CORE_ML -> strings.coreMlSubtitle - ServerSource.LOCAL_APPLE_BONSAI -> strings.bonsaiSubtitle + ServerSource.LOCAL_APPLE_BONSAI -> when (platform) { + Platform.ANDROID -> strings.bonsaiAndroidSubtitle + Platform.IOS -> strings.bonsaiSubtitle + } } internal fun ServerSetupState.ValidationError.message( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt index 5a67b0fb1..51ce9effe 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt @@ -98,8 +98,10 @@ data class ServerSetupStrings( val sdxlSubtitle: String = Localization.string("hint_sdxl_sub_title"), val coreMlTitle: String = "Silicon Diffusion Core ML", val coreMlSubtitle: String = Localization.string("hint_core_ml_sub_title"), - val bonsaiTitle: String = "Silicon Diffusion PrismML Bonsai", + val bonsaiTitle: String = Localization.string("hint_bonsai_title"), + val bonsaiAndroidTitle: String = Localization.string("hint_bonsai_android_title"), val bonsaiSubtitle: String = Localization.string("hint_bonsai_sub_title"), + val bonsaiAndroidSubtitle: String = Localization.string("hint_bonsai_android_sub_title"), val localWarning: String = Localization.string("hint_local_diffusion_warning"), val localCustomSwitch: String = Localization.string("model_local_custom_switch"), val localPermissionHeader: String = Localization.string("model_local_permission_header"), diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt index 9d3fe4cf5..4720a9ff5 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt @@ -26,6 +26,8 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.component.SwitchRow import com.shifthackz.aisdv1.presentation.screen.setup.component.isCustom import com.shifthackz.aisdv1.presentation.screen.setup.component.message +import com.shifthackz.aisdv1.presentation.screen.setup.component.subtitle +import com.shifthackz.aisdv1.presentation.screen.setup.component.title import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.ArliAiForm import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.Automatic1111Form @@ -133,20 +135,8 @@ internal fun LocalGenerationForm( verticalArrangement = Arrangement.spacedBy(8.dp), ) { FormTitle( - title = when (state.mode) { - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeTitle - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlTitle - ServerSource.LOCAL_APPLE_CORE_ML -> strings.coreMlTitle - ServerSource.LOCAL_APPLE_BONSAI -> strings.bonsaiTitle - else -> strings.localDiffusionTitle - }, - subtitle = when (state.mode) { - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeSubtitle - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlSubtitle - ServerSource.LOCAL_APPLE_CORE_ML -> strings.coreMlSubtitle - ServerSource.LOCAL_APPLE_BONSAI -> strings.bonsaiSubtitle - else -> strings.localDiffusionSubtitle - }, + title = state.mode.title(strings, state.platform), + subtitle = state.mode.subtitle(strings, state.platform), ) HintText(text = strings.localWarning) if ( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt index b133466b8..e110f14bd 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.presentation.screen.setup.mappers import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.model.isAvailableOn /** * Flavor-aware provider list before platform-specific availability filtering is applied. @@ -9,4 +10,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource val BuildInfoProvider.allowedModes: List get() = ServerSource .entries - .filter { it.allowedInBuilds.contains(type) } + .filter { source -> + source.allowedInBuilds.contains(type) && + source.isAvailableOn(platform) + } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt index d0843b735..eed09ac39 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.screen.setup.model import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.mvi.MviState import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.DownloadState @@ -24,6 +25,7 @@ data class ServerSetupState( val showBackNavArrow: Boolean = false, val step: Step = Step.SOURCE, val mode: ServerSource = ServerSource.AUTOMATIC1111, + val platform: Platform = Platform.ANDROID, val allowedModes: List = remoteSetupSources, val sourceSearchQuery: String = "", val sourceTypeFilter: ServerSourceType? = null, @@ -344,6 +346,7 @@ data class ServerSetupState( fun Configuration.toServerSetupState( allowedModes: List, + platform: Platform, huggingFaceModels: List, localOnnxModels: List = emptyList(), localMediaPipeModels: List = emptyList(), @@ -370,6 +373,7 @@ fun Configuration.toServerSetupState( loadingConfiguration = false, showBackNavArrow = showBackNavArrow, mode = safeMode, + platform = platform, allowedModes = allowedModes, allowLocalCustomModels = allowLocalCustomModels, serverUrl = serverUrl, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatform.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatform.kt index 7f1167cb9..81f6459b3 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatform.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupPlatform.kt @@ -23,6 +23,9 @@ internal expect fun ServerSetupLocalPathPickerButton( internal expect fun isLocalGenerationSetupAvailable(): Boolean /** - * Filters providers whose runtime cannot work on the current target. + * Filters providers that are impossible on the current target, such as a missing runtime ABI. + * + * Device capability and performance checks belong to benchmark/setup validation rather than the + * first provider list, so users can still discover providers supported in principle. */ internal expect fun isServerSourceAvailableOnPlatform(source: ServerSource): Boolean diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt index 7e71436d7..6068723c0 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.ServerSourceReadiness import com.shifthackz.aisdv1.domain.entity.ServerSourceType import com.shifthackz.aisdv1.presentation.navigation.router.ServerSetupRouter +import com.shifthackz.aisdv1.presentation.model.readinessFor import com.shifthackz.aisdv1.presentation.screen.setup.model.HORDE_DEFAULT_API_KEY import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupEffect import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupIntent @@ -65,6 +66,7 @@ internal class ServerSetupIntentProcessor( is ServerSetupIntent.UpdateSourceTypeFilter -> updateState { if ( it.allowedModes.hasSourceFilterMatch( + state = it, type = intent.type, readinessFilters = it.sourceReadinessFilters, tags = it.sourceTagFilters, @@ -85,6 +87,7 @@ internal class ServerSetupIntentProcessor( } if ( state.allowedModes.hasSourceFilterMatch( + state = state, type = state.sourceTypeFilter, readinessFilters = nextFilters, tags = state.sourceTagFilters, @@ -105,6 +108,7 @@ internal class ServerSetupIntentProcessor( } if ( state.allowedModes.hasSourceFilterMatch( + state = state, type = state.sourceTypeFilter, readinessFilters = state.sourceReadinessFilters, tags = nextTags, @@ -244,18 +248,21 @@ internal class ServerSetupIntentProcessor( } private fun List.hasSourceFilterMatch( + state: ServerSetupState, type: ServerSourceType?, readinessFilters: Set, tags: Set, ): Boolean = any { source -> + val sourceReadiness = source.readinessFor(state.platform) (type == null || source.type == type) && - (readinessFilters.isEmpty() || source.readiness in readinessFilters) && + (readinessFilters.isEmpty() || sourceReadiness in readinessFilters) && source.featureTags.containsAll(tags) } private fun ServerSetupState.withMatchingFilterSource(): ServerSetupState { if ( listOf(mode).hasSourceFilterMatch( + state = this, type = sourceTypeFilter, readinessFilters = sourceReadinessFilters, tags = sourceTagFilters, @@ -265,6 +272,7 @@ private fun ServerSetupState.withMatchingFilterSource(): ServerSetupState { } val firstMatchingSource = allowedModes.firstOrNull { source -> listOf(source).hasSourceFilterMatch( + state = this, type = sourceTypeFilter, readinessFilters = sourceReadinessFilters, tags = sourceTagFilters, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceFiltersBottomSheet.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceFiltersBottomSheet.kt index 8a22b175c..6d30caeee 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceFiltersBottomSheet.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceFiltersBottomSheet.kt @@ -25,9 +25,10 @@ import com.shifthackz.aisdv1.domain.entity.FeatureTag import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.ServerSourceReadiness import com.shifthackz.aisdv1.domain.entity.ServerSourceType +import com.shifthackz.aisdv1.presentation.model.readinessFor +import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupState -import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi /** @@ -55,7 +56,7 @@ internal fun SourceFiltersBottomSheet( .distinct() .sortedBy(FeatureTag::ordinal) val availableReadiness = state.allowedModes - .map(ServerSource::readiness) + .map { source -> source.readinessFor(state.platform) } .distinct() .sortedBy(ServerSourceReadiness::ordinal) Column( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceModeItem.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceModeItem.kt index 1ddf7dd1c..8aaad067f 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceModeItem.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceModeItem.kt @@ -23,12 +23,14 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.domain.entity.ServerSource -import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings +import com.shifthackz.aisdv1.presentation.model.readinessFor import com.shifthackz.aisdv1.presentation.screen.setup.component.icon -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi import com.shifthackz.aisdv1.presentation.screen.setup.component.subtitle import com.shifthackz.aisdv1.presentation.screen.setup.component.title +import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi /** * Renders one provider option in the provider selection list. @@ -43,10 +45,12 @@ import com.shifthackz.aisdv1.presentation.screen.setup.component.title internal fun SourceModeItem( source: ServerSource, selected: Boolean, + platform: Platform, strings: ServerSetupStrings, onClick: () -> Unit, ) { val shape = RoundedCornerShape(16.dp) + val readiness = source.readinessFor(platform) Surface( onClick = onClick, modifier = Modifier.fillMaxWidth(), @@ -68,7 +72,7 @@ internal fun SourceModeItem( modifier = Modifier .size(42.dp) .padding(top = 8.dp, bottom = 8.dp), - imageVector = source.icon, + imageVector = source.icon(platform), contentDescription = null, tint = MaterialTheme.colorScheme.primary, ) @@ -76,14 +80,14 @@ internal fun SourceModeItem( modifier = Modifier .align(Alignment.CenterVertically) .padding(vertical = 8.dp), - text = source.title(strings), + text = source.title(strings, platform), style = MaterialTheme.typography.bodyLarge, color = MaterialTheme.colorScheme.onSurface, ) } Text( modifier = Modifier.padding(horizontal = 8.dp), - text = source.subtitle(strings), + text = source.subtitle(strings, platform), style = MaterialTheme.typography.labelMedium, fontWeight = FontWeight.W500, color = MaterialTheme.colorScheme.onSurfaceVariant, @@ -92,9 +96,9 @@ internal fun SourceModeItem( modifier = Modifier.padding(4.dp), ) { SourceMetaChip( - text = source.readiness.mapToUi(strings), - containerColor = source.readiness.containerColor(), - contentColor = source.readiness.contentColor(), + text = readiness.mapToUi(strings), + containerColor = readiness.containerColor(), + contentColor = readiness.contentColor(), ) SourceMetaChip( text = strings.sourceVersion(source.version), diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionFilters.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionFilters.kt index dc043e57a..c85802580 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionFilters.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionFilters.kt @@ -5,11 +5,12 @@ import com.shifthackz.aisdv1.domain.entity.FeatureTag import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.ServerSourceReadiness import com.shifthackz.aisdv1.domain.entity.ServerSourceType -import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupState -import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi +import com.shifthackz.aisdv1.presentation.model.readinessFor import com.shifthackz.aisdv1.presentation.screen.setup.component.subtitle import com.shifthackz.aisdv1.presentation.screen.setup.component.title +import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi +import com.shifthackz.aisdv1.presentation.screen.setup.model.ServerSetupState /** * Applies provider search, filters, and sort order to the available source list. @@ -42,7 +43,7 @@ internal fun ServerSetupState.visibleSources(strings: ServerSetupStrings): List< */ internal fun ServerSetupState.isSourceTypeEnabled(type: ServerSourceType): Boolean = sourceTypeFilter == type || - allowedModes.hasSourceFilterMatch(type, sourceReadinessFilters, sourceTagFilters) + allowedModes.hasSourceFilterMatch(this, type, sourceReadinessFilters, sourceTagFilters) /** * Checks whether adding [readiness] keeps the filter set non-empty. @@ -53,6 +54,7 @@ internal fun ServerSetupState.isSourceTypeEnabled(type: ServerSourceType): Boole internal fun ServerSetupState.isSourceReadinessEnabled(readiness: ServerSourceReadiness): Boolean = readiness in sourceReadinessFilters || allowedModes.hasSourceFilterMatch( + this, sourceTypeFilter, sourceReadinessFilters + readiness, sourceTagFilters, @@ -67,6 +69,7 @@ internal fun ServerSetupState.isSourceReadinessEnabled(readiness: ServerSourceRe internal fun ServerSetupState.isSourceTagEnabled(tag: FeatureTag): Boolean = tag in sourceTagFilters || allowedModes.hasSourceFilterMatch( + this, sourceTypeFilter, sourceReadinessFilters, sourceTagFilters + tag, @@ -111,7 +114,8 @@ private fun ServerSource.matchesFilters( if (typeFilter != null && type != typeFilter) { return false } - if (state.sourceReadinessFilters.isNotEmpty() && readiness !in state.sourceReadinessFilters) { + val sourceReadiness = readinessFor(state.platform) + if (state.sourceReadinessFilters.isNotEmpty() && sourceReadiness !in state.sourceReadinessFilters) { return false } if (state.sourceTagFilters.isNotEmpty() && !featureTags.containsAll(state.sourceTagFilters)) { @@ -127,10 +131,10 @@ private fun ServerSource.matchesFilters( } val searchableText = buildList { add(key) - add(title(strings)) - add(subtitle(strings)) + add(title(strings, state.platform)) + add(subtitle(strings, state.platform)) add(type.mapToUi(strings)) - add(readiness.mapToUi(strings)) + add(sourceReadiness.mapToUi(strings)) add(version) featureTags.forEach { tag -> add(tag.mapToUi()) } } @@ -140,12 +144,14 @@ private fun ServerSource.matchesFilters( } private fun List.hasSourceFilterMatch( + state: ServerSetupState, type: ServerSourceType?, readinessFilters: Set, tags: Set, ): Boolean = any { source -> + val sourceReadiness = source.readinessFor(state.platform) (type == null || source.type == type) && - (readinessFilters.isEmpty() || source.readiness in readinessFilters) && + (readinessFilters.isEmpty() || sourceReadiness in readinessFilters) && source.featureTags.containsAll(tags) } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionStep.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionStep.kt index 9ad4192ea..49721a661 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionStep.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/source/SourceSelectionStep.kt @@ -89,6 +89,7 @@ internal fun SourceSelectionStep( SourceModeItem( source = source, selected = state.mode == source, + platform = state.platform, strings = strings, onClick = { coroutineScope.launch { diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt index 7c43c0ab8..09cf06f63 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt @@ -46,125 +46,25 @@ import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormEvent import com.shifthackz.aisdv1.presentation.widget.toolbar.GenerationBottomToolbar import com.shifthackz.aisdv1.presentation.widget.work.BackgroundWorkWidget -/** - * Carries `TextToImageStrings` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class TextToImageStrings( - /** - * Exposes the `title` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val title: String = Localization.string("title_text_to_image"), - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String = Localization.string("hint_prompt"), - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String = Localization.string("hint_prompt_negative"), - /** - * Exposes the `width` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val width: String = Localization.string("width"), - /** - * Exposes the `height` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val height: String = Localization.string("height"), - /** - * Exposes the `steps` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val steps: String = Localization.string("gallery_info_field_sampling_steps"), - /** - * Exposes the `cfgScale` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val cfgScale: String = Localization.string("gallery_info_field_cfg"), - /** - * Exposes the `batch` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val batch: String = Localization.string("hint_batch_tag"), - /** - * Exposes the `generate` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val generate: String = Localization.string("action_generate"), - /** - * Exposes the `generating` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val generating: String = Localization.string("notification_running_title"), - /** - * Exposes the `save` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val save: String = Localization.string("action_save"), - /** - * Exposes the `savingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val savingImage: String = Localization.string("message_image_saving"), - /** - * Exposes the `share` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val share: String = Localization.string("action_share_prompt"), - /** - * Exposes the `sharingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val sharingImage: String = Localization.string("message_image_sharing"), - /** - * Exposes the `configureProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val configureProvider: String = Localization.string("settings_item_config"), - /** - * Exposes the `sourceUnavailable` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val sourceUnavailable: String = Localization.string("error_source_android_only"), - /** - * Exposes the `results` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val results: String = Localization.string("title_generation_results"), - /** - * Exposes the `imageUnavailable` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val imageUnavailable: String = Localization.string("message_image_data_received"), - /** - * Exposes the `resultMeta` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val resultMeta: (AiGenerationResult) -> String = { result -> Localization.string( "generation_result_meta", @@ -176,14 +76,10 @@ data class TextToImageStrings( ) /** - * Renders the `TextToImageContent` UI for the SDAI presentation layer. + * Screen content for text-to-image generation. * - * @param state state rendered or processed by the component. - * @param processIntent process intent value consumed by the API. - * @param modifier Compose modifier applied to the rendered UI. - * @param strings strings value consumed by the API. - * @param useDrawerNavigation use drawer navigation value consumed by the API. - * @author Dmitriy Moroz + * The composable owns transient modal selection and pending prompt-chip text, + * while all durable state changes are sent back through [processIntent]. */ @Composable fun TextToImageContent( @@ -374,9 +270,7 @@ fun TextToImageContent( } /** - * Converts SDAI data with `toTextToImageIntent`. - * - * @author Dmitriy Moroz + * Maps shared generation form events into txt2img intents. */ internal fun GenerationInputFormEvent.toTextToImageIntent(): TextToImageIntent? = when (this) { is GenerationInputFormEvent.EditTag -> TextToImageIntent.ShowEditTag( @@ -401,6 +295,7 @@ internal fun GenerationInputFormEvent.toTextToImageIntent(): TextToImageIntent? is GenerationInputFormEvent.UpdateFalAiImageSize -> TextToImageIntent.UpdateFalAiImageSize(value) is GenerationInputFormEvent.UpdateFalAiAcceleration -> TextToImageIntent.UpdateFalAiAcceleration(value) is GenerationInputFormEvent.UpdateSdxlBackend -> TextToImageIntent.UpdateSdxlBackend(value) + is GenerationInputFormEvent.UpdateBonsaiBackend -> TextToImageIntent.UpdateBonsaiBackend(value) is GenerationInputFormEvent.UpdateFalAiSyncMode -> TextToImageIntent.UpdateFalAiSyncMode(value) is GenerationInputFormEvent.UpdateArliAiModel -> TextToImageIntent.UpdateArliAiModel(value) is GenerationInputFormEvent.UpdatePrompt -> TextToImageIntent.UpdatePrompt(value) @@ -422,27 +317,22 @@ internal fun GenerationInputFormEvent.toTextToImageIntent(): TextToImageIntent? GenerationInputFormEvent.OpenADetailerInstallInstructions -> TextToImageIntent.OpenADetailerInstallInstructions } -/** - * Executes the `appendPromptTag` step in the SDAI presentation layer. - * - * @param tag tag value consumed by the API. - * @return Result produced by `appendPromptTag`. - * @author Dmitriy Moroz - */ +@Composable +private fun GenerationHistoryDialog( + onClose: () -> Unit, + onGenerationSelected: (AiGenerationResult) -> Unit, +) { + InputHistoryBottomSheet( + onClose = onClose, + onGenerationSelected = onGenerationSelected, + ) +} + private fun String.appendPromptTag(tag: String): String = listOf(this, tag.trim()) .filter(String::isNotBlank) .joinToString(", ") -/** - * Executes the `flushPendingTaggedText` step in the SDAI presentation layer. - * - * @param state state rendered or processed by the component. - * @param promptChipTextFieldState prompt chip text field state value consumed by the API. - * @param negativePromptChipTextFieldState negative prompt chip text field state value consumed by the API. - * @param processIntent process intent value consumed by the API. - * @author Dmitriy Moroz - */ private fun flushPendingTaggedText( state: TextToImageState, promptChipTextFieldState: androidx.compose.runtime.MutableState, @@ -464,78 +354,15 @@ private fun flushPendingTaggedText( ?.also { negativePromptChipTextFieldState.value = TextFieldValue() } } -/** - * Renders the `GenerationHistoryDialog` UI for the SDAI presentation layer. - * - * @param onClose callback invoked by the component. - * @param onGenerationSelected callback invoked by the component. - * @author Dmitriy Moroz - */ -@Composable -private fun GenerationHistoryDialog( - onClose: () -> Unit, - onGenerationSelected: (AiGenerationResult) -> Unit, -) { - InputHistoryBottomSheet( - onClose = onClose, - onGenerationSelected = onGenerationSelected, - ) -} - -/** - * Defines the `TextToImagePanel` contract for the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private sealed interface TextToImagePanel { - /** - * Provides the `History` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object History : TextToImagePanel - /** - * Carries `Embeddings` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class Embeddings( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, ) : TextToImagePanel - /** - * Carries `Extras` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class Extras( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `type` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val type: ExtraType, ) : TextToImagePanel } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt index 279bb252a..3192f8baa 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt @@ -33,295 +33,15 @@ import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp -import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.domain.entity.AiGenerationResult -import com.shifthackz.aisdv1.domain.entity.ServerSource import kotlin.math.roundToInt -/** - * Renders the `NumberField` UI for the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @param label label value consumed by the API. - * @param error error value consumed by the API. - * @param onValueChange callback invoked by the component. - * @param modifier Compose modifier applied to the rendered UI. - * @author Dmitriy Moroz - */ -@Composable -internal fun NumberField( - value: String, - label: String, - error: UiText?, - onValueChange: (String) -> Unit, - modifier: Modifier = Modifier, -) { - OutlinedTextField( - modifier = modifier, - value = value, - onValueChange = onValueChange, - label = { Text(label) }, - singleLine = true, - isError = error != null, - supportingText = error?.let { message -> - { Text(message.asString()) } - }, - ) -} - -/** - * Renders the `SliderRow` UI for the SDAI presentation layer. - * - * @param label label value consumed by the API. - * @param value value value consumed by the API. - * @param content content value consumed by the API. - * @author Dmitriy Moroz - */ -@Composable -internal fun SliderRow( - label: String, - value: String, - content: @Composable () -> Unit, -) { - Column(verticalArrangement = Arrangement.spacedBy(4.dp)) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.SpaceBetween, - verticalAlignment = Alignment.CenterVertically, - ) { - Text( - text = label, - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - Text( - text = value, - style = MaterialTheme.typography.bodyMedium, - fontWeight = FontWeight.W600, - ) - } - content() - } -} - -/** - * Renders the `BatchControl` UI for the SDAI presentation layer. - * - * @param state state rendered or processed by the component. - * @param strings strings value consumed by the API. - * @param processIntent process intent value consumed by the API. - * @author Dmitriy Moroz - */ -@Composable -internal fun BatchControl( - state: TextToImageState, - strings: TextToImageStrings, - processIntent: (TextToImageIntent) -> Unit, -) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.SpaceBetween, - verticalAlignment = Alignment.CenterVertically, - ) { - Text( - text = strings.batch, - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - Row( - horizontalArrangement = Arrangement.spacedBy(8.dp), - verticalAlignment = Alignment.CenterVertically, - ) { - OutlinedButton( - contentPadding = PaddingValues(horizontal = 12.dp), - onClick = { - processIntent(TextToImageIntent.UpdateBatchCount(state.batchCount - 1)) - }, - ) { - Text("-") - } - Text( - text = state.batchCount.toString(), - style = MaterialTheme.typography.bodyLarge, - fontWeight = FontWeight.W600, - ) - OutlinedButton( - contentPadding = PaddingValues(horizontal = 12.dp), - onClick = { - processIntent(TextToImageIntent.UpdateBatchCount(state.batchCount + 1)) - }, - ) { - Text("+") - } - } - } -} - -/** - * Renders the `GeneratedImageItem` UI for the SDAI presentation layer. - * - * @param result result value consumed by the API. - * @param strings strings value consumed by the API. - * @param savingImage saving image value consumed by the API. - * @param sharingImage sharing image value consumed by the API. - * @param processIntent process intent value consumed by the API. - * @author Dmitriy Moroz - */ -@Composable -internal fun GeneratedImageItem( - result: AiGenerationResult, - strings: TextToImageStrings, - savingImage: Boolean, - sharingImage: Boolean, - processIntent: (TextToImageIntent) -> Unit, -) { - Surface( - modifier = Modifier.fillMaxWidth(), - shape = RoundedCornerShape(16.dp), - color = MaterialTheme.colorScheme.surface, - tonalElevation = 1.dp, - ) { - Column( - modifier = Modifier.padding(12.dp), - verticalArrangement = Arrangement.spacedBy(10.dp), - ) { - val imageBitmap = remember(result.image) { - result.image.decodeBase64ImageBitmap() - } - if (imageBitmap != null) { - Image( - modifier = Modifier - .fillMaxWidth() - .aspectRatio(result.aspectRatio), - bitmap = imageBitmap, - contentDescription = null, - contentScale = ContentScale.Crop, - ) - } else { - Box( - modifier = Modifier - .fillMaxWidth() - .aspectRatio(result.aspectRatio), - contentAlignment = Alignment.Center, - ) { - Icon( - modifier = Modifier.size(56.dp), - imageVector = Icons.Default.AutoFixNormal, - contentDescription = null, - tint = MaterialTheme.colorScheme.primary, - ) - } - Text( - text = strings.imageUnavailable, - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - } - - Text( - text = result.prompt, - style = MaterialTheme.typography.bodyMedium, - maxLines = 3, - overflow = TextOverflow.Ellipsis, - ) - - Text( - text = strings.resultMeta(result), - style = MaterialTheme.typography.labelMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.spacedBy(8.dp), - ) { - val actionsEnabled = result.image.isNotBlank() && !savingImage && !sharingImage - OutlinedButton( - modifier = Modifier.weight(1f), - enabled = actionsEnabled, - onClick = { processIntent(TextToImageIntent.SaveResult(result.image)) }, - ) { - if (savingImage) { - CircularProgressIndicator( - modifier = Modifier.size(18.dp), - strokeWidth = 2.dp, - ) - } else { - Icon( - imageVector = Icons.Default.Save, - contentDescription = null, - ) - } - Text( - modifier = Modifier.padding(start = 8.dp), - text = if (savingImage) strings.savingImage else strings.save, - ) - } - OutlinedButton( - modifier = Modifier.weight(1f), - enabled = actionsEnabled, - onClick = { processIntent(TextToImageIntent.ShareResult(result.image)) }, - ) { - if (sharingImage) { - CircularProgressIndicator( - modifier = Modifier.size(18.dp), - strokeWidth = 2.dp, - ) - } else { - Icon( - imageVector = Icons.Default.Share, - contentDescription = null, - ) - } - Text( - modifier = Modifier.padding(start = 8.dp), - text = if (sharingImage) strings.sharingImage else strings.share, - ) - } - } - } - } -} - -/** - * Exposes the `AiGenerationResult` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal val AiGenerationResult.aspectRatio: Float get() = if (width > 0 && height > 0) width.toFloat() / height.toFloat() else 1f -/** - * Exposes the `ServerSource` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ -internal val ServerSource.displayName: String - get() = when (this) { - ServerSource.AUTOMATIC1111 -> Localization.string("srv_type_own") - ServerSource.SWARM_UI -> Localization.string("srv_type_swarm_ui") - ServerSource.HORDE -> Localization.string("srv_type_horde") - ServerSource.HUGGING_FACE -> Localization.string("srv_type_hugging_face") - ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") - ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") - ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") - ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") - ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") - ServerSource.LOCAL_APPLE_CORE_ML -> "Core ML" - ServerSource.LOCAL_APPLE_BONSAI -> "Silicon Diffusion PrismML Bonsai" - } - -/** - * Executes the `roundToString` step in the SDAI presentation layer. - * - * @return Result produced by `roundToString`. - * @author Dmitriy Moroz - */ internal fun Float.roundToString(): String { val rounded = (this * 10f).roundToInt() / 10f return rounded.toString() diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageFormContent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageFormContent.kt index 2f9d0e101..262f1a0e0 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageFormContent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageFormContent.kt @@ -22,6 +22,7 @@ import androidx.compose.ui.text.input.TextFieldValue import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.presentation.model.compactDisplayName import com.shifthackz.aisdv1.presentation.platform.rememberExternalUrlLauncher import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputForm import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormEvent @@ -97,7 +98,7 @@ internal fun TextToImageForm( verticalArrangement = Arrangement.spacedBy(14.dp), ) { Text( - text = state.mode.displayName, + text = state.mode.compactDisplayName(state.platform), style = MaterialTheme.typography.labelLarge, color = MaterialTheme.colorScheme.onSurfaceVariant, maxLines = 1, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt index d3499bd75..5dc66221e 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.presentation.screen.txt2img import com.shifthackz.aisdv1.core.mvi.MviIntent import com.shifthackz.aisdv1.domain.entity.ADetailerConfig import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel @@ -18,413 +19,76 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiStylePreset import com.shifthackz.aisdv1.presentation.widget.input.GenerationAspectRatio /** - * Defines the `TextToImageIntent` contract for the SDAI presentation layer. + * User actions and form updates emitted by the txt2img screen. * - * @author Dmitriy Moroz + * Benchmark warning actions are modeled here with regular generation intents so + * the view-model can gate local providers without coupling the shared input + * form to benchmark-specific UI. */ sealed interface TextToImageIntent : MviIntent { - /** - * Provides the `OpenDrawer` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object OpenDrawer : TextToImageIntent - /** - * Provides the `NavigateBack` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object NavigateBack : TextToImageIntent - /** - * Provides the `ConfigureProvider` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object ConfigureProvider : TextToImageIntent - /** - * Provides the `Generate` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object Generate : TextToImageIntent - /** - * Opens the benchmark screen from the first local generation prompt. - * - * @author Dmitriy Moroz - */ data object RunBenchmarkFromPrompt : TextToImageIntent - /** - * Skips the first local generation benchmark prompt and continues generation. - * - * @author Dmitriy Moroz - */ data object SkipBenchmarkPrompt : TextToImageIntent - /** - * Continues generation after the benchmark recommendation warning. - * - * @author Dmitriy Moroz - */ data object ContinueAfterBenchmarkWarning : TextToImageIntent - /** - * Suppresses future benchmark recommendation warnings and continues generation. - * - * @author Dmitriy Moroz - */ data object SuppressBenchmarkWarningAndContinue : TextToImageIntent - /** - * Provides the `DismissModal` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object DismissModal : TextToImageIntent - /** - * Provides the `CancelGeneration` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object CancelGeneration : TextToImageIntent - /** - * Provides the `DismissError` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object DismissError : TextToImageIntent - /** - * Provides the `DismissMessage` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object DismissMessage : TextToImageIntent - /** - * Provides the `DismissEditTag` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object DismissEditTag : TextToImageIntent - /** - * Carries `SaveResult` data through the SDAI presentation layer. - * - * @param base64 Base64 image payload used by the operation. - * @author Dmitriy Moroz - */ data class SaveResult(val base64: String) : TextToImageIntent - /** - * Carries `ShareResult` data through the SDAI presentation layer. - * - * @param base64 Base64 image payload used by the operation. - * @author Dmitriy Moroz - */ data class ShareResult(val base64: String) : TextToImageIntent - /** - * Carries `SaveGenerationResults` data through the SDAI presentation layer. - * - * @param results results value consumed by the API. - * @author Dmitriy Moroz - */ data class SaveGenerationResults(val results: List) : TextToImageIntent - /** - * Carries `ViewGenerationResult` data through the SDAI presentation layer. - * - * @param result result value consumed by the API. - * @author Dmitriy Moroz - */ data class ViewGenerationResult(val result: AiGenerationResult) : TextToImageIntent - /** - * Carries `ReportGenerationResult` data through the SDAI presentation layer. - * - * @param result result value consumed by the API. - * @author Dmitriy Moroz - */ data class ReportGenerationResult(val result: AiGenerationResult) : TextToImageIntent - /** - * Carries `ShowEditTag` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class ShowEditTag( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `tag` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val tag: String, - /** - * Exposes the `isNegative` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val isNegative: Boolean, ) : TextToImageIntent - /** - * Carries `ApplyPrompts` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class ApplyPrompts( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, ) : TextToImageIntent - /** - * Carries `ApplyGenerationResult` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class ApplyGenerationResult( - /** - * Exposes the `ai` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val ai: AiGenerationResult, ) : TextToImageIntent - /** - * Carries `UpdateAdvancedOptionsVisibility` data through the SDAI presentation layer. - * - * @param visible visible value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateAdvancedOptionsVisibility(val visible: Boolean) : TextToImageIntent - /** - * Carries `UpdatePrompt` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdatePrompt(val value: String) : TextToImageIntent - /** - * Carries `UpdateNegativePrompt` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateNegativePrompt(val value: String) : TextToImageIntent - /** - * Carries `UpdateWidth` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateWidth(val value: String) : TextToImageIntent - /** - * Provides the `SwapDimensions` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object SwapDimensions : TextToImageIntent - /** - * Carries `ApplyAspectRatio` data through the SDAI presentation layer. - * - * @param ratio ratio value consumed by the API. - * @author Dmitriy Moroz - */ data class ApplyAspectRatio(val ratio: GenerationAspectRatio) : TextToImageIntent - /** - * Carries `UpdateHeight` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateHeight(val value: String) : TextToImageIntent - /** - * Carries `UpdateSamplingSteps` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSamplingSteps(val value: Int) : TextToImageIntent - /** - * Carries `UpdateCfgScale` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateCfgScale(val value: Float) : TextToImageIntent - /** - * Carries `UpdateRestoreFaces` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateRestoreFaces(val value: Boolean) : TextToImageIntent - /** - * Carries `UpdateSeed` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSeed(val value: String) : TextToImageIntent - /** - * Carries `UpdateSubSeed` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSubSeed(val value: String) : TextToImageIntent - /** - * Carries `UpdateSubSeedStrength` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSubSeedStrength(val value: Float) : TextToImageIntent - /** - * Carries `UpdateSampler` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSampler(val value: String) : TextToImageIntent - /** - * Carries `UpdateScheduler` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateScheduler(val value: Scheduler) : TextToImageIntent - /** - * Carries `UpdateForgeModules` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateForgeModules(val value: List) : TextToImageIntent - /** - * Carries `UpdateNsfw` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateNsfw(val value: Boolean) : TextToImageIntent - /** - * Carries `UpdateBatchCount` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateBatchCount(val value: Int) : TextToImageIntent - /** - * Carries `UpdateOpenAiModel` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiModel(val value: OpenAiModel) : TextToImageIntent - /** - * Carries `UpdateOpenAiSize` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiSize(val value: OpenAiSize) : TextToImageIntent - /** - * Carries `UpdateOpenAiQuality` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiQuality(val value: OpenAiQuality) : TextToImageIntent - /** - * Carries `UpdateFalAiModel` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiModel(val value: FalAiModel) : TextToImageIntent - /** - * Carries `UpdateFalAiImageSize` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiImageSize(val value: FalAiImageSize) : TextToImageIntent - /** - * Carries `UpdateFalAiAcceleration` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiAcceleration(val value: FalAiAcceleration) : TextToImageIntent - /** - * Carries `UpdateSdxlBackend` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSdxlBackend(val value: SdxlBackend) : TextToImageIntent - /** - * Carries `UpdateFalAiSyncMode` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ + data class UpdateBonsaiBackend(val value: BonsaiBackend) : TextToImageIntent data class UpdateFalAiSyncMode(val value: Boolean) : TextToImageIntent - /** - * Carries `UpdateArliAiModel` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateArliAiModel(val value: String) : TextToImageIntent - /** - * Carries `UpdateStabilityAiStyle` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateStabilityAiStyle(val value: StabilityAiStylePreset) : TextToImageIntent - /** - * Carries `UpdateStabilityAiClipGuidance` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateStabilityAiClipGuidance(val value: StabilityAiClipGuidance) : TextToImageIntent - /** - * Carries `UpdateHiresConfig` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateHiresConfig(val value: HiresConfig) : TextToImageIntent - /** - * Carries `UpdateADetailerConfig` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateADetailerConfig(val value: ADetailerConfig) : TextToImageIntent - /** - * Provides the `RefreshADetailerAvailability` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object RefreshADetailerAvailability : TextToImageIntent - /** - * Provides the `OpenADetailerInstallInstructions` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object OpenADetailerInstallInstructions : TextToImageIntent } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt index 81ccb0854..0a9a1d401 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt @@ -11,79 +11,25 @@ import kotlin.math.max import kotlin.math.roundToInt /** - * Coordinates `TextToImageIntentProcessor` behavior in the SDAI presentation layer. + * Applies synchronous txt2img UI intents. * - * @author Dmitriy Moroz + * Navigation, form mutations, and result actions are handled here so the + * view-model can keep asynchronous generation and configuration loading logic + * separate from simple state transitions. */ internal class TextToImageIntentProcessor( - /** - * Exposes the `router` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val router: TextToImageRouter, - /** - * Exposes the `updateState` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val updateState: (((TextToImageState) -> TextToImageState) -> Unit), - /** - * Exposes the `generate` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val generate: () -> Unit, - /** - * Exposes the `cancelGeneration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val cancelGeneration: () -> Unit, - /** - * Exposes the `saveImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveImage: (String) -> Unit, - /** - * Exposes the `shareImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val shareImage: (String) -> Unit, - /** - * Exposes the `saveGenerationResults` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveGenerationResults: (List) -> Unit, - /** - * Exposes the `viewGenerationResult` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val viewGenerationResult: (AiGenerationResult) -> Unit, - /** - * Exposes the `reportGenerationResult` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val reportGenerationResult: (AiGenerationResult) -> Unit, - /** - * Exposes the `applyGenerationResult` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val applyGenerationResult: (AiGenerationResult) -> Unit, ) { - /** - * Executes the `process` step in the SDAI presentation layer. - * - * @param intent intent to process in the MVI workflow. - * @author Dmitriy Moroz - */ fun process(intent: TextToImageIntent) { when (intent) { TextToImageIntent.OpenDrawer -> router.openDrawer() @@ -274,6 +220,9 @@ internal class TextToImageIntentProcessor( is TextToImageIntent.UpdateSdxlBackend -> updateState { it.copy(sdxlBackend = intent.value, message = null) } + is TextToImageIntent.UpdateBonsaiBackend -> updateState { + it.copy(bonsaiBackend = intent.value, message = null) + } is TextToImageIntent.UpdateFalAiSyncMode -> updateState { it.copy(falAiSyncMode = intent.value, message = null) } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt index efd45572f..d1cc535fb 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt @@ -1,10 +1,12 @@ package com.shifthackz.aisdv1.presentation.screen.txt2img import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.mvi.MviState import com.shifthackz.aisdv1.domain.entity.ADetailerConfig import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel @@ -24,296 +26,66 @@ import com.shifthackz.aisdv1.presentation.model.PromptTagEditRequest import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormState /** - * Carries `TextToImageState` data through the SDAI presentation layer. + * Complete render state for txt2img. * - * @author Dmitriy Moroz + * The state combines screen lifecycle flags, generated results, validation + * feedback, current platform metadata, and the shared generation form contract + * used by provider-specific controls such as Android Bonsai backend selection. */ @Immutable data class TextToImageState( - /** - * Exposes the `loadingConfiguration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val loadingConfiguration: Boolean = true, - /** - * Exposes the `generating` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val generating: Boolean = false, - /** - * Exposes the `savingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val savingImage: Boolean = false, - /** - * Exposes the `sharingImage` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val sharingImage: Boolean = false, - /** - * Exposes the `promptValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val promptValidationError: UiText? = null, - /** - * Exposes the `error` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val error: UiText? = null, - /** - * Exposes the `message` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val message: UiText? = null, - /** - * Exposes the `screenModal` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val screenModal: GenerationModal = GenerationModal.None, - /** - * Exposes the `results` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val results: List = emptyList(), - /** - * Exposes the `editTag` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val editTag: PromptTagEditRequest? = null, - /** - * Exposes the `onBoardingDemo` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val onBoardingDemo: Boolean = false, - /** - * Exposes the `mode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ + val platform: Platform = Platform.ANDROID, override val mode: ServerSource = ServerSource.AUTOMATIC1111, - /** - * Exposes the `advancedToggleButtonVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val advancedToggleButtonVisible: Boolean = true, - /** - * Exposes the `advancedOptionsVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val advancedOptionsVisible: Boolean = false, - /** - * Exposes the `formPromptTaggedInput` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val formPromptTaggedInput: Boolean = false, - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val prompt: String = "", - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val negativePrompt: String = "", - /** - * Exposes the `width` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val width: String = DEFAULT_SIZE.toString(), - /** - * Exposes the `height` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val height: String = DEFAULT_SIZE.toString(), - /** - * Exposes the `samplingSteps` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val samplingSteps: Int = 20, - /** - * Exposes the `cfgScale` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val cfgScale: Float = 7f, - /** - * Exposes the `restoreFaces` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val restoreFaces: Boolean = false, - /** - * Exposes the `seed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val seed: String = "", - /** - * Exposes the `subSeed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val subSeed: String = "", - /** - * Exposes the `subSeedStrength` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val subSeedStrength: Float = 0f, - /** - * Exposes the `selectedSampler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedSampler: String = "", - /** - * Exposes the `selectedScheduler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedScheduler: Scheduler = Scheduler.AUTOMATIC, - /** - * Exposes the `availableForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val availableForgeModules: List = emptyList(), - /** - * Exposes the `selectedForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedForgeModules: List = emptyList(), - /** - * Exposes the `availableSamplers` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val availableSamplers: List = emptyList(), - /** - * Exposes the `selectedStylePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedStylePreset: StabilityAiStylePreset = StabilityAiStylePreset.NONE, - /** - * Exposes the `selectedClipGuidancePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val selectedClipGuidancePreset: StabilityAiClipGuidance = StabilityAiClipGuidance.NONE, - /** - * Exposes the `openAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiModel: OpenAiModel = OpenAiModel.default, - /** - * Exposes the `openAiSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiSize: OpenAiSize = OpenAiSize.W1024_H1024, - /** - * Exposes the `openAiQuality` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val openAiQuality: OpenAiQuality = OpenAiQuality.AUTO, - /** - * Exposes the `falAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiModel: FalAiModel = FalAiModel.defaultTextToImage, - /** - * Exposes the `falAiImageSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiImageSize: FalAiImageSize = FalAiImageSize.default, - /** - * Exposes the `falAiAcceleration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiAcceleration: FalAiAcceleration = FalAiAcceleration.default, - /** - * Exposes the `falAiSyncMode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val falAiSyncMode: Boolean = false, override val sdxlBackend: SdxlBackend = SdxlBackend.AUTO, + override val bonsaiBackend: BonsaiBackend = BonsaiBackend.AUTO, + override val bonsaiBackendSelectionVisible: Boolean = false, override val arliAiModels: List = emptyList(), override val arliAiModel: String = "", - /** - * Exposes the `widthValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val widthValidationError: UiText? = null, - /** - * Exposes the `heightValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val heightValidationError: UiText? = null, - /** - * Exposes the `nsfw` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val nsfw: Boolean = false, - /** - * Exposes the `batchCount` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val batchCount: Int = 1, - /** - * Exposes the `hires` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val hires: HiresConfig = HiresConfig.DISABLED, - /** - * Exposes the `aDetailer` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailer: ADetailerConfig = ADetailerConfig.DISABLED, - /** - * Exposes the `aDetailerAvailable` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailerAvailable: Boolean = false, - /** - * Exposes the `aDetailerRefreshing` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val aDetailerRefreshing: Boolean = false, ) : MviState, GenerationInputFormState { @@ -332,9 +104,7 @@ data class TextToImageState( } /** - * Converts SDAI data with `mapToPayload`. - * - * @author Dmitriy Moroz + * Converts the current txt2img state into the domain generation request. */ internal fun TextToImageState.mapToPayload(): TextToImagePayload = TextToImagePayload( prompt = prompt.trim(), @@ -400,13 +170,11 @@ internal fun TextToImageState.mapToPayload(): TextToImagePayload = TextToImagePa sdxlBackend = sdxlBackend.takeIf { mode == ServerSource.LOCAL_STABLE_DIFFUSION_CPP } ?: SdxlBackend.AUTO, + bonsaiBackend = bonsaiBackend.takeIf { + mode == ServerSource.LOCAL_APPLE_BONSAI + } ?: BonsaiBackend.AUTO, falAiSyncMode = falAiSyncMode, arliAiModel = arliAiModel.takeIf { mode == ServerSource.ARLI_AI }.orEmpty(), ) -/** - * Exposes the `DEFAULT_SIZE` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ internal const val DEFAULT_SIZE = 512 diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index d3ba9f36b..05384e47f 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.screen.txt2img import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.mvi.BaseMviViewModel import com.shifthackz.aisdv1.core.mvi.EmptyEffect @@ -38,181 +39,46 @@ import kotlinx.coroutines.flow.catch import kotlinx.coroutines.withContext /** - * Coordinates `TextToImageViewModel` behavior in the SDAI presentation layer. + * View-model for the txt2img screen. * - * @author Dmitriy Moroz + * It loads provider configuration, mirrors the shared generation form state, + * gates local providers through benchmark recommendations, and delegates actual + * generation work to the action handler. */ class TextToImageViewModel( - /** - * Exposes the `dispatchersProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val dispatchersProvider: DispatchersProvider, - /** - * Exposes the `getConfigurationUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getConfigurationUseCase: GetConfigurationUseCase, - /** - * Exposes the `getStableDiffusionSamplersUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase, - /** - * Exposes the `getForgeModulesUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val getForgeModulesUseCase: GetForgeModulesUseCase, - /** - * Exposes the `fetchAndGetArliAiModelsUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val fetchAndGetArliAiModelsUseCase: FetchAndGetArliAiModelsUseCase, - /** - * Exposes the `isADetailerAvailableUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val isADetailerAvailableUseCase: IsADetailerAvailableUseCase, - /** - * Exposes the `textToImageUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val textToImageUseCase: TextToImageUseCase, - /** - * Exposes the `saveGenerationResultUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveGenerationResultUseCase: SaveGenerationResultUseCase, - /** - * Exposes the `saveLastResultToCacheUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase, - /** - * Exposes the `interruptGenerationUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val interruptGenerationUseCase: InterruptGenerationUseCase, - /** - * Exposes the `observeHordeProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, - /** - * Exposes the `observeLocalDiffusionProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, - /** - * Exposes the `observeStableDiffusionCppProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeStableDiffusionCppProcessStatusUseCase: ObserveStableDiffusionCppProcessStatusUseCase, - /** - * Exposes the `observeCoreMlProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeCoreMlProcessStatusUseCase: ObserveCoreMlProcessStatusUseCase, - /** - * Exposes the `observeBonsaiProcessStatusUseCase` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val observeBonsaiProcessStatusUseCase: ObserveBonsaiProcessStatusUseCase, - /** - * Exposes the `preferenceManager` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val preferenceManager: PreferenceManager, - /** - * Exposes the `backgroundTaskManager` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val backgroundTaskManager: BackgroundTaskManager, - /** - * Exposes the `backgroundWorkObserver` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val backgroundWorkObserver: BackgroundWorkObserver, - /** - * Exposes the `wakeLockInterActor` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val wakeLockInterActor: WakeLockInterActor, - /** - * Exposes the `platformServices` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val platformServices: GenerationPlatformServices, - /** - * Exposes the `buildInfoProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val buildInfoProvider: BuildInfoProvider, - /** - * Exposes the `generationFormUpdateEvent` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val generationFormUpdateEvent: GenerationFormUpdateEvent, - /** - * Exposes the `dimensionValidator` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val dimensionValidator: DimensionValidator, - /** - * Exposes the `localGenerationBenchmarkGateProvider` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val localGenerationBenchmarkGateProvider: () -> LocalGenerationBenchmarkGate, - /** - * Exposes the `imageSaver` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val imageSaver: ImageSaver, - /** - * Exposes the `imageSharer` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val imageSharer: ImageSharer, - /** - * Exposes the `router` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val router: TextToImageRouter, - /** - * Exposes the `onError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ private val onError: (Throwable) -> Unit = {}, ) : BaseMviViewModel( - initialState = TextToImageState(), + initialState = TextToImageState( + platform = buildInfoProvider.platform, + bonsaiBackendSelectionVisible = buildInfoProvider.platform == Platform.ANDROID, + ), effectDispatcher = dispatchersProvider.immediate, ) { diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index 78690f4c8..84e1f6428 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -20,6 +20,7 @@ import androidx.compose.ui.text.input.TextFieldValue import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel import com.shifthackz.aisdv1.domain.entity.OpenAiModel @@ -115,6 +116,19 @@ fun GenerationInputForm( displayDelegate = { it.displayName.asUiText() }, ) } + if ( + state.mode == ServerSource.LOCAL_APPLE_BONSAI && + state.bonsaiBackendSelectionVisible + ) { + DropdownTextField( + modifier = Modifier.padding(top = 8.dp), + label = Localization.string("hint_bonsai_backend").asUiText(), + value = state.bonsaiBackend, + items = BonsaiBackend.entries, + onItemSelected = { onEvent(GenerationInputFormEvent.UpdateBonsaiBackend(it)) }, + displayDelegate = { it.displayName.asUiText() }, + ) + } } if (state.formPromptTaggedInput) { ChipTextFieldWithItem( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt index d4017f527..ca5ccfcd3 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.widget.input import com.shifthackz.aisdv1.domain.entity.ADetailerConfig +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel @@ -15,296 +16,62 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiClipGuidance import com.shifthackz.aisdv1.domain.entity.StabilityAiStylePreset /** - * Defines the `GenerationInputFormEvent` contract for the SDAI presentation layer. + * Events emitted by the reusable generation input form. * - * @author Dmitriy Moroz + * Screen-specific processors translate these into txt2img/img2img intents so + * the form can stay shared while each screen decides how to persist settings, + * validate dimensions, and handle provider-specific side effects. */ sealed interface GenerationInputFormEvent { - /** - * Carries `UpdateAdvancedOptionsVisibility` data through the SDAI presentation layer. - * - * @param visible visible value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateAdvancedOptionsVisibility(val visible: Boolean) : GenerationInputFormEvent - /** - * Carries `UpdatePrompt` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdatePrompt(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateNegativePrompt` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateNegativePrompt(val value: String) : GenerationInputFormEvent - /** - * Carries `EditTag` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data class EditTag( - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String, - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String, - /** - * Exposes the `tag` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val tag: String, - /** - * Exposes the `isNegative` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val isNegative: Boolean, ) : GenerationInputFormEvent - /** - * Carries `UpdateWidth` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateWidth(val value: String) : GenerationInputFormEvent - /** - * Provides the `SwapDimensions` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object SwapDimensions : GenerationInputFormEvent - /** - * Carries `ApplyAspectRatio` data through the SDAI presentation layer. - * - * @param ratio ratio value consumed by the API. - * @author Dmitriy Moroz - */ data class ApplyAspectRatio(val ratio: GenerationAspectRatio) : GenerationInputFormEvent - /** - * Carries `UpdateHeight` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateHeight(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateSamplingSteps` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSamplingSteps(val value: Int) : GenerationInputFormEvent - /** - * Carries `UpdateCfgScale` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateCfgScale(val value: Float) : GenerationInputFormEvent - /** - * Carries `UpdateRestoreFaces` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateRestoreFaces(val value: Boolean) : GenerationInputFormEvent - /** - * Carries `UpdateSeed` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSeed(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateSubSeed` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSubSeed(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateSubSeedStrength` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSubSeedStrength(val value: Float) : GenerationInputFormEvent - /** - * Carries `UpdateSampler` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSampler(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateScheduler` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateScheduler(val value: Scheduler) : GenerationInputFormEvent - /** - * Carries `UpdateForgeModules` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateForgeModules(val value: List) : GenerationInputFormEvent - /** - * Carries `UpdateNsfw` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateNsfw(val value: Boolean) : GenerationInputFormEvent - /** - * Carries `UpdateBatch` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateBatch(val value: Int) : GenerationInputFormEvent - /** - * Carries `UpdateOpenAiModel` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiModel(val value: OpenAiModel) : GenerationInputFormEvent - /** - * Carries `UpdateOpenAiSize` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiSize(val value: OpenAiSize) : GenerationInputFormEvent - /** - * Carries `UpdateOpenAiQuality` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateOpenAiQuality(val value: OpenAiQuality) : GenerationInputFormEvent - /** - * Carries `UpdateFalAiModel` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiModel(val value: FalAiModel) : GenerationInputFormEvent - /** - * Carries `UpdateFalAiImageSize` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiImageSize(val value: FalAiImageSize) : GenerationInputFormEvent - /** - * Carries `UpdateFalAiAcceleration` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateFalAiAcceleration(val value: FalAiAcceleration) : GenerationInputFormEvent - /** - * Carries `UpdateSdxlBackend` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateSdxlBackend(val value: SdxlBackend) : GenerationInputFormEvent - /** - * Carries `UpdateFalAiSyncMode` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ + data class UpdateBonsaiBackend(val value: BonsaiBackend) : GenerationInputFormEvent data class UpdateFalAiSyncMode(val value: Boolean) : GenerationInputFormEvent - /** - * Carries `UpdateArliAiModel` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateArliAiModel(val value: String) : GenerationInputFormEvent - /** - * Carries `UpdateStabilityAiStyle` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateStabilityAiStyle(val value: StabilityAiStylePreset) : GenerationInputFormEvent - /** - * Carries `UpdateStabilityAiClipGuidance` data through the SDAI presentation layer. - * - * @param value value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateStabilityAiClipGuidance(val value: StabilityAiClipGuidance) : GenerationInputFormEvent - /** - * Carries `UpdateHiresConfig` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateHiresConfig(val value: HiresConfig) : GenerationInputFormEvent - /** - * Carries `UpdateADetailerConfig` data through the SDAI presentation layer. - * - * @param value value consumed by the API. - * @author Dmitriy Moroz - */ data class UpdateADetailerConfig(val value: ADetailerConfig) : GenerationInputFormEvent - /** - * Provides the `RefreshADetailerAvailability` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object RefreshADetailerAvailability : GenerationInputFormEvent - /** - * Provides the `OpenADetailerInstallInstructions` singleton used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ data object OpenADetailerInstallInstructions : GenerationInputFormEvent } /** - * Carries `GenerationAspectRatio` data through the SDAI presentation layer. - * - * @author Dmitriy Moroz + * Preset aspect ratios applied by resizing the active width/height fields. */ enum class GenerationAspectRatio( - /** - * Exposes the `displayName` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val displayName: String, - /** - * Exposes the `width` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val width: Int, - /** - * Exposes the `height` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val height: Int, ) { SQUARE("1:1", 1, 1), diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt index 06b2af7a3..f16d69e38 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.presentation.widget.input import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.domain.entity.ADetailerConfig +import com.shifthackz.aisdv1.domain.entity.BonsaiBackend import com.shifthackz.aisdv1.domain.entity.FalAiAcceleration import com.shifthackz.aisdv1.domain.entity.FalAiImageSize import com.shifthackz.aisdv1.domain.entity.FalAiModel @@ -17,252 +18,62 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiClipGuidance import com.shifthackz.aisdv1.domain.entity.StabilityAiStylePreset /** - * Defines the `GenerationInputFormState` contract for the SDAI presentation layer. + * Shared state contract for generation controls. * - * @author Dmitriy Moroz + * Txt2img and img2img screens implement this interface so the common form can + * render provider-specific controls, including local runtime backend selectors, + * without owning screen-level loading, result, or navigation state. */ interface GenerationInputFormState { - /** - * Exposes the `onBoardingDemo` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val onBoardingDemo: Boolean - /** - * Exposes the `mode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val mode: ServerSource - /** - * Exposes the `advancedToggleButtonVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val advancedToggleButtonVisible: Boolean - /** - * Exposes the `advancedOptionsVisible` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val advancedOptionsVisible: Boolean - /** - * Exposes the `formPromptTaggedInput` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val formPromptTaggedInput: Boolean - /** - * Exposes the `prompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val prompt: String - /** - * Exposes the `negativePrompt` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePrompt: String - /** - * Exposes the `width` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val width: String - /** - * Exposes the `height` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val height: String - /** - * Exposes the `samplingSteps` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val samplingSteps: Int - /** - * Exposes the `cfgScale` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val cfgScale: Float - /** - * Exposes the `restoreFaces` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val restoreFaces: Boolean - /** - * Exposes the `seed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val seed: String - /** - * Exposes the `subSeed` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val subSeed: String - /** - * Exposes the `subSeedStrength` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val subSeedStrength: Float - /** - * Exposes the `selectedSampler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val selectedSampler: String - /** - * Exposes the `selectedScheduler` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val selectedScheduler: Scheduler - /** - * Exposes the `availableForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val availableForgeModules: List - /** - * Exposes the `selectedForgeModules` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val selectedForgeModules: List - /** - * Exposes the `availableSamplers` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val availableSamplers: List - /** - * Exposes the `selectedStylePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val selectedStylePreset: StabilityAiStylePreset - /** - * Exposes the `selectedClipGuidancePreset` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val selectedClipGuidancePreset: StabilityAiClipGuidance - /** - * Exposes the `openAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val openAiModel: OpenAiModel - /** - * Exposes the `openAiSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val openAiSize: OpenAiSize - /** - * Exposes the `openAiQuality` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val openAiQuality: OpenAiQuality - /** - * Exposes the `falAiModel` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val falAiModel: FalAiModel - /** - * Exposes the `falAiImageSize` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val falAiImageSize: FalAiImageSize - /** - * Exposes the `falAiAcceleration` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val falAiAcceleration: FalAiAcceleration - /** - * Exposes the `falAiSyncMode` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val falAiSyncMode: Boolean val sdxlBackend: SdxlBackend + val bonsaiBackend: BonsaiBackend + val bonsaiBackendSelectionVisible: Boolean + get() = false val arliAiModels: List val arliAiModel: String - /** - * Exposes the `widthValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val widthValidationError: UiText? - /** - * Exposes the `heightValidationError` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val heightValidationError: UiText? - /** - * Exposes the `nsfw` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val nsfw: Boolean - /** - * Exposes the `batchCount` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val batchCount: Int - /** - * Exposes the `hires` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val hires: HiresConfig - /** - * Exposes the `aDetailer` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val aDetailer: ADetailerConfig - /** - * Exposes the `aDetailerAvailable` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val aDetailerAvailable: Boolean - /** - * Exposes the `aDetailerRefreshing` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val aDetailerRefreshing: Boolean - /** - * Exposes the `promptKeywords` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val promptKeywords: List get() = prompt.split(",") .map { it.trim() } .filter { it.isNotEmpty() } - /** - * Exposes the `negativePromptKeywords` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val negativePromptKeywords: List get() = negativePrompt.split(",") .map { it.trim() } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index 6eaaad08a..90ed0ff6f 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -1,39 +1,22 @@ package com.shifthackz.aisdv1.presentation.widget.source import androidx.compose.runtime.Composable -import com.shifthackz.aisdv1.core.localization.Localization +import com.shifthackz.aisdv1.core.common.platform.Platform import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.model.displayName /** - * Renders the `getName` UI for the SDAI presentation layer. - * - * @author Dmitriy Moroz + * Resolves the localized provider label for composables that need a plain string. */ @Composable -fun ServerSource.getName(): String = getNameUiText().asString() +fun ServerSource.getName(platform: Platform = Platform.ANDROID): String = + getNameUiText(platform).asString() /** - * Loads SDAI data through `getNameUiText`. - * - * @author Dmitriy Moroz + * Resolves the provider label as [UiText] for non-composable consumers. */ -fun ServerSource.getNameUiText(): UiText = Localization.string( - when (this) { - ServerSource.AUTOMATIC1111 -> "srv_type_own" - ServerSource.HORDE -> "srv_type_horde" - ServerSource.LOCAL_MICROSOFT_ONNX -> "srv_type_local" - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> "srv_type_media_pipe" - ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> "srv_type_sdxl" - ServerSource.LOCAL_APPLE_CORE_ML -> return UiText.Static("Silicon Diffusion Core ML") - ServerSource.LOCAL_APPLE_BONSAI -> return UiText.Static("Silicon Diffusion PrismML Bonsai") - ServerSource.HUGGING_FACE -> "srv_type_hugging_face" - ServerSource.OPEN_AI -> "srv_type_open_ai" - ServerSource.STABILITY_AI -> "srv_type_stability_ai" - ServerSource.FAL_AI -> "srv_type_fal_ai" - ServerSource.ARLI_AI -> "srv_type_arli_ai" - ServerSource.SWARM_UI -> "srv_type_swarm_ui" - }, -).asUiText() +fun ServerSource.getNameUiText(platform: Platform = Platform.ANDROID): UiText = + displayName(platform).asUiText() diff --git a/presentation/src/iosMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.ios.kt b/presentation/src/iosMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.ios.kt index ebd48a5d9..9212ab224 100644 --- a/presentation/src/iosMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.ios.kt +++ b/presentation/src/iosMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/platform/ServerSetupLocalPathPickerButton.ios.kt @@ -35,11 +35,6 @@ internal actual fun ServerSetupLocalPathPickerButton( internal actual fun isLocalGenerationSetupAvailable(): Boolean = true internal actual fun isServerSourceAvailableOnPlatform(source: ServerSource): Boolean = when (source) { - ServerSource.LOCAL_MICROSOFT_ONNX, - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, - ServerSource.LOCAL_STABLE_DIFFUSION_CPP, - -> false - ServerSource.LOCAL_APPLE_CORE_ML, -> isAppleLocalRuntimeAvailable()