diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs
index 118efa55c..0cb0b808f 100644
--- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs
+++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs
@@ -4,6 +4,7 @@
using StabilityMatrix.Core.Models.Packages;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
+using StabilityMatrix.Core.Services.Rocm;
namespace StabilityMatrix.Core.Helper.Factory;
@@ -18,6 +19,7 @@ public class PackageFactory : IPackageFactory
private readonly IUvManager uvManager;
private readonly IPyInstallationManager pyInstallationManager;
private readonly IPipWheelService pipWheelService;
+ private readonly IRocmPackageHelper rocmPackageHelper;
///
/// Mapping of package.Name to package
@@ -32,7 +34,9 @@ public PackageFactory(
IPrerequisiteHelper prerequisiteHelper,
IPyInstallationManager pyInstallationManager,
IPyRunner pyRunner,
- IPipWheelService pipWheelService
+ IUvManager uvManager,
+ IPipWheelService pipWheelService,
+ IRocmPackageHelper rocmPackageHelper
)
{
this.githubApiCache = githubApiCache;
@@ -40,8 +44,10 @@ IPipWheelService pipWheelService
this.downloadService = downloadService;
this.prerequisiteHelper = prerequisiteHelper;
this.pyRunner = pyRunner;
+ this.uvManager = uvManager;
this.pyInstallationManager = pyInstallationManager;
this.pipWheelService = pipWheelService;
+ this.rocmPackageHelper = rocmPackageHelper;
this.basePackages = basePackages.ToDictionary(x => x.Name);
}
@@ -55,7 +61,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage)
downloadService,
prerequisiteHelper,
pyInstallationManager,
- pipWheelService
+ pipWheelService,
+ rocmPackageHelper
),
"Fooocus" => new Fooocus(
githubApiCache,
@@ -280,7 +287,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage)
downloadService,
prerequisiteHelper,
pyInstallationManager,
- pipWheelService
+ pipWheelService,
+ rocmPackageHelper
),
_ => throw new ArgumentOutOfRangeException(nameof(installedPackage)),
};
diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs
index eedcb556c..0013f65bc 100644
--- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs
+++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs
@@ -1,4 +1,6 @@
-namespace StabilityMatrix.Core.Helper.HardwareInfo;
+using StabilityMatrix.Core.Models.Rocm;
+
+namespace StabilityMatrix.Core.Helper.HardwareInfo;
public record GpuInfo
{
@@ -62,11 +64,7 @@ public bool IsLegacyNvidiaGpu()
public bool IsWindowsRocmSupportedGpu()
{
- var gfx = GetAmdGfxArch();
- if (gfx is null)
- return false;
-
- return gfx.StartsWith("gfx110") || gfx.StartsWith("gfx120") || gfx.Equals("gfx1151");
+ return WindowsRocmSupport.IsSupportedGpu(this);
}
public bool IsAmd => Name?.Contains("amd", StringComparison.OrdinalIgnoreCase) ?? false;
@@ -84,7 +82,7 @@ public bool IsWindowsRocmSupportedGpu()
return name switch
{
// RDNA4
- _ when Has("R9700") || Has("9070") => "gfx1201",
+ _ when Has("R9700") || Has("R9600") || Has("9070") => "gfx1201",
_ when Has("9060") => "gfx1200",
// RDNA3.5 APUs
@@ -112,6 +110,9 @@ _ when Has("660M") || Has("680M") => "gfx1035",
_ when Has("6300") || Has("6400") || Has("6450") || Has("6500") || Has("6550") || Has("6500M") =>
"gfx1034",
+ // RDNA2 Steam Deck APU
+ _ when Has("Van Gogh") || Has("Sephiroth") => "gfx1033",
+
// RDNA2 Navi23
_ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M") => "gfx1032",
@@ -121,6 +122,16 @@ _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031",
// RDNA2 Navi21 (big die)
_ when Has("6800") || Has("6900") || Has("6950") => "gfx1030",
+ // RDNA1 Navi10 XT (incl. Pro card)
+ _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010",
+
+ // RDNA1 Navi10 XTX
+ _ when Has("5500") => "gfx1012",
+
+ // Vega/GCN5 Dedicated GPUs
+ _ when Has("pro vii") || HasNoSpace("provii") => "gfx90X",
+ _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900",
+ _ when Has("radeon vii") || HasNoSpace("radeonvii") => "gfx906",
_ => null,
};
diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs
index 8458c730b..93f093d41 100644
--- a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs
+++ b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs
@@ -7,6 +7,7 @@
using Microsoft.Win32;
using NLog;
using StabilityMatrix.Core.Extensions;
+using StabilityMatrix.Core.Models.Rocm;
namespace StabilityMatrix.Core.Helper.HardwareInfo;
@@ -316,12 +317,11 @@ public static bool HasAmdGpu()
return IterGpuInfo().Any(gpu => gpu.IsAmd);
}
- public static bool HasWindowsRocmSupportedGpu() =>
- IterGpuInfo().Any(gpu => gpu is { IsAmd: true, Name: not null } && gpu.IsWindowsRocmSupportedGpu());
+ public static bool HasWindowsRocmSupportedGpu() => IterGpuInfo().Any(WindowsRocmSupport.IsSupportedGpu);
public static GpuInfo? GetWindowsRocmSupportedGpu()
{
- return IterGpuInfo().FirstOrDefault(gpu => gpu.IsWindowsRocmSupportedGpu());
+ return IterGpuInfo().FirstOrDefault(WindowsRocmSupport.IsSupportedGpu);
}
public static bool HasIntelGpu() => IterGpuInfo().Any(gpu => gpu.IsIntel);
diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
index a4c34649d..759406c35 100644
--- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
+++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs
@@ -13,9 +13,11 @@
using StabilityMatrix.Core.Models.Packages.Config;
using StabilityMatrix.Core.Models.Packages.Extensions;
using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Models.Rocm;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
+using StabilityMatrix.Core.Services.Rocm;
namespace StabilityMatrix.Core.Models.Packages;
@@ -26,7 +28,8 @@ public class ComfyUI(
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper,
IPyInstallationManager pyInstallationManager,
- IPipWheelService pipWheelService
+ IPipWheelService pipWheelService,
+ IRocmPackageHelper? rocmPackageHelper = null
)
: BaseGitPackage(
githubApi,
@@ -38,6 +41,7 @@ IPipWheelService pipWheelService
)
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+
public override string Name => "ComfyUI";
public override string DisplayName { get; set; } = "ComfyUI";
public override string Author => "comfyanonymous";
@@ -247,7 +251,7 @@ IPipWheelService pipWheelService
Name = "Enable DirectML",
Type = LaunchOptionType.Bool,
InitialValue =
- !HardwareHelper.HasWindowsRocmSupportedGpu()
+ !HasWindowsRocmSupport()
&& HardwareHelper.PreferDirectMLOrZluda()
&& this is not ComfyZluda,
Options = ["--directml"],
@@ -264,7 +268,9 @@ IPipWheelService pipWheelService
{
Name = "Cross Attention Method",
Type = LaunchOptionType.Bool,
- InitialValue = "--use-pytorch-cross-attention",
+ InitialValue = DefaultToQuadCrossAttention()
+ ? "--use-quad-cross-attention" // For Legacy AMD GPUs.
+ : "--use-pytorch-cross-attention",
Options =
[
"--use-split-cross-attention",
@@ -362,69 +368,36 @@ public override async Task InstallPackage(
.ConfigureAwait(false);
var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion();
- var gfxArch =
- SettingsManager.Settings.PreferredGpu?.GetAmdGfxArch()
- ?? HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch();
-
- // Special case for Windows ROCm Nightly builds
- if (
- Compat.IsWindows
- && !string.IsNullOrWhiteSpace(gfxArch)
- && torchIndex is TorchIndex.Rocm
- && options.PythonOptions.PythonVersion >= PyVersion.Parse("3.11.0")
- )
+ var isLegacyNvidia =
+ torchIndex == TorchIndex.Cuda
+ && (
+ SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu()
+ ?? HardwareHelper.HasLegacyNvidiaGpu()
+ );
+
+ if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport())
{
- var config = new PipInstallConfig
+ if (rocmPackageHelper is null)
{
- RequirementsFilePaths = ["requirements.txt"],
- ExtraPipArgs = ["numpy<2"],
- SkipTorchInstall = true,
- PostInstallPipArgs = ["typing-extensions>=4.15.0"],
- };
- await StandardPipInstallProcessAsync(
+ throw new InvalidOperationException(
+ "Windows ROCm installation requires the shared ROCm helper to resolve gfx-specific index URLs."
+ );
+ }
+
+ await rocmPackageHelper
+ .InstallWindowsNativePackageAsync(
venvRunner,
- options,
+ installLocation,
installedPackage,
- config,
- onConsoleOutput,
+ WindowsRocmProfile,
progress,
+ onConsoleOutput,
cancellationToken
)
.ConfigureAwait(false);
-
- progress?.Report(
- new ProgressReport(-1f, "Installing ROCm nightly torch...", isIndeterminate: true)
- );
- var indexUrl = gfxArch switch
- {
- "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150", // Strix/Gorgon Point
- "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151", // Strix Halo
- _ when gfxArch.StartsWith("gfx110") => "https://rocm.nightlies.amd.com/v2/gfx110X-all",
- _ when gfxArch.StartsWith("gfx120") => "https://rocm.nightlies.amd.com/v2/gfx120X-all",
- _ => throw new ArgumentOutOfRangeException(
- nameof(gfxArch),
- $"Unsupported GFX Arch: {gfxArch}"
- ),
- };
-
- var torchPipArgs = new PipInstallArgs()
- .AddArgs("--pre", "--upgrade")
- .WithTorch()
- .WithTorchVision()
- .WithTorchAudio()
- .AddArgs("--index-url", indexUrl);
-
- await venvRunner.PipInstall(torchPipArgs, onConsoleOutput).ConfigureAwait(false);
}
- else // Standard installation path for all other cases
+ else
{
- var isLegacyNvidia =
- torchIndex == TorchIndex.Cuda
- && (
- SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu()
- ?? HardwareHelper.HasLegacyNvidiaGpu()
- );
-
var config = new PipInstallConfig
{
RequirementsFilePaths = ["requirements.txt"],
@@ -479,7 +452,11 @@ await StandardPipInstallProcessAsync(
SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu()
?? HardwareHelper.HasBlackwellGpu(),
WorkingDirectory = installLocation,
- EnvironmentVariables = GetEnvVars(venvRunner.EnvironmentVariables),
+ EnvironmentVariables = GetEnvVars(
+ venvRunner.EnvironmentVariables,
+ installLocation,
+ installedPackage
+ ),
};
await step.ExecuteAsync(progress).ConfigureAwait(false);
@@ -529,7 +506,7 @@ public override async Task RunPackage(
await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion))
.ConfigureAwait(false);
- VenvRunner.UpdateEnvironmentVariables(GetEnvVars);
+ VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage));
// Check for old NVIDIA driver version with cu130 installations
var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu();
@@ -613,13 +590,7 @@ public override TorchIndex GetRecommendedTorchVersion()
{
var preferRocm =
(Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm()))
- || (
- Compat.IsWindows
- && (
- SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu()
- ?? HardwareHelper.HasWindowsRocmSupportedGpu()
- )
- );
+ || HasWindowsRocmSupport();
if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm)
{
@@ -629,6 +600,54 @@ public override TorchIndex GetRecommendedTorchVersion()
return base.GetRecommendedTorchVersion();
}
+ /// Windows ROCm install profile for ComfyUI.
+ private static readonly RocmPackageProfile WindowsRocmProfile = new()
+ {
+ RequiresRocmSdk = true,
+ ExtraInstallPipArgs = ["numpy<2"],
+ PostInstallPipArgs = ["typing-extensions>=4.15.0"],
+ UpgradePackages = true,
+ ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment,
+ };
+
+ private static IReadOnlyDictionary BuildComfyWindowsRocmEnvironment(
+ RocmRuntimeContext runtimeContext
+ )
+ {
+ return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch)
+ ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" }
+ : new Dictionary();
+ }
+
+ /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix.
+ private bool HasWindowsRocmSupport()
+ {
+ if (!Compat.IsWindows)
+ return false;
+
+ if (rocmPackageHelper is null)
+ return false;
+
+ var compatibility = rocmPackageHelper.GetCompatibility(WindowsRocmProfile);
+
+ return compatibility.IsCompatible;
+ }
+
+ /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower
+ /// and not as supported on older AMD architectures.
+ private bool DefaultToQuadCrossAttention()
+ {
+ if (!Compat.IsWindows || !HasWindowsRocmSupport())
+ return false;
+
+ var gpu = SettingsManager.Settings.PreferredGpu;
+ var gfxArch = WindowsRocmSupport.IsSupportedGpu(gpu)
+ ? gpu?.GetAmdGfxArch()
+ : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch();
+
+ return WindowsRocmSupport.PreferLegacyAttentionFallback(gfxArch);
+ }
+
public override IPackageExtensionManager ExtensionManager =>
new ComfyExtensionManager(this, settingsManager);
@@ -979,21 +998,29 @@ await PipWheelService
.ConfigureAwait(false);
}
- private ImmutableDictionary GetEnvVars(ImmutableDictionary env)
+ private ImmutableDictionary GetEnvVars(
+ ImmutableDictionary env,
+ string installLocation,
+ InstalledPackage installedPackage
+ )
{
// if we're not on windows or we don't have a windows rocm gpu, return original env
- var hasRocmGpu =
- SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu()
- ?? HardwareHelper.HasWindowsRocmSupportedGpu();
+ var hasRocmGpu = HasWindowsRocmSupport();
if (!Compat.IsWindows || !hasRocmGpu)
return env;
- // set some experimental speed improving env vars for Windows ROCm
- return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1")
- .SetItem("MIOPEN_FIND_MODE", "2")
- .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1")
- .SetItem("PYTORCH_ALLOC_CONF", "max_split_size_mb:6144,garbage_collection_threshold:0.8") // greatly helps prevent GPU OOM and instability/driver timeouts/OS hard locks and decreases dependency on Tiled VAE at standard res's
- .SetItem("COMFYUI_ENABLE_MIOPEN", "1"); // re-enables "cudnn" in ComfyUI as it's needed for MiOpen to function properly
+ if (rocmPackageHelper is not null)
+ {
+ var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(
+ installLocation,
+ installedPackage,
+ WindowsRocmProfile
+ );
+
+ return env.SetItems(rocmEnvironment);
+ }
+
+ return env;
}
}
diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs
index 2a00a6262..0c91f23a2 100644
--- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs
+++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs
@@ -6,9 +6,11 @@
using StabilityMatrix.Core.Helper.HardwareInfo;
using StabilityMatrix.Core.Models.FileInterfaces;
using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Models.Rocm;
using StabilityMatrix.Core.Processes;
using StabilityMatrix.Core.Python;
using StabilityMatrix.Core.Services;
+using StabilityMatrix.Core.Services.Rocm;
namespace StabilityMatrix.Core.Models.Packages;
@@ -30,7 +32,8 @@ public class Wan2GP(
IDownloadService downloadService,
IPrerequisiteHelper prerequisiteHelper,
IPyInstallationManager pyInstallationManager,
- IPipWheelService pipWheelService
+ IPipWheelService pipWheelService,
+ IRocmPackageHelper? rocmPackageHelper = null
)
: BaseGitPackage(
githubApi,
@@ -41,6 +44,13 @@ IPipWheelService pipWheelService
pipWheelService
)
{
+ private static readonly RocmPackageProfile WindowsRocmProfile = new()
+ {
+ RequiresRocmSdk = true,
+ UpgradePackages = true,
+ PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"],
+ };
+
public override string Name => "Wan2GP";
public override string DisplayName { get; set; } = "Wan2GP";
public override string Author => "deepbeepmeep";
@@ -64,7 +74,7 @@ IPipWheelService pipWheelService
public override bool IsCompatible =>
HardwareHelper.HasNvidiaGpu()
- || (Compat.IsWindows ? HardwareHelper.HasWindowsRocmSupportedGpu() : HardwareHelper.HasAmdGpu());
+ || (Compat.IsWindows ? HasWindowsRocmSupport() : HardwareHelper.HasAmdGpu());
public override string MainBranch => "main";
public override bool ShouldIgnoreReleases => true;
@@ -72,7 +82,7 @@ IPipWheelService pipWheelService
public override Dictionary> SharedOutputFolders =>
new() { [SharedOutputType.Img2Vid] = ["outputs"] };
- // AMD ROCm requires Python 3.11, NVIDIA uses 3.10
+ // Wan2GP currently uses Python 3.11 for ROCm and 3.10 for CUDA.
public override PyVersion RecommendedPythonVersion =>
IsAmdRocm ? Python.PyInstallationManager.Python_3_11_13 : Python.PyInstallationManager.Python_3_10_17;
@@ -86,6 +96,17 @@ IPipWheelService pipWheelService
///
private bool IsAmdRocm => GetRecommendedTorchVersion() == TorchIndex.Rocm;
+ private bool HasWindowsRocmSupport()
+ {
+ if (!Compat.IsWindows)
+ return false;
+
+ if (rocmPackageHelper is null)
+ return HardwareHelper.HasWindowsRocmSupportedGpu();
+
+ return rocmPackageHelper.GetCompatibility(WindowsRocmProfile).IsCompatible;
+ }
+
///
/// Python wrapper script that patches logging to also print to stdout/stderr, so
/// StabilityMatrix can capture the output. Wan2GP logs through Gradio UI notifications
@@ -213,8 +234,8 @@ public override TorchIndex GetRecommendedTorchVersion()
(
Compat.IsWindows
&& (
- SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu()
- ?? HardwareHelper.HasWindowsRocmSupportedGpu()
+ WindowsRocmSupport.IsSupportedGpu(SettingsManager.Settings.PreferredGpu)
+ || HasWindowsRocmSupport()
)
)
|| (
@@ -256,7 +277,15 @@ public override async Task InstallPackage(
if (torchIndex == TorchIndex.Rocm)
{
- await InstallAmdRocmAsync(venvRunner, progress, onConsoleOutput).ConfigureAwait(false);
+ await InstallAmdRocmAsync(
+ venvRunner,
+ installLocation,
+ installedPackage,
+ progress,
+ onConsoleOutput,
+ cancellationToken
+ )
+ .ConfigureAwait(false);
}
else
{
@@ -359,68 +388,53 @@ await venvRunner
private async Task InstallAmdRocmAsync(
IPyVenvRunner venvRunner,
+ string installLocation,
+ InstalledPackage installedPackage,
IProgress? progress,
- Action? onConsoleOutput
+ Action? onConsoleOutput,
+ CancellationToken cancellationToken
)
{
- progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true));
- await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
-
if (Compat.IsWindows)
{
- // Windows AMD ROCm - special TheRock wheels
- progress?.Report(
- new ProgressReport(-1f, "Installing PyTorch ROCm wheels...", isIndeterminate: true)
- );
-
- // Set environment variable for wheel filename check bypass
- venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1"));
+ if (rocmPackageHelper is null)
+ {
+ throw new InvalidOperationException(
+ "Windows ROCm installation for Wan2GP requires the shared ROCm helper."
+ );
+ }
- // Install PyTorch ROCm wheels from TheRock releases (Python 3.11)
- await venvRunner
- .PipInstall(
- "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl",
- onConsoleOutput
+ await rocmPackageHelper
+ .InstallWindowsNativePackageAsync(
+ venvRunner,
+ installLocation,
+ installedPackage,
+ WindowsRocmProfile,
+ progress,
+ onConsoleOutput,
+ cancellationToken
)
.ConfigureAwait(false);
- await venvRunner
- .PipInstall(
- "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl",
- onConsoleOutput
- )
- .ConfigureAwait(false);
+ return;
+ }
- await venvRunner
- .PipInstall(
- "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl",
- onConsoleOutput
- )
- .ConfigureAwait(false);
+ progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true));
+ await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
- // Install requirements directly using -r flag (handles @ URL syntax properly)
- progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
- await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false);
- }
- else
- {
- // Linux AMD ROCm - standard PyTorch ROCm
- // Install requirements directly using -r flag (handles @ URL syntax properly)
- progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
- await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false);
-
- // Install torch with ROCm index (force reinstall to ensure correct version)
- progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true));
- var torchArgs = new PipInstallArgs()
- .WithTorch("==2.7.0")
- .WithTorchVision("==0.22.0")
- .WithTorchAudio("==2.7.0")
- .WithTorchExtraIndex("rocm6.3")
- .AddArg("--force-reinstall")
- .AddArg("--no-deps");
-
- await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false);
- }
+ progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
+ await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true));
+ var torchArgs = new PipInstallArgs()
+ .WithTorch()
+ .WithTorchVision()
+ .WithTorchAudio()
+ .WithTorchExtraIndex("rocm7.2")
+ .AddArg("--force-reinstall")
+ .AddArg("--no-deps");
+
+ await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false);
// Install additional packages
await venvRunner.PipInstall("hf-xet setuptools numpy==1.26.4", onConsoleOutput).ConfigureAwait(false);
@@ -437,6 +451,16 @@ public override async Task RunPackage(
await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion))
.ConfigureAwait(false);
+ if (Compat.IsWindows && rocmPackageHelper is not null && HasWindowsRocmSupport())
+ {
+ var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(
+ installLocation,
+ installedPackage,
+ WindowsRocmProfile
+ );
+ VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment));
+ }
+
// Fix for distutils compatibility issue with Python 3.10 and setuptools
VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib"));
diff --git a/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs
new file mode 100644
index 000000000..401f3ada4
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs
@@ -0,0 +1,17 @@
+using StabilityMatrix.Core.Helper.HardwareInfo;
+
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Describes whether a package/profile is currently compatible with ROCm on the active machine.
+///
+public class RocmCompatibilityResult
+{
+ public bool IsCompatible { get; init; }
+
+ public string? FailureReason { get; init; }
+
+ public GpuInfo? SelectedGpu { get; init; }
+
+ public string? ResolvedGfxArch { get; init; }
+}
diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs
new file mode 100644
index 000000000..e01261700
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs
@@ -0,0 +1,57 @@
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Controls how ROCm helper defaults, package-specific variables, and user overrides are layered at launch.
+///
+public class RocmEnvironmentOptions
+{
+ ///
+ /// When true, package-specific environment additions may be merged on top of helper defaults.
+ ///
+ public bool IncludePackageOverrides { get; init; } = true;
+
+ ///
+ /// When true, user-defined Stability Matrix environment variables may override helper/package defaults last.
+ ///
+ public bool IncludeUserOverrides { get; init; } = true;
+
+ ///
+ /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper.
+ ///
+ public string? PyTorchAllocConf { get; init; } = "max_split_size_mb:512,garbage_collection_threshold:0.8";
+
+ ///
+ /// When set, configures MIOpen find mode for helper-managed ROCm defaults.
+ ///
+ public string? MiopenFindMode { get; init; } = "2";
+
+ ///
+ /// When set, configures MIOpen search cutoff for helper-managed ROCm defaults.
+ ///
+ public string? MiopenSearchCutoff { get; init; } = "1";
+
+ ///
+ /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults.
+ ///
+ public string? MiopenFindEnforce { get; init; } = "3";
+
+ ///
+ /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults.
+ ///
+ public string? FlashAttentionTritonAmdEnable { get; init; } = "TRUE";
+
+ ///
+ /// When true, helper-managed defaults will enable ROCm AOTriton on modern Windows ROCm architectures.
+ ///
+ public bool ApplyAotritonExperimental { get; init; } = true;
+
+ ///
+ /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures.
+ ///
+ public bool ApplyLegacySdpFallback { get; init; } = true;
+
+ ///
+ /// When true, helper-managed defaults will apply the RDNA1 HSA override mask when needed.
+ ///
+ public bool ApplyRdna1Override { get; init; } = true;
+}
diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs
new file mode 100644
index 000000000..597eb4fe6
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs
@@ -0,0 +1,11 @@
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Captures ROCm-related facts needed during package install or update flows.
+///
+public class RocmInstallContext
+{
+ public string? RuntimeGfxArch { get; init; }
+
+ public string? RocmPackageIndexUrl { get; init; }
+}
diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs
new file mode 100644
index 000000000..d1abf16f8
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs
@@ -0,0 +1,55 @@
+using StabilityMatrix.Core.Models.Progress;
+
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Declares what a package expects from the ROCm helper.
+/// Package classes should describe intent here rather than hardcoding ROCm decisions inline.
+///
+public class RocmPackageProfile
+{
+ public bool RequiresRocmSdk { get; init; }
+
+ ///
+ /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete.
+ ///
+ public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"];
+
+ ///
+ /// Package requirement entries to exclude because the helper installs them from ROCm-specific feeds.
+ ///
+ public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?";
+
+ ///
+ /// Extra package-specific pip arguments to include when installing requirements after helper bootstrap.
+ ///
+ public IEnumerable ExtraInstallPipArgs { get; init; } = [];
+
+ ///
+ /// Extra package-specific pip arguments to install after requirements and torch are complete.
+ ///
+ public IEnumerable PostInstallPipArgs { get; init; } = [];
+
+ ///
+ /// When true, helper-managed requirements installs should use --upgrade.
+ ///
+ public bool UpgradePackages { get; init; }
+
+ ///
+ /// When true, helper-managed torch installs should force reinstall the selected ROCm wheel set.
+ ///
+ public bool ForceReinstallTorch { get; init; } = true;
+
+ ///
+ /// Optional callback for package-specific environment variables derived from a resolved ROCm context.
+ ///
+ public Func<
+ RocmRuntimeContext,
+ IReadOnlyDictionary
+ >? ExtraEnvironmentFactory { get; init; }
+
+ ///
+ /// Controls whether package-specific environment variables should be layered on top of helper defaults.
+ ///
+ public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new();
+}
diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs
new file mode 100644
index 000000000..1fdda7914
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs
@@ -0,0 +1,18 @@
+using StabilityMatrix.Core.Helper.HardwareInfo;
+
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Captures resolved ROCm facts for a package launch or runtime decision.
+/// This model is intended to separate hardware/runtime facts from package policy.
+///
+public class RocmRuntimeContext
+{
+ public bool IsSupported { get; init; }
+
+ public string? FailureReason { get; init; }
+
+ public GpuInfo? SelectedGpu { get; init; }
+
+ public string? RuntimeGfxArch { get; init; }
+}
diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs
new file mode 100644
index 000000000..ee000b1bd
--- /dev/null
+++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs
@@ -0,0 +1,68 @@
+using StabilityMatrix.Core.Helper.HardwareInfo;
+
+namespace StabilityMatrix.Core.Models.Rocm;
+
+///
+/// Centralizes Windows ROCm support and architecture policy so hardware detection, package selection,
+/// installation, and shared launch decisions use the same support map.
+///
+public static class WindowsRocmSupport
+{
+ public static bool IsSupportedGpu(GpuInfo? gpu)
+ {
+ if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name))
+ return false;
+
+ return IsSupportedArchitecture(gpu.GetAmdGfxArch());
+ }
+
+ public static bool IsSupportedArchitecture(string? gfxArch)
+ {
+ return TryGetPackageIndexUrl(gfxArch) is not null;
+ }
+
+ public static bool IsModernArchitecture(string? gfxArch)
+ {
+ return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true
+ || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true
+ || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true;
+ }
+
+ public static bool IsLegacyArchitecture(string? gfxArch)
+ {
+ return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch);
+ }
+
+ public static bool PreferLegacyAttentionFallback(string? gfxArch)
+ {
+ return IsLegacyArchitecture(gfxArch);
+ }
+
+ public static bool IsRdna1Architecture(string? gfxArch)
+ {
+ return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true;
+ }
+
+ public static string? TryGetPackageIndexUrl(string? gfxArch)
+ {
+ return gfxArch switch
+ {
+ "gfx900" => "https://rocm.nightlies.amd.com/v2-staging/gfx900/", // Vega 10
+ "gfx906" => "https://rocm.nightlies.amd.com/v2-staging/gfx906/", // Radeon VII, Vega 20
+ "gfx90X" => "https://rocm.nightlies.amd.com/v2-staging/gfx90X/", // Radeon Pro VII
+ var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) =>
+ "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", // RDNA1 (5000 series, Pro)
+ var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) =>
+ "https://rocm.nightlies.amd.com/v2-staging/gfx103X-all/", // RDNA2 (6000 series, 6xxM Mobile, Steam Deck)
+ var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) =>
+ "https://rocm.nightlies.amd.com/v2/gfx110X-all/", // RDNA3 (7000 series, 7xxM Mobile)
+ "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", // RDNA3.5 (Strix/Gorgon Point)
+ "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", // RDNA3.5 (Strix Halo)
+ "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", // RDNA3.5 (Kraken Point)
+ "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", // RDNA3.5 (Medusa Point)
+ var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) =>
+ "https://rocm.nightlies.amd.com/v2/gfx120X-all/", // RDNA4 (9000 series)
+ _ => null,
+ };
+ }
+}
diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs
new file mode 100644
index 000000000..03b4ce0c2
--- /dev/null
+++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs
@@ -0,0 +1,58 @@
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Models.Rocm;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+
+namespace StabilityMatrix.Core.Services.Rocm;
+
+///
+/// Defines the ROCm helper surface area shared by ROCm-capable packages.
+///
+public interface IRocmPackageHelper
+{
+ ///
+ /// Evaluates whether the current machine and package profile are compatible with ROCm.
+ ///
+ RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile);
+
+ ///
+ /// Resolves the runtime ROCm facts needed for package launch and environment construction.
+ ///
+ RocmRuntimeContext ResolveRuntimeContext(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ );
+
+ ///
+ /// Resolves the ROCm facts needed during package installation or update operations.
+ ///
+ RocmInstallContext ResolveInstallContext(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ );
+
+ ///
+ /// Builds a launch-time environment dictionary from resolved ROCm runtime data.
+ ///
+ IReadOnlyDictionary BuildLaunchEnvironment(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ );
+
+ ///
+ /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs.
+ ///
+ Task InstallWindowsNativePackageAsync(
+ IPyVenvRunner venvRunner,
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null,
+ CancellationToken cancellationToken = default
+ );
+}
diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs
new file mode 100644
index 000000000..a5382ac91
--- /dev/null
+++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs
@@ -0,0 +1,793 @@
+using System.Collections.Immutable;
+using System.Text.Json;
+using Injectio.Attributes;
+using NLog;
+using StabilityMatrix.Core.Exceptions;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.HardwareInfo;
+using StabilityMatrix.Core.Models;
+using StabilityMatrix.Core.Models.FileInterfaces;
+using StabilityMatrix.Core.Models.Progress;
+using StabilityMatrix.Core.Models.Rocm;
+using StabilityMatrix.Core.Processes;
+using StabilityMatrix.Core.Python;
+using StabilityMatrix.Core.Services;
+
+namespace StabilityMatrix.Core.Services.Rocm;
+
+///
+/// Provides the shared ROCm helper surface area used by ROCm-capable packages.
+///
+[RegisterSingleton]
+public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper
+{
+ private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
+ private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase;
+
+ ///
+ public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile)
+ {
+ return BuildCompatibilityResult(profile);
+ }
+
+ ///
+ public RocmRuntimeContext ResolveRuntimeContext(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ )
+ {
+ _ = installLocation;
+ _ = installedPackage;
+
+ var compatibility = BuildCompatibilityResult(profile);
+ if (!compatibility.IsCompatible)
+ {
+ return new RocmRuntimeContext
+ {
+ IsSupported = false,
+ FailureReason = compatibility.FailureReason,
+ SelectedGpu = compatibility.SelectedGpu,
+ RuntimeGfxArch = compatibility.ResolvedGfxArch,
+ };
+ }
+
+ var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true)
+ .Where(IsSupportedWindowsRocmGpu)
+ .ToList();
+
+ var selectedGpu =
+ compatibility.SelectedGpu
+ ?? TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu)
+ ?? supportedAmdGpus.FirstOrDefault();
+
+ var runtimeGfxArch =
+ compatibility.ResolvedGfxArch
+ ?? selectedGpu?.GetAmdGfxArch()
+ ?? GetSupportedFallbackGfxArch(supportedAmdGpus);
+
+ return new RocmRuntimeContext
+ {
+ IsSupported = true,
+ SelectedGpu = selectedGpu,
+ RuntimeGfxArch = runtimeGfxArch,
+ };
+ }
+
+ ///
+ public RocmInstallContext ResolveInstallContext(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ )
+ {
+ _ = installLocation;
+ _ = installedPackage;
+
+ var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true)
+ .Where(IsSupportedWindowsRocmGpu)
+ .ToList();
+
+ var preferredGfxArch = TryResolvePreferredAmdGfxArch(
+ supportedAmdGpus,
+ settingsManager.Settings.PreferredGpu
+ );
+
+ var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus);
+ var windowsNativeIndexUrl = WindowsRocmSupport.TryGetPackageIndexUrl(runtimeGfxArch);
+
+ return new RocmInstallContext
+ {
+ RuntimeGfxArch = runtimeGfxArch,
+ RocmPackageIndexUrl = windowsNativeIndexUrl,
+ };
+ }
+
+ ///
+ public IReadOnlyDictionary BuildLaunchEnvironment(
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile
+ )
+ {
+ var runtimeContext = ResolveRuntimeContext(installLocation, installedPackage, profile);
+
+ if (!runtimeContext.IsSupported)
+ return new Dictionary();
+
+ var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile);
+ var packageEnvironment =
+ profile.ExtraEnvironmentFactory?.Invoke(runtimeContext) ?? new Dictionary();
+
+ var mergedEnvironment = MergeLaunchEnvironment(
+ helperEnvironment,
+ packageEnvironment,
+ profile.EnvironmentOptions
+ );
+
+ return mergedEnvironment;
+ }
+
+ ///
+ public async Task InstallWindowsNativePackageAsync(
+ IPyVenvRunner venvRunner,
+ string installLocation,
+ InstalledPackage installedPackage,
+ RocmPackageProfile profile,
+ IProgress? progress = null,
+ Action? onConsoleOutput = null,
+ CancellationToken cancellationToken = default
+ )
+ {
+ var compatibility = GetCompatibility(profile);
+ if (!compatibility.IsCompatible)
+ {
+ throw new InvalidOperationException(
+ compatibility.FailureReason
+ ?? "Windows ROCm installation is not supported for the current machine."
+ );
+ }
+
+ var installContext = ResolveInstallContext(installLocation, installedPackage, profile);
+
+ var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl;
+
+ if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl))
+ {
+ throw new ApplicationException(
+ $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'."
+ );
+ }
+
+ progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true));
+ await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
+
+ if (profile.RequiresRocmSdk)
+ {
+ progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true));
+ var rocmRuntimeArgs = new PipInstallArgs()
+ .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl])
+ .AddArgs("rocm[devel,libraries]");
+
+ if (installedPackage.PipOverrides != null)
+ {
+ rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides(installedPackage.PipOverrides);
+ }
+
+ await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true));
+ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken)
+ .ConfigureAwait(false);
+ }
+
+ progress?.Report(
+ new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true)
+ );
+
+ var requirementsPipArgs = new PipInstallArgs([.. profile.ExtraInstallPipArgs]);
+ if (profile.UpgradePackages)
+ {
+ requirementsPipArgs = requirementsPipArgs.AddArg("--upgrade");
+ }
+
+ foreach (var relativePath in profile.RequirementsFilePaths)
+ {
+ var requirementsFile = new FilePath(venvRunner.WorkingDirectory ?? installLocation, relativePath);
+ if (!requirementsFile.Exists)
+ continue;
+
+ var requirementsContent = await requirementsFile
+ .ReadAllTextAsync(cancellationToken)
+ .ConfigureAwait(false);
+
+ requirementsPipArgs = requirementsPipArgs.WithParsedFromRequirementsTxt(
+ requirementsContent,
+ profile.RequirementsExcludePattern
+ );
+ }
+
+ if (installedPackage.PipOverrides != null)
+ {
+ requirementsPipArgs = requirementsPipArgs.WithUserOverrides(installedPackage.PipOverrides);
+ }
+
+ await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true));
+ var torchArgs = new PipInstallArgs()
+ .AddArg("--pre")
+ .AddArg("--upgrade")
+ .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl])
+ .WithTorch()
+ .WithTorchAudio()
+ .WithTorchVision();
+
+ if (profile.ForceReinstallTorch)
+ {
+ torchArgs = torchArgs.AddArg("--force-reinstall");
+ }
+
+ if (installedPackage.PipOverrides != null)
+ {
+ torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides);
+ }
+
+ await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false);
+
+ if (profile.RequiresRocmSdk)
+ {
+ await AlignRocmSdkDevelVersionAsync(venvRunner, rocmPackageIndexUrl, onConsoleOutput)
+ .ConfigureAwait(false);
+
+ progress?.Report(new ProgressReport(-1f, "Reinitializing ROCm SDK...", isIndeterminate: true));
+ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken)
+ .ConfigureAwait(false);
+ }
+
+ if (profile.PostInstallPipArgs.Any())
+ {
+ var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]);
+ if (installedPackage.PipOverrides != null)
+ {
+ postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides);
+ }
+
+ await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false);
+ }
+
+ await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken)
+ .ConfigureAwait(false);
+
+ if (profile.RequiresRocmSdk)
+ {
+ await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, cancellationToken)
+ .ConfigureAwait(false);
+ }
+ }
+
+ ///
+ /// Builds a compatibility result from the current machine state and package profile.
+ /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only.
+ ///
+ private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile)
+ {
+ var amdGpus = GetAmdGpuCandidates(forceRefresh: false).ToList();
+ if (amdGpus.Count == 0)
+ {
+ return new RocmCompatibilityResult
+ {
+ IsCompatible = false,
+ FailureReason = "No AMD GPU was detected for ROCm evaluation.",
+ };
+ }
+
+ var preferredGpu = settingsManager.Settings.PreferredGpu;
+
+ var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList();
+ if (supportedAmdGpus.Count == 0)
+ {
+ return new RocmCompatibilityResult
+ {
+ IsCompatible = false,
+ FailureReason = GetUnsupportedGpuReason(amdGpus),
+ };
+ }
+
+ var selectedGpu =
+ TryResolvePreferredAmdGpu(supportedAmdGpus, preferredGpu) ?? supportedAmdGpus.First();
+ var resolvedGfxArch = selectedGpu.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus);
+
+ return new RocmCompatibilityResult
+ {
+ IsCompatible = !string.IsNullOrWhiteSpace(resolvedGfxArch),
+ FailureReason = string.IsNullOrWhiteSpace(resolvedGfxArch)
+ ? "No supported AMD GFX architecture could be resolved for ROCm."
+ : null,
+ SelectedGpu = selectedGpu,
+ ResolvedGfxArch = resolvedGfxArch,
+ };
+ }
+
+ ///
+ /// Returns AMD GPUs from Stability Matrix's internal hardware model.
+ /// This is the canonical GPU source for the ROCm helper and intentionally avoids package-local probing.
+ ///
+ private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = false)
+ {
+ return HardwareHelper.IterGpuInfo(forceRefresh).Where(gpu => gpu.IsAmd).ToList();
+ }
+
+ ///
+ /// Resolves the preferred AMD GPU when the configured preference is still present in the current hardware list.
+ ///
+ private static GpuInfo? TryResolvePreferredAmdGpu(
+ IEnumerable availableGpus,
+ GpuInfo? preferredGpu
+ )
+ {
+ if (preferredGpu is null || !preferredGpu.IsAmd)
+ return null;
+
+ var preferredMatch = availableGpus.FirstOrDefault(gpu => gpu.Equals(preferredGpu));
+ if (preferredMatch is not null)
+ return preferredMatch;
+
+ if (!string.IsNullOrWhiteSpace(preferredGpu.Name))
+ {
+ Logger.Info(
+ "Preferred GPU {PreferredGpuName} was ignored for ROCm detection because it is not present in current hardware enumeration.",
+ preferredGpu.Name
+ );
+ }
+
+ return null;
+ }
+
+ ///
+ /// Resolves the preferred AMD GFX architecture when the configured GPU is supported and currently present.
+ ///
+ private static string? TryResolvePreferredAmdGfxArch(
+ IEnumerable availableGpus,
+ GpuInfo? preferredGpu
+ )
+ {
+ var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu);
+ return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu)
+ ? resolvedPreferredGpu.GetAmdGfxArch()
+ : null;
+ }
+
+ ///
+ /// Resolves the first supported AMD GFX architecture from the current machine state when no preferred GPU applies.
+ ///
+ private static string? GetSupportedFallbackGfxArch(IEnumerable availableGpus)
+ {
+ return availableGpus
+ .Where(IsSupportedWindowsRocmGpu)
+ .Select(gpu => gpu.GetAmdGfxArch())
+ .FirstOrDefault(IsSupportedWindowsRocmArchitecture);
+ }
+
+ ///
+ /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper.
+ ///
+ private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu)
+ {
+ return WindowsRocmSupport.IsSupportedGpu(gpu);
+ }
+
+ ///
+ /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper.
+ ///
+ private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch)
+ {
+ return WindowsRocmSupport.IsSupportedArchitecture(gfxArch);
+ }
+
+ ///
+ /// Produces a readable incompatibility reason when AMD hardware is present but not usable for Windows ROCm.
+ ///
+ private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus)
+ {
+ _ = amdGpus;
+ return "No AMD GPU with a supported Windows ROCm architecture was detected.";
+ }
+
+ ///
+ /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete.
+ ///
+ private static async Task VerifyWindowsNativeTorchInstallAsync(
+ IPyVenvRunner venvRunner,
+ Action? onConsoleOutput,
+ CancellationToken cancellationToken
+ )
+ {
+ var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false);
+ if (torchInfo is null)
+ {
+ throw new ApplicationException("torch was not installed after Windows ROCm setup.");
+ }
+
+ var verificationResult = await venvRunner
+ .Run(
+ "-c \"import json, torch; print(json.dumps({'version': torch.__version__, 'hip': torch.version.hip, 'cuda': torch.cuda.is_available()}))\""
+ )
+ .ConfigureAwait(false);
+
+ var verificationOutput = (verificationResult.StandardOutput ?? string.Empty).Trim();
+ if (string.IsNullOrWhiteSpace(verificationOutput))
+ {
+ throw new ApplicationException("Torch verification produced no output.");
+ }
+
+ var verificationJson = TryExtractJsonObject(verificationOutput);
+ if (string.IsNullOrWhiteSpace(verificationJson))
+ {
+ throw new ApplicationException($"Unexpected torch verification output: {verificationOutput}");
+ }
+
+ JsonDocument verificationDocument;
+ try
+ {
+ verificationDocument = JsonDocument.Parse(verificationJson);
+ }
+ catch (Exception exception)
+ {
+ throw new ApplicationException(
+ $"Unexpected torch verification output: {verificationOutput}",
+ exception
+ );
+ }
+
+ using (verificationDocument)
+ {
+ var root = verificationDocument.RootElement;
+ var version = root.TryGetProperty("version", out var versionElement)
+ ? versionElement.GetString()
+ : null;
+ var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null;
+ var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean();
+
+ if (!IsUsableWindowsNativeTorchBuild(version, hipVersion))
+ {
+ throw new ApplicationException(
+ $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}"
+ );
+ }
+
+ if (!cudaAvailable)
+ {
+ onConsoleOutput?.Invoke(
+ ProcessOutput.FromStdErrLine(
+ $"Torch verification warning: installed ROCm torch build reported cuda={cudaAvailable}; continuing because ROCm metadata was detected (version={version}, hip={hipVersion})."
+ )
+ );
+ }
+
+ onConsoleOutput?.Invoke(
+ ProcessOutput.FromStdOutLine(
+ $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}"
+ )
+ );
+ }
+
+ _ = cancellationToken;
+ }
+
+ ///
+ /// Runs rocm-sdk init after the helper-managed runtime packages are installed so the Windows ROCm SDK can prepare the venv.
+ ///
+ private static async Task InitializeWindowsNativeRocmSdkAsync(
+ string installLocation,
+ Action? onConsoleOutput,
+ CancellationToken cancellationToken
+ )
+ {
+ var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe");
+ if (!File.Exists(rocmSdkExe))
+ {
+ throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe);
+ }
+
+ using var rocmSdkProcess = ProcessRunner.StartAnsiProcess(
+ rocmSdkExe,
+ ["init"],
+ installLocation,
+ onConsoleOutput
+ );
+
+ await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false);
+ if (rocmSdkProcess.ExitCode != 0)
+ {
+ throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}");
+ }
+ }
+
+ ///
+ /// Uses AMD's bundled hipInfo.exe to confirm the installed Windows ROCm runtime can enumerate a ROCm-capable GPU.
+ ///
+ private static async Task VerifyWindowsNativeRocmRuntimeAsync(
+ string installLocation,
+ Action? onConsoleOutput,
+ CancellationToken cancellationToken
+ )
+ {
+ var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe");
+ if (!File.Exists(rocmSdkExe))
+ {
+ throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe);
+ }
+
+ var rocmBinResult = await ProcessRunner
+ .GetProcessResultAsync(rocmSdkExe, ["path", "--bin"], installLocation, useUtf8Encoding: true)
+ .ConfigureAwait(false);
+
+ var rocmBinPath = (rocmBinResult.StandardOutput ?? string.Empty).Trim();
+ if (!rocmBinResult.IsSuccessExitCode || string.IsNullOrWhiteSpace(rocmBinPath))
+ {
+ var rocmBinOutput = CombineProcessOutput(
+ rocmBinResult.StandardOutput,
+ rocmBinResult.StandardError
+ );
+ throw new ApplicationException(
+ $"ROCm runtime verification failed while resolving the ROCm SDK bin path. Output: {rocmBinOutput}"
+ );
+ }
+
+ var hipInfoExe = Path.Combine(rocmBinPath, $"hipInfo{Compat.ExeExtension}");
+ if (!File.Exists(hipInfoExe))
+ {
+ throw new FileNotFoundException(
+ "hipInfo.exe was not found in the ROCm SDK bin directory",
+ hipInfoExe
+ );
+ }
+
+ var hipInfoResult = await ProcessRunner
+ .GetProcessResultAsync(
+ hipInfoExe,
+ [],
+ installLocation,
+ new Dictionary { ["PATH"] = rocmBinPath },
+ useUtf8Encoding: true
+ )
+ .ConfigureAwait(false);
+
+ var hipInfoOutput = CombineProcessOutput(hipInfoResult.StandardOutput, hipInfoResult.StandardError);
+ if (!hipInfoResult.IsSuccessExitCode)
+ {
+ var runtimeFailureReason = TryGetWindowsNativeRocmRuntimeFailureReason(hipInfoOutput);
+ throw new ApplicationException(
+ runtimeFailureReason is null
+ ? $"ROCm runtime verification failed while probing the installed runtime with hipInfo.exe. Output: {hipInfoOutput}"
+ : $"ROCm runtime verification failed: {runtimeFailureReason} Output: {hipInfoOutput}"
+ );
+ }
+
+ onConsoleOutput?.Invoke(
+ ProcessOutput.FromStdOutLine(
+ $"ROCm runtime verification succeeded via hipInfo.exe: {hipInfoOutput}"
+ )
+ );
+
+ _ = cancellationToken;
+ }
+
+ ///
+ /// Reinstalls rocm-sdk-devel to the resolved ROCm build version when the torch step downgrades the runtime stack.
+ ///
+ private static async Task AlignRocmSdkDevelVersionAsync(
+ IPyVenvRunner venvRunner,
+ string rocmPackageIndexUrl,
+ Action? onConsoleOutput
+ )
+ {
+ var rocmInfo = await venvRunner.PipShow("rocm").ConfigureAwait(false);
+ var rocmSdkDevelInfo = await venvRunner.PipShow("rocm-sdk-devel").ConfigureAwait(false);
+ var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false);
+
+ var targetVersion = GetRocmSdkDevelAlignmentVersion(
+ rocmInfo?.Version,
+ rocmSdkDevelInfo?.Version,
+ torchInfo?.Version
+ );
+
+ if (string.IsNullOrWhiteSpace(targetVersion))
+ return;
+
+ onConsoleOutput?.Invoke(
+ ProcessOutput.FromStdErrLine(
+ $"Aligning rocm-sdk-devel from version={rocmSdkDevelInfo?.Version ?? "not-installed"} to version={targetVersion} to match the resolved ROCm torch/runtime build."
+ )
+ );
+
+ var alignmentArgs = new PipInstallArgs()
+ .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl])
+ .AddArg("--force-reinstall")
+ .AddArg($"rocm-sdk-devel=={targetVersion}");
+
+ await venvRunner.PipInstall(alignmentArgs, onConsoleOutput).ConfigureAwait(false);
+ }
+
+ internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion)
+ {
+ if (!string.IsNullOrWhiteSpace(hipVersion))
+ return true;
+
+ return !string.IsNullOrWhiteSpace(version)
+ && version.Contains("rocm", StringComparison.OrdinalIgnoreCase);
+ }
+
+ internal static string? GetRocmSdkDevelAlignmentVersion(
+ string? rocmVersion,
+ string? rocmSdkDevelVersion,
+ string? torchVersion = null
+ )
+ {
+ var targetVersion = !string.IsNullOrWhiteSpace(rocmVersion)
+ ? rocmVersion
+ : TryExtractRocmBuildVersion(torchVersion);
+
+ if (string.IsNullOrWhiteSpace(targetVersion))
+ return null;
+
+ return string.Equals(targetVersion, rocmSdkDevelVersion, StringComparison.OrdinalIgnoreCase)
+ ? null
+ : targetVersion;
+ }
+
+ internal static string? TryGetWindowsNativeRocmRuntimeFailureReason(string? output)
+ {
+ if (string.IsNullOrWhiteSpace(output))
+ return null;
+
+ if (output.Contains("no ROCm-capable device is detected", StringComparison.OrdinalIgnoreCase))
+ {
+ return "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.";
+ }
+
+ if (output.Contains("No WDDM adapters found", StringComparison.OrdinalIgnoreCase))
+ {
+ return "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.";
+ }
+
+ return null;
+ }
+
+ internal static string? TryExtractRocmBuildVersion(string? torchVersion)
+ {
+ if (string.IsNullOrWhiteSpace(torchVersion))
+ return null;
+
+ var rocmMarkerIndex = torchVersion.IndexOf("rocm", StringComparison.OrdinalIgnoreCase);
+ if (rocmMarkerIndex < 0)
+ return null;
+
+ var rocmBuildVersion = torchVersion[(rocmMarkerIndex + "rocm".Length)..].Trim();
+ return string.IsNullOrWhiteSpace(rocmBuildVersion) ? null : rocmBuildVersion;
+ }
+
+ internal static string? TryExtractJsonObject(string output)
+ {
+ if (string.IsNullOrWhiteSpace(output))
+ return null;
+
+ var trimmedOutput = output.Trim();
+
+ for (var index = 0; index < trimmedOutput.Length; index++)
+ {
+ if (trimmedOutput[index] != '{')
+ continue;
+
+ try
+ {
+ using var document = JsonDocument.Parse(trimmedOutput[index..]);
+ return document.RootElement.GetRawText();
+ }
+ catch (JsonException) { }
+ }
+
+ return null;
+ }
+
+ internal static string CombineProcessOutput(string? standardOutput, string? standardError)
+ {
+ var sections = new[] { standardOutput?.Trim(), standardError?.Trim() }.Where(section =>
+ !string.IsNullOrWhiteSpace(section)
+ );
+
+ return string.Join(Environment.NewLine, sections);
+ }
+
+ ///
+ /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile.
+ ///
+ private IReadOnlyDictionary BuildHelperLaunchEnvironment(
+ RocmRuntimeContext runtimeContext,
+ RocmPackageProfile profile
+ )
+ {
+ var environment = new Dictionary(EnvComparer);
+ var options = profile.EnvironmentOptions;
+ var gfxArch = runtimeContext.RuntimeGfxArch;
+
+ ApplyDefaultLaunchEnvironment(environment, gfxArch, options);
+
+ return environment;
+ }
+
+ private void ApplyDefaultLaunchEnvironment(
+ IDictionary environment,
+ string? gfxArch,
+ RocmEnvironmentOptions options
+ )
+ {
+ SetIfNotNull(environment, "FLASH_ATTENTION_TRITON_AMD_ENABLE", options.FlashAttentionTritonAmdEnable);
+ SetIfNotNull(environment, "MIOPEN_FIND_MODE", options.MiopenFindMode);
+ SetIfNotNull(environment, "MIOPEN_SEARCH_CUTOFF", options.MiopenSearchCutoff);
+ SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce);
+ SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf);
+
+ if (options.ApplyAotritonExperimental && WindowsRocmSupport.IsModernArchitecture(gfxArch))
+ {
+ environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1";
+ }
+
+ if (options.ApplyLegacySdpFallback && WindowsRocmSupport.IsLegacyArchitecture(gfxArch))
+ {
+ environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0";
+ environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0";
+ environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1";
+ }
+
+ if (options.ApplyRdna1Override && WindowsRocmSupport.IsRdna1Architecture(gfxArch))
+ {
+ environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0";
+ }
+ }
+
+ private static void SetIfNotNull(IDictionary environment, string key, string? value)
+ {
+ if (!string.IsNullOrWhiteSpace(value))
+ {
+ environment[key] = value;
+ }
+ }
+
+ ///
+ /// Merges helper-owned and package-specific launch environment variables.
+ ///
+ private IReadOnlyDictionary MergeLaunchEnvironment(
+ IReadOnlyDictionary helperEnvironment,
+ IReadOnlyDictionary packageEnvironment,
+ RocmEnvironmentOptions options
+ )
+ {
+ var merged = new Dictionary(EnvComparer);
+
+ foreach (var source in new[] { helperEnvironment, packageEnvironment })
+ {
+ if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides)
+ continue;
+
+ foreach (var pair in source)
+ {
+ merged[pair.Key] = pair.Value;
+ }
+ }
+
+ if (
+ options.IncludeUserOverrides
+ && settingsManager.Settings.EnvironmentVariables is { Count: > 0 } userOverrides
+ )
+ {
+ foreach (var pair in userOverrides)
+ {
+ merged[pair.Key] = pair.Value;
+ }
+ }
+
+ return merged;
+ }
+}
diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs
new file mode 100644
index 000000000..0afb7b86d
--- /dev/null
+++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs
@@ -0,0 +1,176 @@
+using System.Text.Json;
+using StabilityMatrix.Core.Helper;
+using StabilityMatrix.Core.Helper.HardwareInfo;
+using StabilityMatrix.Core.Models.Rocm;
+using StabilityMatrix.Core.Services.Rocm;
+
+namespace StabilityMatrix.Tests.Core;
+
+[TestClass]
+public class RocmPackageHelperTests
+{
+ [TestMethod]
+ public void GetRocmSdkDevelAlignmentVersion_ReturnsRocmVersion_WhenVersionsMismatch()
+ {
+ var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion(
+ rocmVersion: "7.13.0a20260416",
+ rocmSdkDevelVersion: "7.13.0a20260501"
+ );
+
+ Assert.AreEqual("7.13.0a20260416", targetVersion);
+ }
+
+ [TestMethod]
+ public void GetRocmSdkDevelAlignmentVersion_ReturnsNull_WhenVersionsAlreadyMatch()
+ {
+ var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion(
+ rocmVersion: "7.13.0a20260416",
+ rocmSdkDevelVersion: "7.13.0a20260416"
+ );
+
+ Assert.IsNull(targetVersion);
+ }
+
+ [TestMethod]
+ public void GetRocmSdkDevelAlignmentVersion_FallsBackToTorchBuildVersion()
+ {
+ var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion(
+ rocmVersion: null,
+ rocmSdkDevelVersion: "7.13.0a20260501",
+ torchVersion: "2.11.0+rocm7.13.0a20260416"
+ );
+
+ Assert.AreEqual("7.13.0a20260416", targetVersion);
+ }
+
+ [TestMethod]
+ public void TryExtractRocmBuildVersion_ReturnsNull_WhenTorchVersionHasNoRocmTag()
+ {
+ var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0");
+
+ Assert.IsNull(rocmBuildVersion);
+ }
+
+ [TestMethod]
+ public void TryExtractRocmBuildVersion_ReturnsVersionSuffix_WhenTorchVersionContainsRocmTag()
+ {
+ var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0+rocm7.13.0a20260416");
+
+ Assert.AreEqual("7.13.0a20260416", rocmBuildVersion);
+ }
+
+ [TestMethod]
+ public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenHipMetadataExists()
+ {
+ var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild(
+ version: "test-version",
+ hipVersion: "test-hip-version"
+ );
+
+ Assert.IsTrue(isUsable);
+ }
+
+ [TestMethod]
+ public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenVersionContainsRocm()
+ {
+ var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild(
+ version: "test-version+rocm",
+ hipVersion: null
+ );
+
+ Assert.IsTrue(isUsable);
+ }
+
+ [TestMethod]
+ public void IsUsableWindowsNativeTorchBuild_ReturnsFalse_WhenNoRocmMetadataExists()
+ {
+ var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild(
+ version: "test-version",
+ hipVersion: null
+ );
+
+ Assert.IsFalse(isUsable);
+ }
+
+ [TestMethod]
+ public void TryExtractJsonObject_ReturnsJson_WhenOutputContainsDiagnosticPrefix()
+ {
+ const string output =
+ "warning: ROCm topology probe emitted diagnostic output"
+ + "\nwarning: continuing with torch verification"
+ + "\n{\"version\": \"test-version\", \"hip\": \"test-hip-version\", \"cuda\": false}";
+
+ var json = RocmPackageHelper.TryExtractJsonObject(output);
+
+ Assert.IsNotNull(json);
+
+ using var document = JsonDocument.Parse(json);
+ var root = document.RootElement;
+
+ Assert.AreEqual("test-version", root.GetProperty("version").GetString());
+ Assert.AreEqual("test-hip-version", root.GetProperty("hip").GetString());
+ Assert.IsFalse(root.GetProperty("cuda").GetBoolean());
+ }
+
+ [TestMethod]
+ public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson()
+ {
+ const string output =
+ "warning: ROCm topology probe emitted diagnostic output\n"
+ + "warning: no JSON payload was produced";
+
+ var json = RocmPackageHelper.TryExtractJsonObject(output);
+
+ Assert.IsNull(json);
+ }
+
+ [TestMethod]
+ public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsDeviceDetectionMessage()
+ {
+ const string output = "checkHipErrors() HIP API error = 0100 \"no ROCm-capable device is detected\"";
+
+ var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output);
+
+ Assert.AreEqual(
+ "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.",
+ reason
+ );
+ }
+
+ [TestMethod]
+ public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsWddmMessage()
+ {
+ const string output = "warning: No WDDM adapters found.";
+
+ var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output);
+
+ Assert.AreEqual(
+ "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.",
+ reason
+ );
+ }
+
+ [TestMethod]
+ public void CombineProcessOutput_JoinsStdoutAndStderr()
+ {
+ var combined = RocmPackageHelper.CombineProcessOutput("stdout line", "stderr line");
+
+ Assert.AreEqual($"stdout line{Environment.NewLine}stderr line", combined);
+ }
+
+ [TestMethod]
+ public void WindowsRocmSupport_TryGetPackageIndexUrl_ReturnsExpectedIndex_ForKrakenPoint()
+ {
+ var indexUrl = WindowsRocmSupport.TryGetPackageIndexUrl("gfx1152");
+
+ Assert.AreEqual("https://rocm.nightlies.amd.com/v2-staging/gfx1152/", indexUrl);
+ }
+
+ [TestMethod]
+ public void WindowsRocmSupport_IsSupportedGpu_ReturnsTrue_ForSupportedAmdGpu()
+ {
+ var gpu = new GpuInfo { Name = "AMD Radeon RX 9070 XT", MemoryBytes = 16UL * Size.GiB };
+
+ Assert.IsTrue(WindowsRocmSupport.IsSupportedGpu(gpu));
+ }
+}
diff --git a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs
index f78027039..bf45d60c4 100644
--- a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs
+++ b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs
@@ -24,6 +24,8 @@ public void Setup()
null!,
null!,
null!,
+ null!,
+ null!,
null!
);
}