diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 118efa55c..0cb0b808f 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -4,6 +4,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Helper.Factory; @@ -18,6 +19,7 @@ public class PackageFactory : IPackageFactory private readonly IUvManager uvManager; private readonly IPyInstallationManager pyInstallationManager; private readonly IPipWheelService pipWheelService; + private readonly IRocmPackageHelper rocmPackageHelper; /// /// Mapping of package.Name to package @@ -32,7 +34,9 @@ public PackageFactory( IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, IPyRunner pyRunner, - IPipWheelService pipWheelService + IUvManager uvManager, + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) { this.githubApiCache = githubApiCache; @@ -40,8 +44,10 @@ IPipWheelService pipWheelService this.downloadService = downloadService; this.prerequisiteHelper = prerequisiteHelper; this.pyRunner = pyRunner; + this.uvManager = uvManager; this.pyInstallationManager = pyInstallationManager; this.pipWheelService = pipWheelService; + this.rocmPackageHelper = rocmPackageHelper; this.basePackages = basePackages.ToDictionary(x => x.Name); } @@ -55,7 +61,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "Fooocus" => new Fooocus( githubApiCache, @@ -280,7 +287,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)), }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index eedcb556c..0013f65bc 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -1,4 +1,6 @@ -namespace StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; + +namespace StabilityMatrix.Core.Helper.HardwareInfo; public record GpuInfo { @@ -62,11 +64,7 @@ public bool IsLegacyNvidiaGpu() public bool IsWindowsRocmSupportedGpu() { - var gfx = GetAmdGfxArch(); - if (gfx is null) - return false; - - return gfx.StartsWith("gfx110") || gfx.StartsWith("gfx120") || gfx.Equals("gfx1151"); + return WindowsRocmSupport.IsSupportedGpu(this); } public bool IsAmd => Name?.Contains("amd", StringComparison.OrdinalIgnoreCase) ?? false; @@ -84,7 +82,7 @@ public bool IsWindowsRocmSupportedGpu() return name switch { // RDNA4 - _ when Has("R9700") || Has("9070") => "gfx1201", + _ when Has("R9700") || Has("R9600") || Has("9070") => "gfx1201", _ when Has("9060") => "gfx1200", // RDNA3.5 APUs @@ -112,6 +110,9 @@ _ when Has("660M") || Has("680M") => "gfx1035", _ when Has("6300") || Has("6400") || Has("6450") || Has("6500") || Has("6550") || Has("6500M") => "gfx1034", + // RDNA2 Steam Deck APU + _ when Has("Van Gogh") || Has("Sephiroth") => "gfx1033", + // RDNA2 Navi23 _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M") => "gfx1032", @@ -121,6 +122,16 @@ _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031", // RDNA2 Navi21 (big die) _ when Has("6800") || Has("6900") || Has("6950") => "gfx1030", + // RDNA1 Navi10 XT (incl. Pro card) + _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010", + + // RDNA1 Navi10 XTX + _ when Has("5500") => "gfx1012", + + // Vega/GCN5 Dedicated GPUs + _ when Has("pro vii") || HasNoSpace("provii") => "gfx90X", + _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", + _ when Has("radeon vii") || HasNoSpace("radeonvii") => "gfx906", _ => null, }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs index 8458c730b..93f093d41 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs @@ -7,6 +7,7 @@ using Microsoft.Win32; using NLog; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.Rocm; namespace StabilityMatrix.Core.Helper.HardwareInfo; @@ -316,12 +317,11 @@ public static bool HasAmdGpu() return IterGpuInfo().Any(gpu => gpu.IsAmd); } - public static bool HasWindowsRocmSupportedGpu() => - IterGpuInfo().Any(gpu => gpu is { IsAmd: true, Name: not null } && gpu.IsWindowsRocmSupportedGpu()); + public static bool HasWindowsRocmSupportedGpu() => IterGpuInfo().Any(WindowsRocmSupport.IsSupportedGpu); public static GpuInfo? GetWindowsRocmSupportedGpu() { - return IterGpuInfo().FirstOrDefault(gpu => gpu.IsWindowsRocmSupportedGpu()); + return IterGpuInfo().FirstOrDefault(WindowsRocmSupport.IsSupportedGpu); } public static bool HasIntelGpu() => IterGpuInfo().Any(gpu => gpu.IsIntel); diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a4c34649d..759406c35 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -13,9 +13,11 @@ using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -26,7 +28,8 @@ public class ComfyUI( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -38,6 +41,7 @@ IPipWheelService pipWheelService ) { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -247,7 +251,7 @@ IPipWheelService pipWheelService Name = "Enable DirectML", Type = LaunchOptionType.Bool, InitialValue = - !HardwareHelper.HasWindowsRocmSupportedGpu() + !HasWindowsRocmSupport() && HardwareHelper.PreferDirectMLOrZluda() && this is not ComfyZluda, Options = ["--directml"], @@ -264,7 +268,9 @@ IPipWheelService pipWheelService { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = "--use-pytorch-cross-attention", + InitialValue = DefaultToQuadCrossAttention() + ? "--use-quad-cross-attention" // For Legacy AMD GPUs. + : "--use-pytorch-cross-attention", Options = [ "--use-split-cross-attention", @@ -362,69 +368,36 @@ public override async Task InstallPackage( .ConfigureAwait(false); var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion(); - var gfxArch = - SettingsManager.Settings.PreferredGpu?.GetAmdGfxArch() - ?? HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - - // Special case for Windows ROCm Nightly builds - if ( - Compat.IsWindows - && !string.IsNullOrWhiteSpace(gfxArch) - && torchIndex is TorchIndex.Rocm - && options.PythonOptions.PythonVersion >= PyVersion.Parse("3.11.0") - ) + var isLegacyNvidia = + torchIndex == TorchIndex.Cuda + && ( + SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() + ?? HardwareHelper.HasLegacyNvidiaGpu() + ); + + if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport()) { - var config = new PipInstallConfig + if (rocmPackageHelper is null) { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - SkipTorchInstall = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - await StandardPipInstallProcessAsync( + throw new InvalidOperationException( + "Windows ROCm installation requires the shared ROCm helper to resolve gfx-specific index URLs." + ); + } + + await rocmPackageHelper + .InstallWindowsNativePackageAsync( venvRunner, - options, + installLocation, installedPackage, - config, - onConsoleOutput, + WindowsRocmProfile, progress, + onConsoleOutput, cancellationToken ) .ConfigureAwait(false); - - progress?.Report( - new ProgressReport(-1f, "Installing ROCm nightly torch...", isIndeterminate: true) - ); - var indexUrl = gfxArch switch - { - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150", // Strix/Gorgon Point - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151", // Strix Halo - _ when gfxArch.StartsWith("gfx110") => "https://rocm.nightlies.amd.com/v2/gfx110X-all", - _ when gfxArch.StartsWith("gfx120") => "https://rocm.nightlies.amd.com/v2/gfx120X-all", - _ => throw new ArgumentOutOfRangeException( - nameof(gfxArch), - $"Unsupported GFX Arch: {gfxArch}" - ), - }; - - var torchPipArgs = new PipInstallArgs() - .AddArgs("--pre", "--upgrade") - .WithTorch() - .WithTorchVision() - .WithTorchAudio() - .AddArgs("--index-url", indexUrl); - - await venvRunner.PipInstall(torchPipArgs, onConsoleOutput).ConfigureAwait(false); } - else // Standard installation path for all other cases + else { - var isLegacyNvidia = - torchIndex == TorchIndex.Cuda - && ( - SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() - ?? HardwareHelper.HasLegacyNvidiaGpu() - ); - var config = new PipInstallConfig { RequirementsFilePaths = ["requirements.txt"], @@ -479,7 +452,11 @@ await StandardPipInstallProcessAsync( SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu(), WorkingDirectory = installLocation, - EnvironmentVariables = GetEnvVars(venvRunner.EnvironmentVariables), + EnvironmentVariables = GetEnvVars( + venvRunner.EnvironmentVariables, + installLocation, + installedPackage + ), }; await step.ExecuteAsync(progress).ConfigureAwait(false); @@ -529,7 +506,7 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); - VenvRunner.UpdateEnvironmentVariables(GetEnvVars); + VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage)); // Check for old NVIDIA driver version with cu130 installations var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu(); @@ -613,13 +590,7 @@ public override TorchIndex GetRecommendedTorchVersion() { var preferRocm = (Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm())) - || ( - Compat.IsWindows - && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() - ) - ); + || HasWindowsRocmSupport(); if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm) { @@ -629,6 +600,54 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } + /// Windows ROCm install profile for ComfyUI. + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + RequiresRocmSdk = true, + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, + }; + + private static IReadOnlyDictionary BuildComfyWindowsRocmEnvironment( + RocmRuntimeContext runtimeContext + ) + { + return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) + ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } + : new Dictionary(); + } + + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + return false; + + var compatibility = rocmPackageHelper.GetCompatibility(WindowsRocmProfile); + + return compatibility.IsCompatible; + } + + /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower + /// and not as supported on older AMD architectures. + private bool DefaultToQuadCrossAttention() + { + if (!Compat.IsWindows || !HasWindowsRocmSupport()) + return false; + + var gpu = SettingsManager.Settings.PreferredGpu; + var gfxArch = WindowsRocmSupport.IsSupportedGpu(gpu) + ? gpu?.GetAmdGfxArch() + : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); + + return WindowsRocmSupport.PreferLegacyAttentionFallback(gfxArch); + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -979,21 +998,29 @@ await PipWheelService .ConfigureAwait(false); } - private ImmutableDictionary GetEnvVars(ImmutableDictionary env) + private ImmutableDictionary GetEnvVars( + ImmutableDictionary env, + string installLocation, + InstalledPackage installedPackage + ) { // if we're not on windows or we don't have a windows rocm gpu, return original env - var hasRocmGpu = - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + var hasRocmGpu = HasWindowsRocmSupport(); if (!Compat.IsWindows || !hasRocmGpu) return env; - // set some experimental speed improving env vars for Windows ROCm - return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1") - .SetItem("MIOPEN_FIND_MODE", "2") - .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - .SetItem("PYTORCH_ALLOC_CONF", "max_split_size_mb:6144,garbage_collection_threshold:0.8") // greatly helps prevent GPU OOM and instability/driver timeouts/OS hard locks and decreases dependency on Tiled VAE at standard res's - .SetItem("COMFYUI_ENABLE_MIOPEN", "1"); // re-enables "cudnn" in ComfyUI as it's needed for MiOpen to function properly + if (rocmPackageHelper is not null) + { + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); + + return env.SetItems(rocmEnvironment); + } + + return env; } } diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 2a00a6262..0c91f23a2 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -6,9 +6,11 @@ using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -30,7 +32,8 @@ public class Wan2GP( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -41,6 +44,13 @@ IPipWheelService pipWheelService pipWheelService ) { + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + RequiresRocmSdk = true, + UpgradePackages = true, + PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], + }; + public override string Name => "Wan2GP"; public override string DisplayName { get; set; } = "Wan2GP"; public override string Author => "deepbeepmeep"; @@ -64,7 +74,7 @@ IPipWheelService pipWheelService public override bool IsCompatible => HardwareHelper.HasNvidiaGpu() - || (Compat.IsWindows ? HardwareHelper.HasWindowsRocmSupportedGpu() : HardwareHelper.HasAmdGpu()); + || (Compat.IsWindows ? HasWindowsRocmSupport() : HardwareHelper.HasAmdGpu()); public override string MainBranch => "main"; public override bool ShouldIgnoreReleases => true; @@ -72,7 +82,7 @@ IPipWheelService pipWheelService public override Dictionary> SharedOutputFolders => new() { [SharedOutputType.Img2Vid] = ["outputs"] }; - // AMD ROCm requires Python 3.11, NVIDIA uses 3.10 + // Wan2GP currently uses Python 3.11 for ROCm and 3.10 for CUDA. public override PyVersion RecommendedPythonVersion => IsAmdRocm ? Python.PyInstallationManager.Python_3_11_13 : Python.PyInstallationManager.Python_3_10_17; @@ -86,6 +96,17 @@ IPipWheelService pipWheelService /// private bool IsAmdRocm => GetRecommendedTorchVersion() == TorchIndex.Rocm; + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + return HardwareHelper.HasWindowsRocmSupportedGpu(); + + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile).IsCompatible; + } + /// /// Python wrapper script that patches logging to also print to stdout/stderr, so /// StabilityMatrix can capture the output. Wan2GP logs through Gradio UI notifications @@ -213,8 +234,8 @@ public override TorchIndex GetRecommendedTorchVersion() ( Compat.IsWindows && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() + WindowsRocmSupport.IsSupportedGpu(SettingsManager.Settings.PreferredGpu) + || HasWindowsRocmSupport() ) ) || ( @@ -256,7 +277,15 @@ public override async Task InstallPackage( if (torchIndex == TorchIndex.Rocm) { - await InstallAmdRocmAsync(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + await InstallAmdRocmAsync( + venvRunner, + installLocation, + installedPackage, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); } else { @@ -359,68 +388,53 @@ await venvRunner private async Task InstallAmdRocmAsync( IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, IProgress? progress, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { - progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); - await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - if (Compat.IsWindows) { - // Windows AMD ROCm - special TheRock wheels - progress?.Report( - new ProgressReport(-1f, "Installing PyTorch ROCm wheels...", isIndeterminate: true) - ); - - // Set environment variable for wheel filename check bypass - venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1")); + if (rocmPackageHelper is null) + { + throw new InvalidOperationException( + "Windows ROCm installation for Wan2GP requires the shared ROCm helper." + ); + } - // Install PyTorch ROCm wheels from TheRock releases (Python 3.11) - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl", - onConsoleOutput + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken ) .ConfigureAwait(false); - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + return; + } - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - } - else - { - // Linux AMD ROCm - standard PyTorch ROCm - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - - // Install torch with ROCm index (force reinstall to ensure correct version) - progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .WithTorch("==2.7.0") - .WithTorchVision("==0.22.0") - .WithTorchAudio("==2.7.0") - .WithTorchExtraIndex("rocm6.3") - .AddArg("--force-reinstall") - .AddArg("--no-deps"); - - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - } + progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); + await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .WithTorch() + .WithTorchVision() + .WithTorchAudio() + .WithTorchExtraIndex("rocm7.2") + .AddArg("--force-reinstall") + .AddArg("--no-deps"); + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); // Install additional packages await venvRunner.PipInstall("hf-xet setuptools numpy==1.26.4", onConsoleOutput).ConfigureAwait(false); @@ -437,6 +451,16 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); + if (Compat.IsWindows && rocmPackageHelper is not null && HasWindowsRocmSupport()) + { + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); + VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); + } + // Fix for distutils compatibility issue with Python 3.10 and setuptools VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib")); diff --git a/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs new file mode 100644 index 000000000..401f3ada4 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs @@ -0,0 +1,17 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Describes whether a package/profile is currently compatible with ROCm on the active machine. +/// +public class RocmCompatibilityResult +{ + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? ResolvedGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs new file mode 100644 index 000000000..e01261700 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -0,0 +1,57 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Controls how ROCm helper defaults, package-specific variables, and user overrides are layered at launch. +/// +public class RocmEnvironmentOptions +{ + /// + /// When true, package-specific environment additions may be merged on top of helper defaults. + /// + public bool IncludePackageOverrides { get; init; } = true; + + /// + /// When true, user-defined Stability Matrix environment variables may override helper/package defaults last. + /// + public bool IncludeUserOverrides { get; init; } = true; + + /// + /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. + /// + public string? PyTorchAllocConf { get; init; } = "max_split_size_mb:512,garbage_collection_threshold:0.8"; + + /// + /// When set, configures MIOpen find mode for helper-managed ROCm defaults. + /// + public string? MiopenFindMode { get; init; } = "2"; + + /// + /// When set, configures MIOpen search cutoff for helper-managed ROCm defaults. + /// + public string? MiopenSearchCutoff { get; init; } = "1"; + + /// + /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults. + /// + public string? MiopenFindEnforce { get; init; } = "3"; + + /// + /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults. + /// + public string? FlashAttentionTritonAmdEnable { get; init; } = "TRUE"; + + /// + /// When true, helper-managed defaults will enable ROCm AOTriton on modern Windows ROCm architectures. + /// + public bool ApplyAotritonExperimental { get; init; } = true; + + /// + /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures. + /// + public bool ApplyLegacySdpFallback { get; init; } = true; + + /// + /// When true, helper-managed defaults will apply the RDNA1 HSA override mask when needed. + /// + public bool ApplyRdna1Override { get; init; } = true; +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs new file mode 100644 index 000000000..597eb4fe6 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -0,0 +1,11 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures ROCm-related facts needed during package install or update flows. +/// +public class RocmInstallContext +{ + public string? RuntimeGfxArch { get; init; } + + public string? RocmPackageIndexUrl { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs new file mode 100644 index 000000000..d1abf16f8 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -0,0 +1,55 @@ +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Declares what a package expects from the ROCm helper. +/// Package classes should describe intent here rather than hardcoding ROCm decisions inline. +/// +public class RocmPackageProfile +{ + public bool RequiresRocmSdk { get; init; } + + /// + /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. + /// + public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"]; + + /// + /// Package requirement entries to exclude because the helper installs them from ROCm-specific feeds. + /// + public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?"; + + /// + /// Extra package-specific pip arguments to include when installing requirements after helper bootstrap. + /// + public IEnumerable ExtraInstallPipArgs { get; init; } = []; + + /// + /// Extra package-specific pip arguments to install after requirements and torch are complete. + /// + public IEnumerable PostInstallPipArgs { get; init; } = []; + + /// + /// When true, helper-managed requirements installs should use --upgrade. + /// + public bool UpgradePackages { get; init; } + + /// + /// When true, helper-managed torch installs should force reinstall the selected ROCm wheel set. + /// + public bool ForceReinstallTorch { get; init; } = true; + + /// + /// Optional callback for package-specific environment variables derived from a resolved ROCm context. + /// + public Func< + RocmRuntimeContext, + IReadOnlyDictionary + >? ExtraEnvironmentFactory { get; init; } + + /// + /// Controls whether package-specific environment variables should be layered on top of helper defaults. + /// + public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs new file mode 100644 index 000000000..1fdda7914 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -0,0 +1,18 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures resolved ROCm facts for a package launch or runtime decision. +/// This model is intended to separate hardware/runtime facts from package policy. +/// +public class RocmRuntimeContext +{ + public bool IsSupported { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs new file mode 100644 index 000000000..ee000b1bd --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -0,0 +1,68 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Centralizes Windows ROCm support and architecture policy so hardware detection, package selection, +/// installation, and shared launch decisions use the same support map. +/// +public static class WindowsRocmSupport +{ + public static bool IsSupportedGpu(GpuInfo? gpu) + { + if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + return IsSupportedArchitecture(gpu.GetAmdGfxArch()); + } + + public static bool IsSupportedArchitecture(string? gfxArch) + { + return TryGetPackageIndexUrl(gfxArch) is not null; + } + + public static bool IsModernArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + public static bool IsLegacyArchitecture(string? gfxArch) + { + return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch); + } + + public static bool PreferLegacyAttentionFallback(string? gfxArch) + { + return IsLegacyArchitecture(gfxArch); + } + + public static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + + public static string? TryGetPackageIndexUrl(string? gfxArch) + { + return gfxArch switch + { + "gfx900" => "https://rocm.nightlies.amd.com/v2-staging/gfx900/", // Vega 10 + "gfx906" => "https://rocm.nightlies.amd.com/v2-staging/gfx906/", // Radeon VII, Vega 20 + "gfx90X" => "https://rocm.nightlies.amd.com/v2-staging/gfx90X/", // Radeon Pro VII + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", // RDNA1 (5000 series, Pro) + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-all/", // RDNA2 (6000 series, 6xxM Mobile, Steam Deck) + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx110X-all/", // RDNA3 (7000 series, 7xxM Mobile) + "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", // RDNA3.5 (Strix/Gorgon Point) + "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", // RDNA3.5 (Strix Halo) + "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", // RDNA3.5 (Kraken Point) + "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", // RDNA3.5 (Medusa Point) + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx120X-all/", // RDNA4 (9000 series) + _ => null, + }; + } +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs new file mode 100644 index 000000000..03b4ce0c2 --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -0,0 +1,58 @@ +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Defines the ROCm helper surface area shared by ROCm-capable packages. +/// +public interface IRocmPackageHelper +{ + /// + /// Evaluates whether the current machine and package profile are compatible with ROCm. + /// + RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile); + + /// + /// Resolves the runtime ROCm facts needed for package launch and environment construction. + /// + RocmRuntimeContext ResolveRuntimeContext( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ); + + /// + /// Resolves the ROCm facts needed during package installation or update operations. + /// + RocmInstallContext ResolveInstallContext( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ); + + /// + /// Builds a launch-time environment dictionary from resolved ROCm runtime data. + /// + IReadOnlyDictionary BuildLaunchEnvironment( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ); + + /// + /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs. + /// + Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); +} diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs new file mode 100644 index 000000000..a5382ac91 --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -0,0 +1,793 @@ +using System.Collections.Immutable; +using System.Text.Json; +using Injectio.Attributes; +using NLog; +using StabilityMatrix.Core.Exceptions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Provides the shared ROCm helper surface area used by ROCm-capable packages. +/// +[RegisterSingleton] +public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper +{ + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; + + /// + public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) + { + return BuildCompatibilityResult(profile); + } + + /// + public RocmRuntimeContext ResolveRuntimeContext( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ) + { + _ = installLocation; + _ = installedPackage; + + var compatibility = BuildCompatibilityResult(profile); + if (!compatibility.IsCompatible) + { + return new RocmRuntimeContext + { + IsSupported = false, + FailureReason = compatibility.FailureReason, + SelectedGpu = compatibility.SelectedGpu, + RuntimeGfxArch = compatibility.ResolvedGfxArch, + }; + } + + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var selectedGpu = + compatibility.SelectedGpu + ?? TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) + ?? supportedAmdGpus.FirstOrDefault(); + + var runtimeGfxArch = + compatibility.ResolvedGfxArch + ?? selectedGpu?.GetAmdGfxArch() + ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + + return new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + }; + } + + /// + public RocmInstallContext ResolveInstallContext( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ) + { + _ = installLocation; + _ = installedPackage; + + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var preferredGfxArch = TryResolvePreferredAmdGfxArch( + supportedAmdGpus, + settingsManager.Settings.PreferredGpu + ); + + var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + var windowsNativeIndexUrl = WindowsRocmSupport.TryGetPackageIndexUrl(runtimeGfxArch); + + return new RocmInstallContext + { + RuntimeGfxArch = runtimeGfxArch, + RocmPackageIndexUrl = windowsNativeIndexUrl, + }; + } + + /// + public IReadOnlyDictionary BuildLaunchEnvironment( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile + ) + { + var runtimeContext = ResolveRuntimeContext(installLocation, installedPackage, profile); + + if (!runtimeContext.IsSupported) + return new Dictionary(); + + var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); + var packageEnvironment = + profile.ExtraEnvironmentFactory?.Invoke(runtimeContext) ?? new Dictionary(); + + var mergedEnvironment = MergeLaunchEnvironment( + helperEnvironment, + packageEnvironment, + profile.EnvironmentOptions + ); + + return mergedEnvironment; + } + + /// + public async Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var compatibility = GetCompatibility(profile); + if (!compatibility.IsCompatible) + { + throw new InvalidOperationException( + compatibility.FailureReason + ?? "Windows ROCm installation is not supported for the current machine." + ); + } + + var installContext = ResolveInstallContext(installLocation, installedPackage, profile); + + var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; + + if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl)) + { + throw new ApplicationException( + $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." + ); + } + + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); + var rocmRuntimeArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArgs("rocm[devel,libraries]"); + + if (installedPackage.PipOverrides != null) + { + rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + + progress?.Report( + new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) + ); + + var requirementsPipArgs = new PipInstallArgs([.. profile.ExtraInstallPipArgs]); + if (profile.UpgradePackages) + { + requirementsPipArgs = requirementsPipArgs.AddArg("--upgrade"); + } + + foreach (var relativePath in profile.RequirementsFilePaths) + { + var requirementsFile = new FilePath(venvRunner.WorkingDirectory ?? installLocation, relativePath); + if (!requirementsFile.Exists) + continue; + + var requirementsContent = await requirementsFile + .ReadAllTextAsync(cancellationToken) + .ConfigureAwait(false); + + requirementsPipArgs = requirementsPipArgs.WithParsedFromRequirementsTxt( + requirementsContent, + profile.RequirementsExcludePattern + ); + } + + if (installedPackage.PipOverrides != null) + { + requirementsPipArgs = requirementsPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .AddArg("--pre") + .AddArg("--upgrade") + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .WithTorch() + .WithTorchAudio() + .WithTorchVision(); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } + + if (installedPackage.PipOverrides != null) + { + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await AlignRocmSdkDevelVersionAsync(venvRunner, rocmPackageIndexUrl, onConsoleOutput) + .ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Reinitializing ROCm SDK...", isIndeterminate: true)); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + + if (profile.PostInstallPipArgs.Any()) + { + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + } + + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + } + + /// + /// Builds a compatibility result from the current machine state and package profile. + /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. + /// + private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) + { + var amdGpus = GetAmdGpuCandidates(forceRefresh: false).ToList(); + if (amdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = "No AMD GPU was detected for ROCm evaluation.", + }; + } + + var preferredGpu = settingsManager.Settings.PreferredGpu; + + var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); + if (supportedAmdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = GetUnsupportedGpuReason(amdGpus), + }; + } + + var selectedGpu = + TryResolvePreferredAmdGpu(supportedAmdGpus, preferredGpu) ?? supportedAmdGpus.First(); + var resolvedGfxArch = selectedGpu.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + + return new RocmCompatibilityResult + { + IsCompatible = !string.IsNullOrWhiteSpace(resolvedGfxArch), + FailureReason = string.IsNullOrWhiteSpace(resolvedGfxArch) + ? "No supported AMD GFX architecture could be resolved for ROCm." + : null, + SelectedGpu = selectedGpu, + ResolvedGfxArch = resolvedGfxArch, + }; + } + + /// + /// Returns AMD GPUs from Stability Matrix's internal hardware model. + /// This is the canonical GPU source for the ROCm helper and intentionally avoids package-local probing. + /// + private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = false) + { + return HardwareHelper.IterGpuInfo(forceRefresh).Where(gpu => gpu.IsAmd).ToList(); + } + + /// + /// Resolves the preferred AMD GPU when the configured preference is still present in the current hardware list. + /// + private static GpuInfo? TryResolvePreferredAmdGpu( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + if (preferredGpu is null || !preferredGpu.IsAmd) + return null; + + var preferredMatch = availableGpus.FirstOrDefault(gpu => gpu.Equals(preferredGpu)); + if (preferredMatch is not null) + return preferredMatch; + + if (!string.IsNullOrWhiteSpace(preferredGpu.Name)) + { + Logger.Info( + "Preferred GPU {PreferredGpuName} was ignored for ROCm detection because it is not present in current hardware enumeration.", + preferredGpu.Name + ); + } + + return null; + } + + /// + /// Resolves the preferred AMD GFX architecture when the configured GPU is supported and currently present. + /// + private static string? TryResolvePreferredAmdGfxArch( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu); + return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu) + ? resolvedPreferredGpu.GetAmdGfxArch() + : null; + } + + /// + /// Resolves the first supported AMD GFX architecture from the current machine state when no preferred GPU applies. + /// + private static string? GetSupportedFallbackGfxArch(IEnumerable availableGpus) + { + return availableGpus + .Where(IsSupportedWindowsRocmGpu) + .Select(gpu => gpu.GetAmdGfxArch()) + .FirstOrDefault(IsSupportedWindowsRocmArchitecture); + } + + /// + /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) + { + return WindowsRocmSupport.IsSupportedGpu(gpu); + } + + /// + /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return WindowsRocmSupport.IsSupportedArchitecture(gfxArch); + } + + /// + /// Produces a readable incompatibility reason when AMD hardware is present but not usable for Windows ROCm. + /// + private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) + { + _ = amdGpus; + return "No AMD GPU with a supported Windows ROCm architecture was detected."; + } + + /// + /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. + /// + private static async Task VerifyWindowsNativeTorchInstallAsync( + IPyVenvRunner venvRunner, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new ApplicationException("torch was not installed after Windows ROCm setup."); + } + + var verificationResult = await venvRunner + .Run( + "-c \"import json, torch; print(json.dumps({'version': torch.__version__, 'hip': torch.version.hip, 'cuda': torch.cuda.is_available()}))\"" + ) + .ConfigureAwait(false); + + var verificationOutput = (verificationResult.StandardOutput ?? string.Empty).Trim(); + if (string.IsNullOrWhiteSpace(verificationOutput)) + { + throw new ApplicationException("Torch verification produced no output."); + } + + var verificationJson = TryExtractJsonObject(verificationOutput); + if (string.IsNullOrWhiteSpace(verificationJson)) + { + throw new ApplicationException($"Unexpected torch verification output: {verificationOutput}"); + } + + JsonDocument verificationDocument; + try + { + verificationDocument = JsonDocument.Parse(verificationJson); + } + catch (Exception exception) + { + throw new ApplicationException( + $"Unexpected torch verification output: {verificationOutput}", + exception + ); + } + + using (verificationDocument) + { + var root = verificationDocument.RootElement; + var version = root.TryGetProperty("version", out var versionElement) + ? versionElement.GetString() + : null; + var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; + var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); + + if (!IsUsableWindowsNativeTorchBuild(version, hipVersion)) + { + throw new ApplicationException( + $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" + ); + } + + if (!cudaAvailable) + { + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Torch verification warning: installed ROCm torch build reported cuda={cudaAvailable}; continuing because ROCm metadata was detected (version={version}, hip={hipVersion})." + ) + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" + ) + ); + } + + _ = cancellationToken; + } + + /// + /// Runs rocm-sdk init after the helper-managed runtime packages are installed so the Windows ROCm SDK can prepare the venv. + /// + private static async Task InitializeWindowsNativeRocmSdkAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( + rocmSdkExe, + ["init"], + installLocation, + onConsoleOutput + ); + + await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + if (rocmSdkProcess.ExitCode != 0) + { + throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); + } + } + + /// + /// Uses AMD's bundled hipInfo.exe to confirm the installed Windows ROCm runtime can enumerate a ROCm-capable GPU. + /// + private static async Task VerifyWindowsNativeRocmRuntimeAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + var rocmBinResult = await ProcessRunner + .GetProcessResultAsync(rocmSdkExe, ["path", "--bin"], installLocation, useUtf8Encoding: true) + .ConfigureAwait(false); + + var rocmBinPath = (rocmBinResult.StandardOutput ?? string.Empty).Trim(); + if (!rocmBinResult.IsSuccessExitCode || string.IsNullOrWhiteSpace(rocmBinPath)) + { + var rocmBinOutput = CombineProcessOutput( + rocmBinResult.StandardOutput, + rocmBinResult.StandardError + ); + throw new ApplicationException( + $"ROCm runtime verification failed while resolving the ROCm SDK bin path. Output: {rocmBinOutput}" + ); + } + + var hipInfoExe = Path.Combine(rocmBinPath, $"hipInfo{Compat.ExeExtension}"); + if (!File.Exists(hipInfoExe)) + { + throw new FileNotFoundException( + "hipInfo.exe was not found in the ROCm SDK bin directory", + hipInfoExe + ); + } + + var hipInfoResult = await ProcessRunner + .GetProcessResultAsync( + hipInfoExe, + [], + installLocation, + new Dictionary { ["PATH"] = rocmBinPath }, + useUtf8Encoding: true + ) + .ConfigureAwait(false); + + var hipInfoOutput = CombineProcessOutput(hipInfoResult.StandardOutput, hipInfoResult.StandardError); + if (!hipInfoResult.IsSuccessExitCode) + { + var runtimeFailureReason = TryGetWindowsNativeRocmRuntimeFailureReason(hipInfoOutput); + throw new ApplicationException( + runtimeFailureReason is null + ? $"ROCm runtime verification failed while probing the installed runtime with hipInfo.exe. Output: {hipInfoOutput}" + : $"ROCm runtime verification failed: {runtimeFailureReason} Output: {hipInfoOutput}" + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"ROCm runtime verification succeeded via hipInfo.exe: {hipInfoOutput}" + ) + ); + + _ = cancellationToken; + } + + /// + /// Reinstalls rocm-sdk-devel to the resolved ROCm build version when the torch step downgrades the runtime stack. + /// + private static async Task AlignRocmSdkDevelVersionAsync( + IPyVenvRunner venvRunner, + string rocmPackageIndexUrl, + Action? onConsoleOutput + ) + { + var rocmInfo = await venvRunner.PipShow("rocm").ConfigureAwait(false); + var rocmSdkDevelInfo = await venvRunner.PipShow("rocm-sdk-devel").ConfigureAwait(false); + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + + var targetVersion = GetRocmSdkDevelAlignmentVersion( + rocmInfo?.Version, + rocmSdkDevelInfo?.Version, + torchInfo?.Version + ); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return; + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Aligning rocm-sdk-devel from version={rocmSdkDevelInfo?.Version ?? "not-installed"} to version={targetVersion} to match the resolved ROCm torch/runtime build." + ) + ); + + var alignmentArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArg("--force-reinstall") + .AddArg($"rocm-sdk-devel=={targetVersion}"); + + await venvRunner.PipInstall(alignmentArgs, onConsoleOutput).ConfigureAwait(false); + } + + internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion) + { + if (!string.IsNullOrWhiteSpace(hipVersion)) + return true; + + return !string.IsNullOrWhiteSpace(version) + && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); + } + + internal static string? GetRocmSdkDevelAlignmentVersion( + string? rocmVersion, + string? rocmSdkDevelVersion, + string? torchVersion = null + ) + { + var targetVersion = !string.IsNullOrWhiteSpace(rocmVersion) + ? rocmVersion + : TryExtractRocmBuildVersion(torchVersion); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return null; + + return string.Equals(targetVersion, rocmSdkDevelVersion, StringComparison.OrdinalIgnoreCase) + ? null + : targetVersion; + } + + internal static string? TryGetWindowsNativeRocmRuntimeFailureReason(string? output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + if (output.Contains("no ROCm-capable device is detected", StringComparison.OrdinalIgnoreCase)) + { + return "the installed ROCm runtime could not detect a ROCm-capable GPU on this system."; + } + + if (output.Contains("No WDDM adapters found", StringComparison.OrdinalIgnoreCase)) + { + return "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack."; + } + + return null; + } + + internal static string? TryExtractRocmBuildVersion(string? torchVersion) + { + if (string.IsNullOrWhiteSpace(torchVersion)) + return null; + + var rocmMarkerIndex = torchVersion.IndexOf("rocm", StringComparison.OrdinalIgnoreCase); + if (rocmMarkerIndex < 0) + return null; + + var rocmBuildVersion = torchVersion[(rocmMarkerIndex + "rocm".Length)..].Trim(); + return string.IsNullOrWhiteSpace(rocmBuildVersion) ? null : rocmBuildVersion; + } + + internal static string? TryExtractJsonObject(string output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + var trimmedOutput = output.Trim(); + + for (var index = 0; index < trimmedOutput.Length; index++) + { + if (trimmedOutput[index] != '{') + continue; + + try + { + using var document = JsonDocument.Parse(trimmedOutput[index..]); + return document.RootElement.GetRawText(); + } + catch (JsonException) { } + } + + return null; + } + + internal static string CombineProcessOutput(string? standardOutput, string? standardError) + { + var sections = new[] { standardOutput?.Trim(), standardError?.Trim() }.Where(section => + !string.IsNullOrWhiteSpace(section) + ); + + return string.Join(Environment.NewLine, sections); + } + + /// + /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. + /// + private IReadOnlyDictionary BuildHelperLaunchEnvironment( + RocmRuntimeContext runtimeContext, + RocmPackageProfile profile + ) + { + var environment = new Dictionary(EnvComparer); + var options = profile.EnvironmentOptions; + var gfxArch = runtimeContext.RuntimeGfxArch; + + ApplyDefaultLaunchEnvironment(environment, gfxArch, options); + + return environment; + } + + private void ApplyDefaultLaunchEnvironment( + IDictionary environment, + string? gfxArch, + RocmEnvironmentOptions options + ) + { + SetIfNotNull(environment, "FLASH_ATTENTION_TRITON_AMD_ENABLE", options.FlashAttentionTritonAmdEnable); + SetIfNotNull(environment, "MIOPEN_FIND_MODE", options.MiopenFindMode); + SetIfNotNull(environment, "MIOPEN_SEARCH_CUTOFF", options.MiopenSearchCutoff); + SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); + SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); + + if (options.ApplyAotritonExperimental && WindowsRocmSupport.IsModernArchitecture(gfxArch)) + { + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + } + + if (options.ApplyLegacySdpFallback && WindowsRocmSupport.IsLegacyArchitecture(gfxArch)) + { + environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; + } + + if (options.ApplyRdna1Override && WindowsRocmSupport.IsRdna1Architecture(gfxArch)) + { + environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; + } + } + + private static void SetIfNotNull(IDictionary environment, string key, string? value) + { + if (!string.IsNullOrWhiteSpace(value)) + { + environment[key] = value; + } + } + + /// + /// Merges helper-owned and package-specific launch environment variables. + /// + private IReadOnlyDictionary MergeLaunchEnvironment( + IReadOnlyDictionary helperEnvironment, + IReadOnlyDictionary packageEnvironment, + RocmEnvironmentOptions options + ) + { + var merged = new Dictionary(EnvComparer); + + foreach (var source in new[] { helperEnvironment, packageEnvironment }) + { + if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) + continue; + + foreach (var pair in source) + { + merged[pair.Key] = pair.Value; + } + } + + if ( + options.IncludeUserOverrides + && settingsManager.Settings.EnvironmentVariables is { Count: > 0 } userOverrides + ) + { + foreach (var pair in userOverrides) + { + merged[pair.Key] = pair.Value; + } + } + + return merged; + } +} diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs new file mode 100644 index 000000000..0afb7b86d --- /dev/null +++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs @@ -0,0 +1,176 @@ +using System.Text.Json; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class RocmPackageHelperTests +{ + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsRocmVersion_WhenVersionsMismatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260501" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsNull_WhenVersionsAlreadyMatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260416" + ); + + Assert.IsNull(targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_FallsBackToTorchBuildVersion() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: null, + rocmSdkDevelVersion: "7.13.0a20260501", + torchVersion: "2.11.0+rocm7.13.0a20260416" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsNull_WhenTorchVersionHasNoRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0"); + + Assert.IsNull(rocmBuildVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsVersionSuffix_WhenTorchVersionContainsRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0+rocm7.13.0a20260416"); + + Assert.AreEqual("7.13.0a20260416", rocmBuildVersion); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenHipMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: "test-hip-version" + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenVersionContainsRocm() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version+rocm", + hipVersion: null + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsFalse_WhenNoRocmMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: null + ); + + Assert.IsFalse(isUsable); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsJson_WhenOutputContainsDiagnosticPrefix() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output" + + "\nwarning: continuing with torch verification" + + "\n{\"version\": \"test-version\", \"hip\": \"test-hip-version\", \"cuda\": false}"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNotNull(json); + + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + + Assert.AreEqual("test-version", root.GetProperty("version").GetString()); + Assert.AreEqual("test-hip-version", root.GetProperty("hip").GetString()); + Assert.IsFalse(root.GetProperty("cuda").GetBoolean()); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output\n" + + "warning: no JSON payload was produced"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNull(json); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsDeviceDetectionMessage() + { + const string output = "checkHipErrors() HIP API error = 0100 \"no ROCm-capable device is detected\""; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.", + reason + ); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsWddmMessage() + { + const string output = "warning: No WDDM adapters found."; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.", + reason + ); + } + + [TestMethod] + public void CombineProcessOutput_JoinsStdoutAndStderr() + { + var combined = RocmPackageHelper.CombineProcessOutput("stdout line", "stderr line"); + + Assert.AreEqual($"stdout line{Environment.NewLine}stderr line", combined); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetPackageIndexUrl_ReturnsExpectedIndex_ForKrakenPoint() + { + var indexUrl = WindowsRocmSupport.TryGetPackageIndexUrl("gfx1152"); + + Assert.AreEqual("https://rocm.nightlies.amd.com/v2-staging/gfx1152/", indexUrl); + } + + [TestMethod] + public void WindowsRocmSupport_IsSupportedGpu_ReturnsTrue_ForSupportedAmdGpu() + { + var gpu = new GpuInfo { Name = "AMD Radeon RX 9070 XT", MemoryBytes = 16UL * Size.GiB }; + + Assert.IsTrue(WindowsRocmSupport.IsSupportedGpu(gpu)); + } +} diff --git a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs index f78027039..bf45d60c4 100644 --- a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs +++ b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs @@ -24,6 +24,8 @@ public void Setup() null!, null!, null!, + null!, + null!, null! ); }