From a03623b86d016d097d49661fc0b5f2809b5a6092 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 19:02:05 -0400 Subject: [PATCH 01/10] Scaffold universal ROCm helper types and service - Add initial ROCm helper structure - Set up ROCm helper foundation Compile test sucessful. --- .../Models/Rocm/RocmCompatibilityResult.cs | 17 +++ .../Models/Rocm/RocmEnvironmentOptions.cs | 33 ++++++ .../Models/Rocm/RocmInstallContext.cs | 23 ++++ .../Models/Rocm/RocmPackageProfile.cs | 65 +++++++++++ .../Models/Rocm/RocmRuntimeContext.cs | 33 ++++++ .../Models/Rocm/RocmSdkPaths.cs | 22 ++++ .../Services/Rocm/IRocmPackageHelper.cs | 79 ++++++++++++++ .../Services/Rocm/RocmPackageHelper.cs | 103 ++++++++++++++++++ 8 files changed, 375 insertions(+) create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs create mode 100644 StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs create mode 100644 StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs diff --git a/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs new file mode 100644 index 00000000..401f3ada --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs @@ -0,0 +1,17 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Describes whether a package/profile is currently compatible with ROCm on the active machine. +/// +public class RocmCompatibilityResult +{ + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? ResolvedGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs new file mode 100644 index 00000000..11c2bbfb --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -0,0 +1,33 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Controls how helper-generated, package-specific, and user-defined environment variables +/// should be layered together once the helper has real behavior. +/// +public class RocmEnvironmentOptions +{ + /// + /// Determines the merge order used when multiple environment sources provide the same key. + /// + public RocmEnvironmentOverlayPriority OverlayPriority { get; init; } = + RocmEnvironmentOverlayPriority.HelperThenPackageThenUser; + + /// + /// When true, package-specific environment additions may be merged on top of helper defaults. + /// + public bool IncludePackageOverrides { get; init; } = true; + + /// + /// When true, user-defined Stability Matrix environment variables may be merged last. + /// + public bool IncludeUserOverrides { get; init; } = true; +} + +/// +/// Describes the intended precedence of environment sources for ROCm-enabled package launches. +/// +public enum RocmEnvironmentOverlayPriority +{ + HelperThenPackageThenUser, + HelperThenUserThenPackage, +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs new file mode 100644 index 00000000..2055dd71 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -0,0 +1,23 @@ +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures ROCm-related facts needed during package install or update flows. +/// +public class RocmInstallContext +{ + public string? PreferredGfxArch { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public TorchIndex TorchIndex { get; init; } = TorchIndex.Rocm; + + public string? WheelCompatibilityHints { get; init; } + + public string? SdkRoot { get; init; } + + public RocmSdkPaths SdkPaths { get; init; } = new(); + + public IReadOnlyDictionary Environment { get; init; } = new Dictionary(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs new file mode 100644 index 00000000..a15c247d --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -0,0 +1,65 @@ +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Declares what a package expects from the ROCm helper. +/// Package classes should describe intent here rather than hardcoding ROCm decisions inline. +/// +public class RocmPackageProfile +{ + /// + /// Logical package name for diagnostics and profile-specific decisions. + /// + public string PackageName { get; init; } = string.Empty; + + public bool RequiresWindows { get; init; } + + public bool RequiresRocmSdk { get; init; } + + public bool NeedsRuntimeGfxResolution { get; init; } + + public bool NeedsHipPath { get; init; } + + public bool NeedsRocmPath { get; init; } + + public bool NeedsTritonOverrideArch { get; init; } + + public bool NeedsRdna1Override { get; init; } + + public bool NeedsLegacySdpFallback { get; init; } + + public bool NeedsAotritonExperimental { get; init; } + + public bool NeedsTunableOpCache { get; init; } + + public bool NeedsTritonCache { get; init; } + + public bool NeedsMIOpenDbPaths { get; init; } + + public bool NeedsRocblasPaths { get; init; } + + /// + /// Optional callback for package-specific cache path variables. + /// The helper will eventually merge these with its own defaults. + /// + public Func>? CacheDirectoryFactory { get; init; } + + /// + /// Optional callback for package-specific environment variables derived from a resolved ROCm context. + /// + public Func< + RocmRuntimeContext, + IReadOnlyDictionary + >? ExtraEnvironmentFactory { get; init; } + + /// + /// Optional progress message prefix or label that package code can surface during install/update work. + /// + public string? ProgressLabel { get; init; } + + /// + /// Controls how helper, package, and user-defined environment variables should be merged. + /// + public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs new file mode 100644 index 00000000..87c88ba6 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -0,0 +1,33 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures resolved ROCm facts for a package launch or runtime decision. +/// This model is intended to separate hardware/runtime facts from package policy. +/// +public class RocmRuntimeContext +{ + public bool IsSupported { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public bool IsLegacyGpu { get; init; } + + public bool IsRdna1 { get; init; } + + public string? HipPath { get; init; } + + public string? RocmPath { get; init; } + + public string? RocmSdkSitePackagesPath { get; init; } + + public RocmSdkPaths SdkPaths { get; init; } = new(); + + public IReadOnlyDictionary ResolvedEnvironment { get; init; } = + new Dictionary(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs new file mode 100644 index 00000000..5789744f --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs @@ -0,0 +1,22 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Represents ROCm SDK-related paths resolved for a package install. +/// These values are intentionally plain data so package code can decide which paths matter. +/// +public class RocmSdkPaths +{ + public string? RocmRoot { get; init; } + + public string? HipPath { get; init; } + + public string? RocmPath { get; init; } + + public string? RocmSdkSitePackagesPath { get; init; } + + public string? MioPenDbPath { get; init; } + + public string? RocblasDbPath { get; init; } + + public string? RocblasLibraryPath { get; init; } +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs new file mode 100644 index 00000000..5b9383fc --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -0,0 +1,79 @@ +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Defines the ROCm helper surface area shared by ROCm-capable packages. +/// +public interface IRocmPackageHelper +{ + /// + /// Evaluates whether the current machine and package profile are compatible with ROCm. + /// + Task GetCompatibilityAsync( + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Resolves the runtime ROCm facts needed for package launch and environment construction. + /// + Task ResolveRuntimeContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Resolves the ROCm facts needed during package installation or update operations. + /// + Task ResolveInstallContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Builds an install-time environment dictionary from a resolved install context. + /// + IReadOnlyDictionary BuildInstallEnvironment( + string installLocation, + RocmInstallContext context, + RocmPackageProfile profile + ); + + /// + /// Re-resolves ROCm install facts after a package update changes dependencies or runtime state. + /// + Task RefreshPackageAfterUpdateAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Builds a launch-time environment dictionary from resolved ROCm runtime data. + /// + Task> BuildLaunchEnvironmentAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Applies a resolved launch environment to the provided Python venv runner. + /// + Task ApplyLaunchEnvironmentAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); +} diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs new file mode 100644 index 00000000..c353dc15 --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -0,0 +1,103 @@ +using System.Collections.Immutable; +using Injectio.Attributes; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Provides the shared ROCm helper surface area used by ROCm-capable packages. +/// +[RegisterSingleton] +public class RocmPackageHelper : IRocmPackageHelper +{ + private const string NotImplementedMessage = "ROCm helper behavior has not been implemented yet."; + + /// + public Task GetCompatibilityAsync( + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + new RocmCompatibilityResult { IsCompatible = false, FailureReason = NotImplementedMessage } + ); + } + + /// + public Task ResolveRuntimeContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + new RocmRuntimeContext { IsSupported = false, FailureReason = NotImplementedMessage } + ); + } + + /// + public Task ResolveInstallContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult(new RocmInstallContext()); + } + + /// + public IReadOnlyDictionary BuildInstallEnvironment( + string installLocation, + RocmInstallContext context, + RocmPackageProfile profile + ) + { + return new Dictionary(); + } + + /// + public Task RefreshPackageAfterUpdateAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult(new RocmInstallContext()); + } + + /// + public Task> BuildLaunchEnvironmentAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult>(new Dictionary()); + } + + /// + public async Task ApplyLaunchEnvironmentAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + var environment = await BuildLaunchEnvironmentAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .ConfigureAwait(false); + + venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); + } +} From 218aff9a764a6131cae0e63f43170053d6dc73ea Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 19:17:43 -0400 Subject: [PATCH 02/10] Implement ROCm GPU detection helper --- .../Services/Rocm/RocmPackageHelper.cs | 288 +++++++++++++++++- 1 file changed, 280 insertions(+), 8 deletions(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index c353dc15..1adb141e 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -1,8 +1,12 @@ using System.Collections.Immutable; using Injectio.Attributes; +using NLog; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Services.Rocm; @@ -10,9 +14,25 @@ namespace StabilityMatrix.Core.Services.Rocm; /// Provides the shared ROCm helper surface area used by ROCm-capable packages. /// [RegisterSingleton] -public class RocmPackageHelper : IRocmPackageHelper +public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { - private const string NotImplementedMessage = "ROCm helper behavior has not been implemented yet."; + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + private static readonly string[] UnsupportedRdna2ModelMarkers = + [ + "680m", + "660m", + "610m", + "rx6300", + "w6300", + "rx6400", + "w6400", + "rx6450", + "rx6550", + ]; + + private const string EnvironmentNotImplementedMessage = + "ROCm helper environment composition has not been implemented yet."; /// public Task GetCompatibilityAsync( @@ -20,9 +40,7 @@ public Task GetCompatibilityAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult( - new RocmCompatibilityResult { IsCompatible = false, FailureReason = NotImplementedMessage } - ); + return Task.FromResult(BuildCompatibilityResult(profile)); } /// @@ -33,8 +51,43 @@ public Task ResolveRuntimeContextAsync( CancellationToken cancellationToken = default ) { + var compatibility = BuildCompatibilityResult(profile); + if (!compatibility.IsCompatible) + { + return Task.FromResult( + new RocmRuntimeContext + { + IsSupported = false, + FailureReason = compatibility.FailureReason, + SelectedGpu = compatibility.SelectedGpu, + RuntimeGfxArch = compatibility.ResolvedGfxArch, + } + ); + } + + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var selectedGpu = + compatibility.SelectedGpu + ?? TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) + ?? supportedAmdGpus.FirstOrDefault(); + + var runtimeGfxArch = + compatibility.ResolvedGfxArch + ?? selectedGpu?.GetAmdGfxArch() + ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + return Task.FromResult( - new RocmRuntimeContext { IsSupported = false, FailureReason = NotImplementedMessage } + new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + IsLegacyGpu = IsLegacyArchitecture(runtimeGfxArch), + IsRdna1 = IsRdna1Architecture(runtimeGfxArch), + } ); } @@ -46,7 +99,22 @@ public Task ResolveInstallContextAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult(new RocmInstallContext()); + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var preferredGfxArch = TryResolvePreferredAmdGfxArch( + supportedAmdGpus, + settingsManager.Settings.PreferredGpu + ); + + return Task.FromResult( + new RocmInstallContext + { + PreferredGfxArch = preferredGfxArch, + RuntimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus), + } + ); } /// @@ -56,6 +124,9 @@ public IReadOnlyDictionary BuildInstallEnvironment( RocmPackageProfile profile ) { + _ = installLocation; + _ = context; + _ = profile; return new Dictionary(); } @@ -67,7 +138,7 @@ public Task RefreshPackageAfterUpdateAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult(new RocmInstallContext()); + return ResolveInstallContextAsync(installLocation, installedPackage, profile, cancellationToken); } /// @@ -78,6 +149,10 @@ public Task> BuildLaunchEnvironmentAsync( CancellationToken cancellationToken = default ) { + _ = installLocation; + _ = installedPackage; + _ = profile; + _ = cancellationToken; return Task.FromResult>(new Dictionary()); } @@ -100,4 +175,201 @@ public async Task ApplyLaunchEnvironmentAsync( venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); } + + /// + /// Builds a compatibility result from the current machine state and package profile. + /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. + /// + private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) + { + if (profile.RequiresWindows && !Compat.IsWindows) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = "This ROCm profile currently requires Windows.", + }; + } + + var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); + if (amdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = "No AMD GPU was detected for ROCm evaluation.", + }; + } + + var preferredGpu = settingsManager.Settings.PreferredGpu; + if (preferredGpu is not null && IsExplicitlyUnsupportedRdna2Gpu(preferredGpu)) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = $"Selected GPU '{preferredGpu.Name}' is unsupported for Windows ROCm.", + SelectedGpu = preferredGpu, + }; + } + + var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); + if (supportedAmdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = GetUnsupportedGpuReason(amdGpus), + }; + } + + var selectedGpu = + TryResolvePreferredAmdGpu(supportedAmdGpus, preferredGpu) ?? supportedAmdGpus.First(); + var resolvedGfxArch = selectedGpu.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + + return new RocmCompatibilityResult + { + IsCompatible = !string.IsNullOrWhiteSpace(resolvedGfxArch), + FailureReason = string.IsNullOrWhiteSpace(resolvedGfxArch) + ? "No supported AMD GFX architecture could be resolved for ROCm." + : null, + SelectedGpu = selectedGpu, + ResolvedGfxArch = resolvedGfxArch, + }; + } + + /// + /// Returns AMD GPUs from Stability Matrix's internal hardware model. + /// This is the canonical GPU source for the ROCm helper and intentionally avoids package-local probing. + /// + private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = false) + { + return HardwareHelper.IterGpuInfo(forceRefresh).Where(gpu => gpu.IsAmd).ToList(); + } + + /// + /// Resolves the preferred AMD GPU when the configured preference is still present in the current hardware list. + /// + private static GpuInfo? TryResolvePreferredAmdGpu( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + if (preferredGpu is null || !preferredGpu.IsAmd) + return null; + + var preferredMatch = availableGpus.FirstOrDefault(gpu => gpu.Equals(preferredGpu)); + if (preferredMatch is not null) + return preferredMatch; + + if (!string.IsNullOrWhiteSpace(preferredGpu.Name)) + { + Logger.Info( + "Preferred GPU {PreferredGpuName} was ignored for ROCm detection because it is not present in current hardware enumeration.", + preferredGpu.Name + ); + } + + return null; + } + + /// + /// Resolves the preferred AMD GFX architecture when the configured GPU is supported and currently present. + /// + private static string? TryResolvePreferredAmdGfxArch( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu); + return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu) + ? resolvedPreferredGpu.GetAmdGfxArch() + : null; + } + + /// + /// Resolves the first supported AMD GFX architecture from the current machine state when no preferred GPU applies. + /// + private static string? GetSupportedFallbackGfxArch(IEnumerable availableGpus) + { + return availableGpus + .Where(IsSupportedWindowsRocmGpu) + .Select(gpu => gpu.GetAmdGfxArch()) + .FirstOrDefault(IsSupportedWindowsRocmArchitecture); + } + + /// + /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. + /// Unsupported low-end RDNA2/APU models are filtered explicitly even when they identify as AMD hardware. + /// + private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) + { + if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) + return false; + + return IsSupportedWindowsRocmArchitecture(gpu.GetAmdGfxArch()); + } + + /// + /// Identifies Windows ROCm-incompatible RDNA2 models that need to remain outside the supported GPU set. + /// + private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) + { + if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + var normalizedName = gpu.Name.Replace(" ", string.Empty, StringComparison.Ordinal).ToLowerInvariant(); + return UnsupportedRdna2ModelMarkers.Any(normalizedName.Contains); + } + + /// + /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return gfxArch switch + { + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => true, + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => true, + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => true, + "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => true, + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => true, + _ => false, + }; + } + + /// + /// Returns true for architectures that need the legacy ROCm runtime path. + /// + private static bool IsLegacyArchitecture(string? gfxArch) + { + return gfxArch is not null + && ( + gfxArch.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) + || gfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) + ); + } + + /// + /// Returns true for RDNA1 architectures that need dedicated override handling. + /// + private static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + + /// + /// Produces a readable incompatibility reason when AMD hardware is present but not usable for Windows ROCm. + /// + private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) + { + if (amdGpus.Any(IsExplicitlyUnsupportedRdna2Gpu)) + { + return "Detected only unsupported AMD RDNA2 GPUs for Windows ROCm. Unsupported models include Radeon 680M/660M/610M and RX 6300/6400/6450/6550-class GPUs."; + } + + return "No AMD GPU with a supported Windows ROCm architecture was detected."; + } } From 70f852d523520f27d9c500e99462280112d313d2 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 20:11:58 -0400 Subject: [PATCH 03/10] Initial ComfyUI.cs intergration - Add initial ROCm helper calls/config - Removed pre-existing Windows ROCm blocks which will be obsolete following helper implementation --- .../Models/Packages/ComfyUI.cs | 156 +++++++----------- 1 file changed, 62 insertions(+), 94 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a4c34649..a0b5fb44 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -13,9 +13,11 @@ using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -26,7 +28,8 @@ public class ComfyUI( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -38,6 +41,14 @@ IPipWheelService pipWheelService ) { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + PackageName = "ComfyUI", + RequiresWindows = true, + NeedsRuntimeGfxResolution = true, + }; + public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -247,7 +258,7 @@ IPipWheelService pipWheelService Name = "Enable DirectML", Type = LaunchOptionType.Bool, InitialValue = - !HardwareHelper.HasWindowsRocmSupportedGpu() + !HasWindowsRocmSupport() && HardwareHelper.PreferDirectMLOrZluda() && this is not ComfyZluda, Options = ["--directml"], @@ -362,91 +373,34 @@ public override async Task InstallPackage( .ConfigureAwait(false); var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion(); - var gfxArch = - SettingsManager.Settings.PreferredGpu?.GetAmdGfxArch() - ?? HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - - // Special case for Windows ROCm Nightly builds - if ( - Compat.IsWindows - && !string.IsNullOrWhiteSpace(gfxArch) - && torchIndex is TorchIndex.Rocm - && options.PythonOptions.PythonVersion >= PyVersion.Parse("3.11.0") - ) - { - var config = new PipInstallConfig - { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - SkipTorchInstall = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); - - progress?.Report( - new ProgressReport(-1f, "Installing ROCm nightly torch...", isIndeterminate: true) + var isLegacyNvidia = + torchIndex == TorchIndex.Cuda + && ( + SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() + ?? HardwareHelper.HasLegacyNvidiaGpu() ); - var indexUrl = gfxArch switch - { - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150", // Strix/Gorgon Point - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151", // Strix Halo - _ when gfxArch.StartsWith("gfx110") => "https://rocm.nightlies.amd.com/v2/gfx110X-all", - _ when gfxArch.StartsWith("gfx120") => "https://rocm.nightlies.amd.com/v2/gfx120X-all", - _ => throw new ArgumentOutOfRangeException( - nameof(gfxArch), - $"Unsupported GFX Arch: {gfxArch}" - ), - }; - - var torchPipArgs = new PipInstallArgs() - .AddArgs("--pre", "--upgrade") - .WithTorch() - .WithTorchVision() - .WithTorchAudio() - .AddArgs("--index-url", indexUrl); - - await venvRunner.PipInstall(torchPipArgs, onConsoleOutput).ConfigureAwait(false); - } - else // Standard installation path for all other cases + + var config = new PipInstallConfig { - var isLegacyNvidia = - torchIndex == TorchIndex.Cuda - && ( - SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() - ?? HardwareHelper.HasLegacyNvidiaGpu() - ); + RequirementsFilePaths = ["requirements.txt"], + ExtraPipArgs = ["numpy<2"], + TorchaudioVersion = " ", // Request torchaudio without a specific version + CudaIndex = isLegacyNvidia ? "cu126" : "cu130", + RocmIndex = "rocm7.2", + UpgradePackages = true, + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + }; - var config = new PipInstallConfig - { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - TorchaudioVersion = " ", // Request torchaudio without a specific version - CudaIndex = isLegacyNvidia ? "cu126" : "cu130", - RocmIndex = "rocm7.2", - UpgradePackages = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); - } + await StandardPipInstallProcessAsync( + venvRunner, + options, + installedPackage, + config, + onConsoleOutput, + progress, + cancellationToken + ) + .ConfigureAwait(false); try { @@ -613,13 +567,7 @@ public override TorchIndex GetRecommendedTorchVersion() { var preferRocm = (Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm())) - || ( - Compat.IsWindows - && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() - ) - ); + || HasWindowsRocmSupport(); if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm) { @@ -629,6 +577,28 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } + /// + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. + /// + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + { + return SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() + ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + } + + var compatibility = rocmPackageHelper + .GetCompatibilityAsync(WindowsRocmProfile) + .GetAwaiter() + .GetResult(); + + return compatibility.IsCompatible; + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -982,9 +952,7 @@ await PipWheelService private ImmutableDictionary GetEnvVars(ImmutableDictionary env) { // if we're not on windows or we don't have a windows rocm gpu, return original env - var hasRocmGpu = - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + var hasRocmGpu = HasWindowsRocmSupport(); if (!Compat.IsWindows || !hasRocmGpu) return env; From 84f7f95eeb5d1733f93f1f3341ffbb2ad79451e7 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 21:15:57 -0400 Subject: [PATCH 04/10] implement helper-owned Windows ROCm install flow - Windows ROCm install/bootstrap logic into shared ROCm helper - Add gfx-family mapping for Windows-native TheRock ROCm URLs - Route ComfyUI Win Rocm installs through helper-resolved ROCm runtime, rocm-sdk, and pytorch setup - Prevent requirements.txt from overwritting helper-installed ROCm torch packages - Add helper-owned post-install torch verification and improve unsupported GPU failure handling --- .../Models/Packages/ComfyUI.cs | 98 +++-- .../Models/Rocm/RocmInstallContext.cs | 4 + .../Models/Rocm/RocmPackageProfile.cs | 30 ++ .../Services/Rocm/IRocmPackageHelper.cs | 15 + .../Services/Rocm/RocmPackageHelper.cs | 336 +++++++++++++++++- 5 files changed, 447 insertions(+), 36 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a0b5fb44..e3f54269 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -46,7 +46,19 @@ public class ComfyUI( { PackageName = "ComfyUI", RequiresWindows = true, + RequiresRocmSdk = true, NeedsRuntimeGfxResolution = true, + NeedsAotritonExperimental = true, + NeedsTunableOpCache = true, + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = _ => new Dictionary + { + ["MIOPEN_FIND_MODE"] = "2", + ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:512,garbage_collection_threshold:0.8", + ["COMFYUI_ENABLE_MIOPEN"] = "1", + }, }; public override string Name => "ComfyUI"; @@ -380,27 +392,51 @@ public override async Task InstallPackage( ?? HardwareHelper.HasLegacyNvidiaGpu() ); - var config = new PipInstallConfig + if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport()) { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - TorchaudioVersion = " ", // Request torchaudio without a specific version - CudaIndex = isLegacyNvidia ? "cu126" : "cu130", - RocmIndex = "rocm7.2", - UpgradePackages = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; + if (rocmPackageHelper is null) + { + throw new InvalidOperationException( + "Windows ROCm installation requires the shared ROCm helper to resolve gfx-specific index URLs." + ); + } - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); + } + else + { + var config = new PipInstallConfig + { + RequirementsFilePaths = ["requirements.txt"], + ExtraPipArgs = ["numpy<2"], + TorchaudioVersion = " ", // Request torchaudio without a specific version + CudaIndex = isLegacyNvidia ? "cu126" : "cu130", + RocmIndex = "rocm7.2", + UpgradePackages = true, + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + }; + + await StandardPipInstallProcessAsync( + venvRunner, + options, + installedPackage, + config, + onConsoleOutput, + progress, + cancellationToken + ) + .ConfigureAwait(false); + } try { @@ -433,7 +469,11 @@ await StandardPipInstallProcessAsync( SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu(), WorkingDirectory = installLocation, - EnvironmentVariables = GetEnvVars(venvRunner.EnvironmentVariables), + EnvironmentVariables = GetEnvVars( + venvRunner.EnvironmentVariables, + installLocation, + installedPackage + ), }; await step.ExecuteAsync(progress).ConfigureAwait(false); @@ -483,7 +523,7 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); - VenvRunner.UpdateEnvironmentVariables(GetEnvVars); + VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage)); // Check for old NVIDIA driver version with cu130 installations var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu(); @@ -949,7 +989,11 @@ await PipWheelService .ConfigureAwait(false); } - private ImmutableDictionary GetEnvVars(ImmutableDictionary env) + private ImmutableDictionary GetEnvVars( + ImmutableDictionary env, + string installLocation, + InstalledPackage installedPackage + ) { // if we're not on windows or we don't have a windows rocm gpu, return original env var hasRocmGpu = HasWindowsRocmSupport(); @@ -957,6 +1001,16 @@ private ImmutableDictionary GetEnvVars(ImmutableDictionary + /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. + /// + public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"]; + + /// + /// Package requirement entries to exclude because the helper installs them from ROCm-specific feeds. + /// + public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?"; + + /// + /// Extra package-specific pip arguments to include when installing requirements after helper bootstrap. + /// + public IEnumerable ExtraInstallPipArgs { get; init; } = []; + + /// + /// Extra package-specific pip arguments to install after requirements and torch are complete. + /// + public IEnumerable PostInstallPipArgs { get; init; } = []; + + /// + /// When true, helper-managed requirements installs should use --upgrade. + /// + public bool UpgradePackages { get; init; } + + /// + /// When true, helper-managed torch installs should force reinstall the selected ROCm wheel set. + /// + public bool ForceReinstallTorch { get; init; } = true; + /// /// Optional callback for package-specific cache path variables. /// The helper will eventually merge these with its own defaults. diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 5b9383fc..0fac954e 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -1,5 +1,7 @@ using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; namespace StabilityMatrix.Core.Services.Rocm; @@ -76,4 +78,17 @@ Task ApplyLaunchEnvironmentAsync( RocmPackageProfile profile, CancellationToken cancellationToken = default ); + + /// + /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs. + /// + Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); } diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 1adb141e..a2c700b6 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -1,10 +1,15 @@ using System.Collections.Immutable; +using System.Text.Json; using Injectio.Attributes; using NLog; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; @@ -31,9 +36,6 @@ public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageH "rx6550", ]; - private const string EnvironmentNotImplementedMessage = - "ROCm helper environment composition has not been implemented yet."; - /// public Task GetCompatibilityAsync( RocmPackageProfile profile, @@ -99,6 +101,10 @@ public Task ResolveInstallContextAsync( CancellationToken cancellationToken = default ) { + _ = installLocation; + _ = installedPackage; + _ = cancellationToken; + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) .Where(IsSupportedWindowsRocmGpu) .ToList(); @@ -108,11 +114,16 @@ public Task ResolveInstallContextAsync( settingsManager.Settings.PreferredGpu ); + var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + var windowsNativeIndexUrl = TryGetWindowsNativeRocmIndexUrl(runtimeGfxArch); + return Task.FromResult( new RocmInstallContext { PreferredGfxArch = preferredGfxArch, - RuntimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus), + RuntimeGfxArch = runtimeGfxArch, + RocmPackageIndexUrl = windowsNativeIndexUrl, + RocmTorchIndexUrl = windowsNativeIndexUrl, } ); } @@ -151,9 +162,30 @@ public Task> BuildLaunchEnvironmentAsync( { _ = installLocation; _ = installedPackage; - _ = profile; - _ = cancellationToken; - return Task.FromResult>(new Dictionary()); + + var runtimeContext = ResolveRuntimeContextAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .GetAwaiter() + .GetResult(); + + if (!runtimeContext.IsSupported) + return Task.FromResult>(new Dictionary()); + + var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); + var packageEnvironment = + profile.ExtraEnvironmentFactory?.Invoke(runtimeContext) ?? new Dictionary(); + + var mergedEnvironment = MergeLaunchEnvironment( + helperEnvironment, + packageEnvironment, + profile.EnvironmentOptions + ); + + return Task.FromResult>(mergedEnvironment); } /// @@ -176,6 +208,146 @@ public async Task ApplyLaunchEnvironmentAsync( venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); } + /// + public async Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var compatibility = await GetCompatibilityAsync(profile, cancellationToken).ConfigureAwait(false); + if (!compatibility.IsCompatible) + { + throw new ApplicationException( + compatibility.FailureReason + ?? "Windows ROCm installation is not supported for the current machine." + ); + } + + var installContext = await ResolveInstallContextAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .ConfigureAwait(false); + + var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; + var rocmTorchIndexUrl = installContext.RocmTorchIndexUrl ?? rocmPackageIndexUrl; + + if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl) || string.IsNullOrWhiteSpace(rocmTorchIndexUrl)) + { + throw new ApplicationException( + $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." + ); + } + + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); + var rocmRuntimeArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArgs("rocm[devel,libraries]", "--no-warn-script-location"); + + if (installedPackage.PipOverrides != null) + { + rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( + rocmSdkExe, + ["init"], + installLocation, + onConsoleOutput + ); + + await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + if (rocmSdkProcess.ExitCode != 0) + { + throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); + } + } + + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmTorchIndexUrl]) + .AddArgs("torch", "torchaudio", "torchvision", "--no-warn-script-location"); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } + + if (installedPackage.PipOverrides != null) + { + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report( + new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) + ); + + var requirementsPipArgs = new PipInstallArgs([.. profile.ExtraInstallPipArgs]); + if (profile.UpgradePackages) + { + requirementsPipArgs = requirementsPipArgs.AddArg("--upgrade"); + } + + foreach (var relativePath in profile.RequirementsFilePaths) + { + var requirementsFile = new FilePath(venvRunner.WorkingDirectory ?? installLocation, relativePath); + if (!requirementsFile.Exists) + continue; + + var requirementsContent = await requirementsFile + .ReadAllTextAsync(cancellationToken) + .ConfigureAwait(false); + + requirementsPipArgs = requirementsPipArgs.WithParsedFromRequirementsTxt( + requirementsContent, + profile.RequirementsExcludePattern + ); + } + + if (installedPackage.PipOverrides != null) + { + requirementsPipArgs = requirementsPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); + + if (!profile.PostInstallPipArgs.Any()) + return; + + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput).ConfigureAwait(false); + } + /// /// Builds a compatibility result from the current machine state and package profile. /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. @@ -309,7 +481,7 @@ private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) return false; - return IsSupportedWindowsRocmArchitecture(gpu.GetAmdGfxArch()); + return TryGetWindowsNativeRocmIndexUrl(gpu.GetAmdGfxArch()) is not null; } /// @@ -328,15 +500,30 @@ private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. /// private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return TryGetWindowsNativeRocmIndexUrl(gfxArch) is not null; + } + + /// + /// Maps an AMD GFX architecture identifier to the Windows-native ROCm Technical Preview feed URL. + /// + private static string? TryGetWindowsNativeRocmIndexUrl(string? gfxArch) { return gfxArch switch { - var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => true, - var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => true, - var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => true, - "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => true, - var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => true, - _ => false, + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/", + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx110X-all/", + "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", + "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", + "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", + "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx120X-all/", + _ => null, }; } @@ -372,4 +559,125 @@ private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) return "No AMD GPU with a supported Windows ROCm architecture was detected."; } + + /// + /// Verifies that the installed torch build still reports a usable ROCm runtime after helper-managed installs complete. + /// + private static async Task VerifyWindowsNativeTorchInstallAsync( + IPyVenvRunner venvRunner, + Action? onConsoleOutput + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new ApplicationException("torch was not installed after Windows ROCm setup."); + } + + var verificationResult = await venvRunner + .Run( + "-c \"import json, torch; print(json.dumps({'version': torch.__version__, 'hip': torch.version.hip, 'cuda': torch.cuda.is_available()}))\"" + ) + .ConfigureAwait(false); + + var verificationOutput = (verificationResult.StandardOutput ?? string.Empty).Trim(); + if (string.IsNullOrWhiteSpace(verificationOutput)) + { + throw new ApplicationException("Torch verification produced no output."); + } + + JsonDocument verificationDocument; + try + { + verificationDocument = JsonDocument.Parse(verificationOutput); + } + catch (Exception exception) + { + throw new ApplicationException( + $"Unexpected torch verification output: {verificationOutput}", + exception + ); + } + + using (verificationDocument) + { + var root = verificationDocument.RootElement; + var version = root.TryGetProperty("version", out var versionElement) + ? versionElement.GetString() + : null; + var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; + var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); + + if (string.IsNullOrWhiteSpace(hipVersion) || !cudaAvailable) + { + throw new ApplicationException( + $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" + ) + ); + } + } + + /// + /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. + /// + private static IReadOnlyDictionary BuildHelperLaunchEnvironment( + RocmRuntimeContext runtimeContext, + RocmPackageProfile profile + ) + { + var environment = new Dictionary(); + + if (profile.NeedsTunableOpCache) + { + environment["PYTORCH_TUNABLEOP_ENABLED"] = "1"; + } + + if (profile.NeedsAotritonExperimental) + { + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + } + + if (profile.NeedsTritonOverrideArch && !string.IsNullOrWhiteSpace(runtimeContext.RuntimeGfxArch)) + { + environment["HSA_OVERRIDE_GFX_VERSION"] = runtimeContext.RuntimeGfxArch; + } + + return environment; + } + + /// + /// Merges helper-owned and package-specific launch environment variables using the profile overlay rules. + /// + private static IReadOnlyDictionary MergeLaunchEnvironment( + IReadOnlyDictionary helperEnvironment, + IReadOnlyDictionary packageEnvironment, + RocmEnvironmentOptions options + ) + { + var merged = new Dictionary(); + + IReadOnlyDictionary[] orderedSources = + options.OverlayPriority == RocmEnvironmentOverlayPriority.HelperThenUserThenPackage + ? new[] { helperEnvironment, packageEnvironment } + : new[] { helperEnvironment, packageEnvironment }; + + foreach (var source in orderedSources) + { + if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) + continue; + + foreach (var pair in source) + { + merged[pair.Key] = pair.Value; + } + } + + return merged; + } } From 31b5955a8cf47a4475e99cd4666a350af9b8d1e1 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 14:51:20 -0400 Subject: [PATCH 05/10] =?UTF-8?q?-=20refactor=20the=20shared=20ROCm=20help?= =?UTF-8?q?er=20into=20a=20synchronous=20compatibility/runtime/install/env?= =?UTF-8?q?ironment=20API=20and=20simplify=20the=20ROCm=20profile/context?= =?UTF-8?q?=20models=20around=20the=20helper=E2=80=99s=20real=20responsibi?= =?UTF-8?q?lities=20-=20add=20a=20centralized=20Windows=20ROCm=20support?= =?UTF-8?q?=20map=20so=20GPU=20detection,=20architecture=20support=20check?= =?UTF-8?q?s,=20and=20package=20index=20resolution=20all=20use=20the=20sam?= =?UTF-8?q?e=20source=20of=20truth=20-=20expand=20AMD=20architecture=20det?= =?UTF-8?q?ection=20to=20cover=20additional=20RDNA4,=20Steam=20Deck,=20RDN?= =?UTF-8?q?A1,=20and=20Vega-class=20GPUs=20used=20by=20the=20Windows=20ROC?= =?UTF-8?q?m=20support=20path=20-=20add=20a=20helper-managed=20Windows=20R?= =?UTF-8?q?OCm=20bootstrap=20flow=20that=20installs=20the=20ROCm=20runtime?= =?UTF-8?q?,=20initializes/reinitializes=20the=20SDK,=20aligns=20rocm-sdk-?= =?UTF-8?q?devel=20with=20the=20resolved=20torch=20build,=20and=20verifies?= =?UTF-8?q?=20both=20torch=20ROCm=20metadata=20and=20runtime=20availabilit?= =?UTF-8?q?y=20-=20centralize=20ROCm=20launch=20environment=20construction?= =?UTF-8?q?=20in=20the=20helper,=20including=20default=20MIOpen,=20allocat?= =?UTF-8?q?or,=20flash-attention,=20and=20AOTriton=20settings=20plus=20leg?= =?UTF-8?q?acy=20SDP=20fallback,=20RDNA1=20overrides,=20and=20user=20env?= =?UTF-8?q?=20override=20layering=20-=20switch=20ComfyUI=20to=20helper-dri?= =?UTF-8?q?ven=20Windows=20ROCm=20compatibility=20and=20launch=20env=20han?= =?UTF-8?q?dling,=20and=20default=20legacy=20Windows=20ROCm=20GPUs=20to=20?= =?UTF-8?q?quad=20cross-attention=20while=20keeping=20Comfy-specific=20MIO?= =?UTF-8?q?pen=20enablement=20as=20a=20preset=20-=20integrate=20Wan2GP=20w?= =?UTF-8?q?ith=20the=20shared=20Windows=20ROCm=20helper=20for=20install=20?= =?UTF-8?q?and=20launch=20flows,=20while=20updating=20its=20Linux=20ROCm?= =?UTF-8?q?=20path=20to=20use=20upstream=20rocm7.2=20torch/vision/audio=20?= =?UTF-8?q?installs=20-=20wire=20the=20ROCm=20helper=20through=20package?= =?UTF-8?q?=20construction=20and=20add=20focused=20test=20coverage=20for?= =?UTF-8?q?=20ROCm=20build/version=20parsing,=20runtime=20failure=20classi?= =?UTF-8?q?fication,=20and=20Windows=20ROCm=20support/index=20resolution?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Helper/Factory/PackageFactory.cs | 11 +- .../Helper/HardwareInfo/GpuInfo.cs | 25 +- .../Helper/HardwareInfo/HardwareHelper.cs | 6 +- .../Models/Packages/ComfyUI.cs | 57 +- .../Models/Packages/Wan2GP.cs | 139 ++-- .../Models/Rocm/RocmEnvironmentOptions.cs | 65 +- .../Models/Rocm/RocmInstallContext.cs | 16 - .../Models/Rocm/RocmPackageProfile.cs | 37 +- .../Models/Rocm/RocmRuntimeContext.cs | 15 - .../Models/Rocm/RocmSdkPaths.cs | 22 - .../Models/Rocm/WindowsRocmSupport.cs | 46 ++ .../Services/Rocm/IRocmPackageHelper.cs | 48 +- .../Services/Rocm/RocmPackageHelper.cs | 677 +++++++++++------- .../Core/RocmPackageHelperTests.cs | 176 +++++ .../Helper/PackageFactoryTests.cs | 2 + 15 files changed, 823 insertions(+), 519 deletions(-) delete mode 100644 StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs create mode 100644 StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 118efa55..6a073986 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -4,6 +4,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Helper.Factory; @@ -18,6 +19,7 @@ public class PackageFactory : IPackageFactory private readonly IUvManager uvManager; private readonly IPyInstallationManager pyInstallationManager; private readonly IPipWheelService pipWheelService; + private readonly IRocmPackageHelper rocmPackageHelper; /// /// Mapping of package.Name to package @@ -32,7 +34,9 @@ public PackageFactory( IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, IPyRunner pyRunner, - IPipWheelService pipWheelService + IUvManager uvManager, + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) { this.githubApiCache = githubApiCache; @@ -40,8 +44,10 @@ IPipWheelService pipWheelService this.downloadService = downloadService; this.prerequisiteHelper = prerequisiteHelper; this.pyRunner = pyRunner; + this.uvManager = uvManager; this.pyInstallationManager = pyInstallationManager; this.pipWheelService = pipWheelService; + this.rocmPackageHelper = rocmPackageHelper; this.basePackages = basePackages.ToDictionary(x => x.Name); } @@ -55,7 +61,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "Fooocus" => new Fooocus( githubApiCache, diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index eedcb556..0013f65b 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -1,4 +1,6 @@ -namespace StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; + +namespace StabilityMatrix.Core.Helper.HardwareInfo; public record GpuInfo { @@ -62,11 +64,7 @@ public bool IsLegacyNvidiaGpu() public bool IsWindowsRocmSupportedGpu() { - var gfx = GetAmdGfxArch(); - if (gfx is null) - return false; - - return gfx.StartsWith("gfx110") || gfx.StartsWith("gfx120") || gfx.Equals("gfx1151"); + return WindowsRocmSupport.IsSupportedGpu(this); } public bool IsAmd => Name?.Contains("amd", StringComparison.OrdinalIgnoreCase) ?? false; @@ -84,7 +82,7 @@ public bool IsWindowsRocmSupportedGpu() return name switch { // RDNA4 - _ when Has("R9700") || Has("9070") => "gfx1201", + _ when Has("R9700") || Has("R9600") || Has("9070") => "gfx1201", _ when Has("9060") => "gfx1200", // RDNA3.5 APUs @@ -112,6 +110,9 @@ _ when Has("660M") || Has("680M") => "gfx1035", _ when Has("6300") || Has("6400") || Has("6450") || Has("6500") || Has("6550") || Has("6500M") => "gfx1034", + // RDNA2 Steam Deck APU + _ when Has("Van Gogh") || Has("Sephiroth") => "gfx1033", + // RDNA2 Navi23 _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M") => "gfx1032", @@ -121,6 +122,16 @@ _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031", // RDNA2 Navi21 (big die) _ when Has("6800") || Has("6900") || Has("6950") => "gfx1030", + // RDNA1 Navi10 XT (incl. Pro card) + _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010", + + // RDNA1 Navi10 XTX + _ when Has("5500") => "gfx1012", + + // Vega/GCN5 Dedicated GPUs + _ when Has("pro vii") || HasNoSpace("provii") => "gfx90X", + _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", + _ when Has("radeon vii") || HasNoSpace("radeonvii") => "gfx906", _ => null, }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs index 8458c730..93f093d4 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs @@ -7,6 +7,7 @@ using Microsoft.Win32; using NLog; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.Rocm; namespace StabilityMatrix.Core.Helper.HardwareInfo; @@ -316,12 +317,11 @@ public static bool HasAmdGpu() return IterGpuInfo().Any(gpu => gpu.IsAmd); } - public static bool HasWindowsRocmSupportedGpu() => - IterGpuInfo().Any(gpu => gpu is { IsAmd: true, Name: not null } && gpu.IsWindowsRocmSupportedGpu()); + public static bool HasWindowsRocmSupportedGpu() => IterGpuInfo().Any(WindowsRocmSupport.IsSupportedGpu); public static GpuInfo? GetWindowsRocmSupportedGpu() { - return IterGpuInfo().FirstOrDefault(gpu => gpu.IsWindowsRocmSupportedGpu()); + return IterGpuInfo().FirstOrDefault(WindowsRocmSupport.IsSupportedGpu); } public static bool HasIntelGpu() => IterGpuInfo().Any(gpu => gpu.IsIntel); diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index e3f54269..1832a27a 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -45,20 +45,11 @@ public class ComfyUI( private static readonly RocmPackageProfile WindowsRocmProfile = new() { PackageName = "ComfyUI", - RequiresWindows = true, RequiresRocmSdk = true, - NeedsRuntimeGfxResolution = true, - NeedsAotritonExperimental = true, - NeedsTunableOpCache = true, ExtraInstallPipArgs = ["numpy<2"], PostInstallPipArgs = ["typing-extensions>=4.15.0"], UpgradePackages = true, - ExtraEnvironmentFactory = _ => new Dictionary - { - ["MIOPEN_FIND_MODE"] = "2", - ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:512,garbage_collection_threshold:0.8", - ["COMFYUI_ENABLE_MIOPEN"] = "1", - }, + EnvironmentOptions = new RocmEnvironmentOptions { Preset = RocmEnvironmentPreset.ComfyUi }, }; public override string Name => "ComfyUI"; @@ -287,7 +278,9 @@ public class ComfyUI( { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = "--use-pytorch-cross-attention", + InitialValue = ShouldDefaultToQuadCrossAttention() + ? "--use-quad-cross-attention" + : "--use-pytorch-cross-attention", Options = [ "--use-split-cross-attention", @@ -626,19 +619,29 @@ private bool HasWindowsRocmSupport() return false; if (rocmPackageHelper is null) - { - return SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); - } + return false; - var compatibility = rocmPackageHelper - .GetCompatibilityAsync(WindowsRocmProfile) - .GetAwaiter() - .GetResult(); + var compatibility = rocmPackageHelper.GetCompatibility(WindowsRocmProfile); return compatibility.IsCompatible; } + private bool ShouldDefaultToQuadCrossAttention() + { + if (!Compat.IsWindows || !HasWindowsRocmSupport()) + return false; + + var gpu = SettingsManager.Settings.PreferredGpu; + var gfxArch = WindowsRocmSupport.IsSupportedGpu(gpu) + ? gpu?.GetAmdGfxArch() + : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); + + return !string.IsNullOrWhiteSpace(gfxArch) + && !gfxArch.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) + && !gfxArch.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) + && !gfxArch.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase); + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -1003,19 +1006,15 @@ InstalledPackage installedPackage if (rocmPackageHelper is not null) { - var rocmEnvironment = rocmPackageHelper - .BuildLaunchEnvironmentAsync(installLocation, installedPackage, WindowsRocmProfile) - .GetAwaiter() - .GetResult(); + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); return env.SetItems(rocmEnvironment); } - // set some experimental speed improving env vars for Windows ROCm - return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1") - .SetItem("MIOPEN_FIND_MODE", "2") - .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - .SetItem("PYTORCH_ALLOC_CONF", "max_split_size_mb:6144,garbage_collection_threshold:0.8") // greatly helps prevent GPU OOM and instability/driver timeouts/OS hard locks and decreases dependency on Tiled VAE at standard res's - .SetItem("COMFYUI_ENABLE_MIOPEN", "1"); // re-enables "cudnn" in ComfyUI as it's needed for MiOpen to function properly + return env; } } diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 2a00a626..e11fd322 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -6,9 +6,11 @@ using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -30,7 +32,8 @@ public class Wan2GP( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -41,6 +44,14 @@ IPipWheelService pipWheelService pipWheelService ) { + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + PackageName = "Wan2GP", + RequiresRocmSdk = true, + UpgradePackages = true, + PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], + }; + public override string Name => "Wan2GP"; public override string DisplayName { get; set; } = "Wan2GP"; public override string Author => "deepbeepmeep"; @@ -64,7 +75,7 @@ IPipWheelService pipWheelService public override bool IsCompatible => HardwareHelper.HasNvidiaGpu() - || (Compat.IsWindows ? HardwareHelper.HasWindowsRocmSupportedGpu() : HardwareHelper.HasAmdGpu()); + || (Compat.IsWindows ? HasWindowsRocmSupport() : HardwareHelper.HasAmdGpu()); public override string MainBranch => "main"; public override bool ShouldIgnoreReleases => true; @@ -72,7 +83,7 @@ IPipWheelService pipWheelService public override Dictionary> SharedOutputFolders => new() { [SharedOutputType.Img2Vid] = ["outputs"] }; - // AMD ROCm requires Python 3.11, NVIDIA uses 3.10 + // Wan2GP currently uses Python 3.11 for ROCm and 3.10 for CUDA. public override PyVersion RecommendedPythonVersion => IsAmdRocm ? Python.PyInstallationManager.Python_3_11_13 : Python.PyInstallationManager.Python_3_10_17; @@ -86,6 +97,17 @@ IPipWheelService pipWheelService /// private bool IsAmdRocm => GetRecommendedTorchVersion() == TorchIndex.Rocm; + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + return HardwareHelper.HasWindowsRocmSupportedGpu(); + + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile).IsCompatible; + } + /// /// Python wrapper script that patches logging to also print to stdout/stderr, so /// StabilityMatrix can capture the output. Wan2GP logs through Gradio UI notifications @@ -213,8 +235,8 @@ public override TorchIndex GetRecommendedTorchVersion() ( Compat.IsWindows && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() + WindowsRocmSupport.IsSupportedGpu(SettingsManager.Settings.PreferredGpu) + || HasWindowsRocmSupport() ) ) || ( @@ -256,7 +278,15 @@ public override async Task InstallPackage( if (torchIndex == TorchIndex.Rocm) { - await InstallAmdRocmAsync(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + await InstallAmdRocmAsync( + venvRunner, + installLocation, + installedPackage, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); } else { @@ -359,68 +389,53 @@ await venvRunner private async Task InstallAmdRocmAsync( IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, IProgress? progress, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { - progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); - await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - if (Compat.IsWindows) { - // Windows AMD ROCm - special TheRock wheels - progress?.Report( - new ProgressReport(-1f, "Installing PyTorch ROCm wheels...", isIndeterminate: true) - ); - - // Set environment variable for wheel filename check bypass - venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1")); + if (rocmPackageHelper is null) + { + throw new InvalidOperationException( + "Windows ROCm installation for Wan2GP requires the shared ROCm helper." + ); + } - // Install PyTorch ROCm wheels from TheRock releases (Python 3.11) - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl", - onConsoleOutput + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken ) .ConfigureAwait(false); - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + return; + } - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - } - else - { - // Linux AMD ROCm - standard PyTorch ROCm - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - - // Install torch with ROCm index (force reinstall to ensure correct version) - progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .WithTorch("==2.7.0") - .WithTorchVision("==0.22.0") - .WithTorchAudio("==2.7.0") - .WithTorchExtraIndex("rocm6.3") - .AddArg("--force-reinstall") - .AddArg("--no-deps"); - - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - } + progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); + await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .WithTorch() + .WithTorchVision() + .WithTorchAudio() + .WithTorchExtraIndex("rocm7.2") + .AddArg("--force-reinstall") + .AddArg("--no-deps"); + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); // Install additional packages await venvRunner.PipInstall("hf-xet setuptools numpy==1.26.4", onConsoleOutput).ConfigureAwait(false); @@ -437,6 +452,16 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); + if (Compat.IsWindows && rocmPackageHelper is not null && HasWindowsRocmSupport()) + { + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); + VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); + } + // Fix for distutils compatibility issue with Python 3.10 and setuptools VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib")); diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index 11c2bbfb..21ff6e7d 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -1,33 +1,68 @@ namespace StabilityMatrix.Core.Models.Rocm; /// -/// Controls how helper-generated, package-specific, and user-defined environment variables -/// should be layered together once the helper has real behavior. +/// Controls how ROCm helper defaults, package-specific variables, and user overrides are layered at launch. /// public class RocmEnvironmentOptions { - /// - /// Determines the merge order used when multiple environment sources provide the same key. - /// - public RocmEnvironmentOverlayPriority OverlayPriority { get; init; } = - RocmEnvironmentOverlayPriority.HelperThenPackageThenUser; - /// /// When true, package-specific environment additions may be merged on top of helper defaults. /// public bool IncludePackageOverrides { get; init; } = true; /// - /// When true, user-defined Stability Matrix environment variables may be merged last. + /// When true, user-defined Stability Matrix environment variables may override helper/package defaults last. /// public bool IncludeUserOverrides { get; init; } = true; + + /// + /// Selects a package-oriented ROCm environment preset managed by the helper. + /// + public RocmEnvironmentPreset Preset { get; init; } = RocmEnvironmentPreset.None; + + /// + /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. + /// + public string? PyTorchAllocConf { get; init; } = "max_split_size_mb:512,garbage_collection_threshold:0.8"; + + /// + /// When set, configures MIOpen find mode for helper-managed ROCm defaults. + /// + public string? MiopenFindMode { get; init; } = "2"; + + /// + /// When set, configures MIOpen search cutoff for helper-managed ROCm defaults. + /// + public string? MiopenSearchCutoff { get; init; } = "1"; + + /// + /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults. + /// + public string? MiopenFindEnforce { get; init; } = "3"; + + /// + /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults. + /// + public string? FlashAttentionTritonAmdEnable { get; init; } = "TRUE"; + + /// + /// When true, helper-managed defaults will enable ROCm AOTriton on modern Windows ROCm architectures. + /// + public bool ApplyAotritonExperimental { get; init; } = true; + + /// + /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures. + /// + public bool ApplyLegacySdpFallback { get; init; } = true; + + /// + /// When true, helper-managed defaults will apply the RDNA1 HSA override mask when needed. + /// + public bool ApplyRdna1Override { get; init; } = true; } -/// -/// Describes the intended precedence of environment sources for ROCm-enabled package launches. -/// -public enum RocmEnvironmentOverlayPriority +public enum RocmEnvironmentPreset { - HelperThenPackageThenUser, - HelperThenUserThenPackage, + None, + ComfyUi, } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs index d18d70d0..597eb4fe 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -1,5 +1,3 @@ -using StabilityMatrix.Core.Models; - namespace StabilityMatrix.Core.Models.Rocm; /// @@ -7,21 +5,7 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public class RocmInstallContext { - public string? PreferredGfxArch { get; init; } - public string? RuntimeGfxArch { get; init; } public string? RocmPackageIndexUrl { get; init; } - - public string? RocmTorchIndexUrl { get; init; } - - public TorchIndex TorchIndex { get; init; } = TorchIndex.Rocm; - - public string? WheelCompatibilityHints { get; init; } - - public string? SdkRoot { get; init; } - - public RocmSdkPaths SdkPaths { get; init; } = new(); - - public IReadOnlyDictionary Environment { get; init; } = new Dictionary(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index f2a9323e..a7baa675 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -13,32 +13,8 @@ public class RocmPackageProfile /// public string PackageName { get; init; } = string.Empty; - public bool RequiresWindows { get; init; } - public bool RequiresRocmSdk { get; init; } - public bool NeedsRuntimeGfxResolution { get; init; } - - public bool NeedsHipPath { get; init; } - - public bool NeedsRocmPath { get; init; } - - public bool NeedsTritonOverrideArch { get; init; } - - public bool NeedsRdna1Override { get; init; } - - public bool NeedsLegacySdpFallback { get; init; } - - public bool NeedsAotritonExperimental { get; init; } - - public bool NeedsTunableOpCache { get; init; } - - public bool NeedsTritonCache { get; init; } - - public bool NeedsMIOpenDbPaths { get; init; } - - public bool NeedsRocblasPaths { get; init; } - /// /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. /// @@ -69,12 +45,6 @@ public class RocmPackageProfile /// public bool ForceReinstallTorch { get; init; } = true; - /// - /// Optional callback for package-specific cache path variables. - /// The helper will eventually merge these with its own defaults. - /// - public Func>? CacheDirectoryFactory { get; init; } - /// /// Optional callback for package-specific environment variables derived from a resolved ROCm context. /// @@ -84,12 +54,7 @@ public Func< >? ExtraEnvironmentFactory { get; init; } /// - /// Optional progress message prefix or label that package code can surface during install/update work. - /// - public string? ProgressLabel { get; init; } - - /// - /// Controls how helper, package, and user-defined environment variables should be merged. + /// Controls whether package-specific environment variables should be layered on top of helper defaults. /// public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs index 87c88ba6..1fdda791 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -15,19 +15,4 @@ public class RocmRuntimeContext public GpuInfo? SelectedGpu { get; init; } public string? RuntimeGfxArch { get; init; } - - public bool IsLegacyGpu { get; init; } - - public bool IsRdna1 { get; init; } - - public string? HipPath { get; init; } - - public string? RocmPath { get; init; } - - public string? RocmSdkSitePackagesPath { get; init; } - - public RocmSdkPaths SdkPaths { get; init; } = new(); - - public IReadOnlyDictionary ResolvedEnvironment { get; init; } = - new Dictionary(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs deleted file mode 100644 index 5789744f..00000000 --- a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs +++ /dev/null @@ -1,22 +0,0 @@ -namespace StabilityMatrix.Core.Models.Rocm; - -/// -/// Represents ROCm SDK-related paths resolved for a package install. -/// These values are intentionally plain data so package code can decide which paths matter. -/// -public class RocmSdkPaths -{ - public string? RocmRoot { get; init; } - - public string? HipPath { get; init; } - - public string? RocmPath { get; init; } - - public string? RocmSdkSitePackagesPath { get; init; } - - public string? MioPenDbPath { get; init; } - - public string? RocblasDbPath { get; init; } - - public string? RocblasLibraryPath { get; init; } -} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs new file mode 100644 index 00000000..dec46d04 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -0,0 +1,46 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Centralizes Windows ROCm support policy so hardware detection, package selection, +/// and ROCm installation all use the same architecture support map. +/// +public static class WindowsRocmSupport +{ + public static bool IsSupportedGpu(GpuInfo? gpu) + { + if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + return IsSupportedArchitecture(gpu.GetAmdGfxArch()); + } + + public static bool IsSupportedArchitecture(string? gfxArch) + { + return TryGetPackageIndexUrl(gfxArch) is not null; + } + + public static string? TryGetPackageIndexUrl(string? gfxArch) + { + return gfxArch switch + { + "gfx900" => "https://rocm.nightlies.amd.com/v2-staging/gfx900/", // Vega 10 + "gfx906" => "https://rocm.nightlies.amd.com/v2-staging/gfx906/", // Radeon VII, Vega 20 + "gfx90X" => "https://rocm.nightlies.amd.com/v2-staging/gfx90X/", // Radeon Pro VII + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", // RDNA1 (5000 series, Pro) + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-all/", // RDNA2 (6000 series, 6xxM Mobile, Steam Deck) + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx110X-all/", // RDNA3 (7000 series, 7xxM Mobile) + "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", // RDNA3.5 (Strix/Gorgon Point) + "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", // RDNA3.5 (Strix Halo) + "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", // RDNA3.5 (Kraken Point) + "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", // RDNA3.5 (Medusa Point) + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx120X-all/", // RDNA4 (9000 series) + _ => null, + }; + } +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 0fac954e..03b4ce0c 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -14,69 +14,33 @@ public interface IRocmPackageHelper /// /// Evaluates whether the current machine and package profile are compatible with ROCm. /// - Task GetCompatibilityAsync( - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); + RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile); /// /// Resolves the runtime ROCm facts needed for package launch and environment construction. /// - Task ResolveRuntimeContextAsync( + RocmRuntimeContext ResolveRuntimeContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ); /// /// Resolves the ROCm facts needed during package installation or update operations. /// - Task ResolveInstallContextAsync( + RocmInstallContext ResolveInstallContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - - /// - /// Builds an install-time environment dictionary from a resolved install context. - /// - IReadOnlyDictionary BuildInstallEnvironment( - string installLocation, - RocmInstallContext context, RocmPackageProfile profile ); - /// - /// Re-resolves ROCm install facts after a package update changes dependencies or runtime state. - /// - Task RefreshPackageAfterUpdateAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - /// /// Builds a launch-time environment dictionary from resolved ROCm runtime data. /// - Task> BuildLaunchEnvironmentAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - - /// - /// Applies a resolved launch environment to the provided Python venv runner. - /// - Task ApplyLaunchEnvironmentAsync( - IPyVenvRunner venvRunner, + IReadOnlyDictionary BuildLaunchEnvironment( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ); /// diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index a2c700b6..19dced0b 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -22,49 +22,34 @@ namespace StabilityMatrix.Core.Services.Rocm; public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - - private static readonly string[] UnsupportedRdna2ModelMarkers = - [ - "680m", - "660m", - "610m", - "rx6300", - "w6300", - "rx6400", - "w6400", - "rx6450", - "rx6550", - ]; + private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; /// - public Task GetCompatibilityAsync( - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) + public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) { - return Task.FromResult(BuildCompatibilityResult(profile)); + return BuildCompatibilityResult(profile); } /// - public Task ResolveRuntimeContextAsync( + public RocmRuntimeContext ResolveRuntimeContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { + _ = installLocation; + _ = installedPackage; + var compatibility = BuildCompatibilityResult(profile); if (!compatibility.IsCompatible) { - return Task.FromResult( - new RocmRuntimeContext - { - IsSupported = false, - FailureReason = compatibility.FailureReason, - SelectedGpu = compatibility.SelectedGpu, - RuntimeGfxArch = compatibility.ResolvedGfxArch, - } - ); + return new RocmRuntimeContext + { + IsSupported = false, + FailureReason = compatibility.FailureReason, + SelectedGpu = compatibility.SelectedGpu, + RuntimeGfxArch = compatibility.ResolvedGfxArch, + }; } var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) @@ -81,29 +66,23 @@ public Task ResolveRuntimeContextAsync( ?? selectedGpu?.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - return Task.FromResult( - new RocmRuntimeContext - { - IsSupported = true, - SelectedGpu = selectedGpu, - RuntimeGfxArch = runtimeGfxArch, - IsLegacyGpu = IsLegacyArchitecture(runtimeGfxArch), - IsRdna1 = IsRdna1Architecture(runtimeGfxArch), - } - ); + return new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + }; } /// - public Task ResolveInstallContextAsync( + public RocmInstallContext ResolveInstallContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { _ = installLocation; _ = installedPackage; - _ = cancellationToken; var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) .Where(IsSupportedWindowsRocmGpu) @@ -115,65 +94,26 @@ public Task ResolveInstallContextAsync( ); var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - var windowsNativeIndexUrl = TryGetWindowsNativeRocmIndexUrl(runtimeGfxArch); - - return Task.FromResult( - new RocmInstallContext - { - PreferredGfxArch = preferredGfxArch, - RuntimeGfxArch = runtimeGfxArch, - RocmPackageIndexUrl = windowsNativeIndexUrl, - RocmTorchIndexUrl = windowsNativeIndexUrl, - } - ); - } - - /// - public IReadOnlyDictionary BuildInstallEnvironment( - string installLocation, - RocmInstallContext context, - RocmPackageProfile profile - ) - { - _ = installLocation; - _ = context; - _ = profile; - return new Dictionary(); - } + var windowsNativeIndexUrl = WindowsRocmSupport.TryGetPackageIndexUrl(runtimeGfxArch); - /// - public Task RefreshPackageAfterUpdateAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) - { - return ResolveInstallContextAsync(installLocation, installedPackage, profile, cancellationToken); + return new RocmInstallContext + { + RuntimeGfxArch = runtimeGfxArch, + RocmPackageIndexUrl = windowsNativeIndexUrl, + }; } /// - public Task> BuildLaunchEnvironmentAsync( + public IReadOnlyDictionary BuildLaunchEnvironment( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { - _ = installLocation; - _ = installedPackage; - - var runtimeContext = ResolveRuntimeContextAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .GetAwaiter() - .GetResult(); + var runtimeContext = ResolveRuntimeContext(installLocation, installedPackage, profile); if (!runtimeContext.IsSupported) - return Task.FromResult>(new Dictionary()); + return new Dictionary(); var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); var packageEnvironment = @@ -185,27 +125,7 @@ public Task> BuildLaunchEnvironmentAsync( profile.EnvironmentOptions ); - return Task.FromResult>(mergedEnvironment); - } - - /// - public async Task ApplyLaunchEnvironmentAsync( - IPyVenvRunner venvRunner, - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) - { - var environment = await BuildLaunchEnvironmentAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .ConfigureAwait(false); - - venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); + return mergedEnvironment; } /// @@ -219,7 +139,7 @@ public async Task InstallWindowsNativePackageAsync( CancellationToken cancellationToken = default ) { - var compatibility = await GetCompatibilityAsync(profile, cancellationToken).ConfigureAwait(false); + var compatibility = GetCompatibility(profile); if (!compatibility.IsCompatible) { throw new ApplicationException( @@ -228,18 +148,11 @@ public async Task InstallWindowsNativePackageAsync( ); } - var installContext = await ResolveInstallContextAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .ConfigureAwait(false); + var installContext = ResolveInstallContext(installLocation, installedPackage, profile); var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; - var rocmTorchIndexUrl = installContext.RocmTorchIndexUrl ?? rocmPackageIndexUrl; - if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl) || string.IsNullOrWhiteSpace(rocmTorchIndexUrl)) + if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl)) { throw new ApplicationException( $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." @@ -254,7 +167,7 @@ public async Task InstallWindowsNativePackageAsync( progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); var rocmRuntimeArgs = new PipInstallArgs() .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .AddArgs("rocm[devel,libraries]", "--no-warn-script-location"); + .AddArgs("rocm[devel,libraries]"); if (installedPackage.PipOverrides != null) { @@ -264,43 +177,10 @@ public async Task InstallWindowsNativePackageAsync( await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); - var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); - if (!File.Exists(rocmSdkExe)) - { - throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); - } - - using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( - rocmSdkExe, - ["init"], - installLocation, - onConsoleOutput - ); - - await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); - if (rocmSdkProcess.ExitCode != 0) - { - throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); - } - } - - progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .AddKeyedArgs("--index-url", ["--index-url", rocmTorchIndexUrl]) - .AddArgs("torch", "torchaudio", "torchvision", "--no-warn-script-location"); - - if (profile.ForceReinstallTorch) - { - torchArgs = torchArgs.AddArg("--force-reinstall"); - } - - if (installedPackage.PipOverrides != null) - { - torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); } - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - progress?.Report( new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) ); @@ -334,18 +214,56 @@ public async Task InstallWindowsNativePackageAsync( await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); - if (!profile.PostInstallPipArgs.Any()) - return; + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .AddArg("--pre") + .AddArg("--upgrade") + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .WithTorch() + .WithTorchAudio() + .WithTorchVision(); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } - var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); if (installedPackage.PipOverrides != null) { - postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); } - await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await AlignRocmSdkDevelVersionAsync(venvRunner, rocmPackageIndexUrl, onConsoleOutput) + .ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Reinitializing ROCm SDK...", isIndeterminate: true)); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + + if (profile.PostInstallPipArgs.Any()) + { + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + } - await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput).ConfigureAwait(false); + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } } /// @@ -354,15 +272,6 @@ public async Task InstallWindowsNativePackageAsync( /// private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) { - if (profile.RequiresWindows && !Compat.IsWindows) - { - return new RocmCompatibilityResult - { - IsCompatible = false, - FailureReason = "This ROCm profile currently requires Windows.", - }; - } - var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); if (amdGpus.Count == 0) { @@ -374,15 +283,6 @@ private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile prof } var preferredGpu = settingsManager.Settings.PreferredGpu; - if (preferredGpu is not null && IsExplicitlyUnsupportedRdna2Gpu(preferredGpu)) - { - return new RocmCompatibilityResult - { - IsCompatible = false, - FailureReason = $"Selected GPU '{preferredGpu.Name}' is unsupported for Windows ROCm.", - SelectedGpu = preferredGpu, - }; - } var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); if (supportedAmdGpus.Count == 0) @@ -471,29 +371,10 @@ private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = fa /// /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. - /// Unsupported low-end RDNA2/APU models are filtered explicitly even when they identify as AMD hardware. /// private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) { - if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) - return false; - - if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) - return false; - - return TryGetWindowsNativeRocmIndexUrl(gpu.GetAmdGfxArch()) is not null; - } - - /// - /// Identifies Windows ROCm-incompatible RDNA2 models that need to remain outside the supported GPU set. - /// - private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) - { - if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) - return false; - - var normalizedName = gpu.Name.Replace(" ", string.Empty, StringComparison.Ordinal).ToLowerInvariant(); - return UnsupportedRdna2ModelMarkers.Any(normalizedName.Contains); + return WindowsRocmSupport.IsSupportedGpu(gpu); } /// @@ -501,50 +382,7 @@ private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) /// private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) { - return TryGetWindowsNativeRocmIndexUrl(gfxArch) is not null; - } - - /// - /// Maps an AMD GFX architecture identifier to the Windows-native ROCm Technical Preview feed URL. - /// - private static string? TryGetWindowsNativeRocmIndexUrl(string? gfxArch) - { - return gfxArch switch - { - var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", - var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/", - var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx110X-all/", - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", - "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", - "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", - var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx120X-all/", - _ => null, - }; - } - - /// - /// Returns true for architectures that need the legacy ROCm runtime path. - /// - private static bool IsLegacyArchitecture(string? gfxArch) - { - return gfxArch is not null - && ( - gfxArch.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) - || gfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) - ); - } - - /// - /// Returns true for RDNA1 architectures that need dedicated override handling. - /// - private static bool IsRdna1Architecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + return WindowsRocmSupport.IsSupportedArchitecture(gfxArch); } /// @@ -552,20 +390,17 @@ private static bool IsRdna1Architecture(string? gfxArch) /// private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) { - if (amdGpus.Any(IsExplicitlyUnsupportedRdna2Gpu)) - { - return "Detected only unsupported AMD RDNA2 GPUs for Windows ROCm. Unsupported models include Radeon 680M/660M/610M and RX 6300/6400/6450/6550-class GPUs."; - } - + _ = amdGpus; return "No AMD GPU with a supported Windows ROCm architecture was detected."; } /// - /// Verifies that the installed torch build still reports a usable ROCm runtime after helper-managed installs complete. + /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. /// private static async Task VerifyWindowsNativeTorchInstallAsync( IPyVenvRunner venvRunner, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); @@ -586,10 +421,16 @@ private static async Task VerifyWindowsNativeTorchInstallAsync( throw new ApplicationException("Torch verification produced no output."); } + var verificationJson = TryExtractJsonObject(verificationOutput); + if (string.IsNullOrWhiteSpace(verificationJson)) + { + throw new ApplicationException($"Unexpected torch verification output: {verificationOutput}"); + } + JsonDocument verificationDocument; try { - verificationDocument = JsonDocument.Parse(verificationOutput); + verificationDocument = JsonDocument.Parse(verificationJson); } catch (Exception exception) { @@ -608,66 +449,341 @@ private static async Task VerifyWindowsNativeTorchInstallAsync( var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); - if (string.IsNullOrWhiteSpace(hipVersion) || !cudaAvailable) + if (!IsUsableWindowsNativeTorchBuild(version, hipVersion)) { throw new ApplicationException( $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" ); } + if (!cudaAvailable) + { + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Torch verification warning: installed ROCm torch build reported cuda={cudaAvailable}; continuing because ROCm metadata was detected (version={version}, hip={hipVersion})." + ) + ); + } + onConsoleOutput?.Invoke( ProcessOutput.FromStdOutLine( $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" ) ); } + + _ = cancellationToken; + } + + /// + /// Runs rocm-sdk init after the helper-managed runtime packages are installed so the Windows ROCm SDK can prepare the venv. + /// + private static async Task InitializeWindowsNativeRocmSdkAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( + rocmSdkExe, + ["init"], + installLocation, + onConsoleOutput + ); + + await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + if (rocmSdkProcess.ExitCode != 0) + { + throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); + } + } + + /// + /// Uses AMD's bundled hipInfo.exe to confirm the installed Windows ROCm runtime can enumerate a ROCm-capable GPU. + /// + private static async Task VerifyWindowsNativeRocmRuntimeAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + var rocmBinResult = await ProcessRunner + .GetProcessResultAsync(rocmSdkExe, ["path", "--bin"], installLocation, useUtf8Encoding: true) + .ConfigureAwait(false); + + var rocmBinPath = (rocmBinResult.StandardOutput ?? string.Empty).Trim(); + if (!rocmBinResult.IsSuccessExitCode || string.IsNullOrWhiteSpace(rocmBinPath)) + { + var rocmBinOutput = CombineProcessOutput( + rocmBinResult.StandardOutput, + rocmBinResult.StandardError + ); + throw new ApplicationException( + $"ROCm runtime verification failed while resolving the ROCm SDK bin path. Output: {rocmBinOutput}" + ); + } + + var hipInfoExe = Path.Combine(rocmBinPath, $"hipInfo{Compat.ExeExtension}"); + if (!File.Exists(hipInfoExe)) + { + throw new FileNotFoundException( + "hipInfo.exe was not found in the ROCm SDK bin directory", + hipInfoExe + ); + } + + var hipInfoResult = await ProcessRunner + .GetProcessResultAsync( + hipInfoExe, + [], + installLocation, + new Dictionary { ["PATH"] = rocmBinPath }, + useUtf8Encoding: true + ) + .ConfigureAwait(false); + + var hipInfoOutput = CombineProcessOutput(hipInfoResult.StandardOutput, hipInfoResult.StandardError); + if (!hipInfoResult.IsSuccessExitCode) + { + var runtimeFailureReason = TryGetWindowsNativeRocmRuntimeFailureReason(hipInfoOutput); + throw new ApplicationException( + runtimeFailureReason is null + ? $"ROCm runtime verification failed while probing the installed runtime with hipInfo.exe. Output: {hipInfoOutput}" + : $"ROCm runtime verification failed: {runtimeFailureReason} Output: {hipInfoOutput}" + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"ROCm runtime verification succeeded via hipInfo.exe: {hipInfoOutput}" + ) + ); + + _ = cancellationToken; + } + + /// + /// Reinstalls rocm-sdk-devel to the resolved ROCm build version when the torch step downgrades the runtime stack. + /// + private static async Task AlignRocmSdkDevelVersionAsync( + IPyVenvRunner venvRunner, + string rocmPackageIndexUrl, + Action? onConsoleOutput + ) + { + var rocmInfo = await venvRunner.PipShow("rocm").ConfigureAwait(false); + var rocmSdkDevelInfo = await venvRunner.PipShow("rocm-sdk-devel").ConfigureAwait(false); + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + + var targetVersion = GetRocmSdkDevelAlignmentVersion( + rocmInfo?.Version, + rocmSdkDevelInfo?.Version, + torchInfo?.Version + ); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return; + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Aligning rocm-sdk-devel from version={rocmSdkDevelInfo?.Version ?? "not-installed"} to version={targetVersion} to match the resolved ROCm torch/runtime build." + ) + ); + + var alignmentArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArg("--force-reinstall") + .AddArg($"rocm-sdk-devel=={targetVersion}"); + + await venvRunner.PipInstall(alignmentArgs, onConsoleOutput).ConfigureAwait(false); + } + + internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion) + { + if (!string.IsNullOrWhiteSpace(hipVersion)) + return true; + + return !string.IsNullOrWhiteSpace(version) + && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); + } + + internal static string? GetRocmSdkDevelAlignmentVersion( + string? rocmVersion, + string? rocmSdkDevelVersion, + string? torchVersion = null + ) + { + var targetVersion = !string.IsNullOrWhiteSpace(rocmVersion) + ? rocmVersion + : TryExtractRocmBuildVersion(torchVersion); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return null; + + return string.Equals(targetVersion, rocmSdkDevelVersion, StringComparison.OrdinalIgnoreCase) + ? null + : targetVersion; + } + + internal static string? TryGetWindowsNativeRocmRuntimeFailureReason(string? output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + if (output.Contains("no ROCm-capable device is detected", StringComparison.OrdinalIgnoreCase)) + { + return "the installed ROCm runtime could not detect a ROCm-capable GPU on this system."; + } + + if (output.Contains("No WDDM adapters found", StringComparison.OrdinalIgnoreCase)) + { + return "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack."; + } + + return null; + } + + internal static string? TryExtractRocmBuildVersion(string? torchVersion) + { + if (string.IsNullOrWhiteSpace(torchVersion)) + return null; + + var rocmMarkerIndex = torchVersion.IndexOf("rocm", StringComparison.OrdinalIgnoreCase); + if (rocmMarkerIndex < 0) + return null; + + var rocmBuildVersion = torchVersion[(rocmMarkerIndex + "rocm".Length)..].Trim(); + return string.IsNullOrWhiteSpace(rocmBuildVersion) ? null : rocmBuildVersion; + } + + internal static string? TryExtractJsonObject(string output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + var trimmedOutput = output.Trim(); + + for (var index = 0; index < trimmedOutput.Length; index++) + { + if (trimmedOutput[index] != '{') + continue; + + try + { + using var document = JsonDocument.Parse(trimmedOutput[index..]); + return document.RootElement.GetRawText(); + } + catch (JsonException) { } + } + + return null; + } + + internal static string CombineProcessOutput(string? standardOutput, string? standardError) + { + var sections = new[] { standardOutput?.Trim(), standardError?.Trim() }.Where(section => + !string.IsNullOrWhiteSpace(section) + ); + + return string.Join(Environment.NewLine, sections); } /// /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. /// - private static IReadOnlyDictionary BuildHelperLaunchEnvironment( + private IReadOnlyDictionary BuildHelperLaunchEnvironment( RocmRuntimeContext runtimeContext, RocmPackageProfile profile ) { - var environment = new Dictionary(); + var environment = new Dictionary(EnvComparer); + var options = profile.EnvironmentOptions; + var gfxArch = runtimeContext.RuntimeGfxArch; + + ApplyPresetLaunchEnvironment(environment, gfxArch, options); + + return environment; + } + + private void ApplyPresetLaunchEnvironment( + IDictionary environment, + string? gfxArch, + RocmEnvironmentOptions options + ) + { + SetIfNotNull(environment, "FLASH_ATTENTION_TRITON_AMD_ENABLE", options.FlashAttentionTritonAmdEnable); + SetIfNotNull(environment, "MIOPEN_FIND_MODE", options.MiopenFindMode); + SetIfNotNull(environment, "MIOPEN_SEARCH_CUTOFF", options.MiopenSearchCutoff); + SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); + SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); - if (profile.NeedsTunableOpCache) + if (options.ApplyAotritonExperimental && IsModernWindowsRocmArchitecture(gfxArch)) { - environment["PYTORCH_TUNABLEOP_ENABLED"] = "1"; + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; } - if (profile.NeedsAotritonExperimental) + if (!IsModernWindowsRocmArchitecture(gfxArch) && options.ApplyLegacySdpFallback) { - environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; } - if (profile.NeedsTritonOverrideArch && !string.IsNullOrWhiteSpace(runtimeContext.RuntimeGfxArch)) + if (options.ApplyRdna1Override && IsRdna1Architecture(gfxArch)) { - environment["HSA_OVERRIDE_GFX_VERSION"] = runtimeContext.RuntimeGfxArch; + environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; } - return environment; + if (options.Preset == RocmEnvironmentPreset.ComfyUi && IsModernWindowsRocmArchitecture(gfxArch)) + { + environment["COMFYUI_ENABLE_MIOPEN"] = "1"; + } + } + + private static bool IsModernWindowsRocmArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + private static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + + private static void SetIfNotNull(IDictionary environment, string key, string? value) + { + if (!string.IsNullOrWhiteSpace(value)) + { + environment[key] = value; + } } /// - /// Merges helper-owned and package-specific launch environment variables using the profile overlay rules. + /// Merges helper-owned and package-specific launch environment variables. /// - private static IReadOnlyDictionary MergeLaunchEnvironment( + private IReadOnlyDictionary MergeLaunchEnvironment( IReadOnlyDictionary helperEnvironment, IReadOnlyDictionary packageEnvironment, RocmEnvironmentOptions options ) { - var merged = new Dictionary(); + var merged = new Dictionary(EnvComparer); - IReadOnlyDictionary[] orderedSources = - options.OverlayPriority == RocmEnvironmentOverlayPriority.HelperThenUserThenPackage - ? new[] { helperEnvironment, packageEnvironment } - : new[] { helperEnvironment, packageEnvironment }; - - foreach (var source in orderedSources) + foreach (var source in new[] { helperEnvironment, packageEnvironment }) { if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) continue; @@ -678,6 +794,17 @@ RocmEnvironmentOptions options } } + if ( + options.IncludeUserOverrides + && settingsManager.Settings.EnvironmentVariables is { Count: > 0 } userOverrides + ) + { + foreach (var pair in userOverrides) + { + merged[pair.Key] = pair.Value; + } + } + return merged; } } diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs new file mode 100644 index 00000000..0afb7b86 --- /dev/null +++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs @@ -0,0 +1,176 @@ +using System.Text.Json; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class RocmPackageHelperTests +{ + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsRocmVersion_WhenVersionsMismatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260501" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsNull_WhenVersionsAlreadyMatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260416" + ); + + Assert.IsNull(targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_FallsBackToTorchBuildVersion() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: null, + rocmSdkDevelVersion: "7.13.0a20260501", + torchVersion: "2.11.0+rocm7.13.0a20260416" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsNull_WhenTorchVersionHasNoRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0"); + + Assert.IsNull(rocmBuildVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsVersionSuffix_WhenTorchVersionContainsRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0+rocm7.13.0a20260416"); + + Assert.AreEqual("7.13.0a20260416", rocmBuildVersion); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenHipMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: "test-hip-version" + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenVersionContainsRocm() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version+rocm", + hipVersion: null + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsFalse_WhenNoRocmMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: null + ); + + Assert.IsFalse(isUsable); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsJson_WhenOutputContainsDiagnosticPrefix() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output" + + "\nwarning: continuing with torch verification" + + "\n{\"version\": \"test-version\", \"hip\": \"test-hip-version\", \"cuda\": false}"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNotNull(json); + + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + + Assert.AreEqual("test-version", root.GetProperty("version").GetString()); + Assert.AreEqual("test-hip-version", root.GetProperty("hip").GetString()); + Assert.IsFalse(root.GetProperty("cuda").GetBoolean()); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output\n" + + "warning: no JSON payload was produced"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNull(json); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsDeviceDetectionMessage() + { + const string output = "checkHipErrors() HIP API error = 0100 \"no ROCm-capable device is detected\""; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.", + reason + ); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsWddmMessage() + { + const string output = "warning: No WDDM adapters found."; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.", + reason + ); + } + + [TestMethod] + public void CombineProcessOutput_JoinsStdoutAndStderr() + { + var combined = RocmPackageHelper.CombineProcessOutput("stdout line", "stderr line"); + + Assert.AreEqual($"stdout line{Environment.NewLine}stderr line", combined); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetPackageIndexUrl_ReturnsExpectedIndex_ForKrakenPoint() + { + var indexUrl = WindowsRocmSupport.TryGetPackageIndexUrl("gfx1152"); + + Assert.AreEqual("https://rocm.nightlies.amd.com/v2-staging/gfx1152/", indexUrl); + } + + [TestMethod] + public void WindowsRocmSupport_IsSupportedGpu_ReturnsTrue_ForSupportedAmdGpu() + { + var gpu = new GpuInfo { Name = "AMD Radeon RX 9070 XT", MemoryBytes = 16UL * Size.GiB }; + + Assert.IsTrue(WindowsRocmSupport.IsSupportedGpu(gpu)); + } +} diff --git a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs index f7802703..bf45d60c 100644 --- a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs +++ b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs @@ -24,6 +24,8 @@ public void Setup() null!, null!, null!, + null!, + null!, null! ); } From bd7ddfd045820eefb35feee5184ff76dcba753f4 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 22:02:47 -0400 Subject: [PATCH 06/10] refactor shared Windows ROCm policy and package launch defaults - centralize Windows ROCm architecture classification and legacy-attention fallback policy in WindowsRocmSupport - move ComfyUI-specific MIOpen env handling out of the helper and into package-owned ROCm config - reuse shared ROCm policy for ComfyUI quad-attention defaults and helper-managed AOTriton / math SDP / RDNA1 gates - remove dead ROCm preset plumbing and trim unused RocmPackageProfile surface - rename helper/package methods for clearer default-policy semantics --- .../Models/Packages/ComfyUI.cs | 42 +++++++++++-------- .../Models/Packages/Wan2GP.cs | 1 - .../Models/Rocm/RocmEnvironmentOptions.cs | 11 ----- .../Models/Rocm/RocmPackageProfile.cs | 5 --- .../Models/Rocm/WindowsRocmSupport.cs | 26 +++++++++++- .../Services/Rocm/RocmPackageHelper.cs | 27 +++--------- 6 files changed, 53 insertions(+), 59 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 1832a27a..7bfd0c36 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -42,16 +42,6 @@ public class ComfyUI( { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - private static readonly RocmPackageProfile WindowsRocmProfile = new() - { - PackageName = "ComfyUI", - RequiresRocmSdk = true, - ExtraInstallPipArgs = ["numpy<2"], - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - UpgradePackages = true, - EnvironmentOptions = new RocmEnvironmentOptions { Preset = RocmEnvironmentPreset.ComfyUi }, - }; - public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -278,7 +268,7 @@ public class ComfyUI( { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = ShouldDefaultToQuadCrossAttention() + InitialValue = DefaultToQuadCrossAttention() ? "--use-quad-cross-attention" : "--use-pytorch-cross-attention", Options = @@ -610,9 +600,26 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } - /// + /// Windows ROCm install profile for ComfyUI. + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + RequiresRocmSdk = true, + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, + }; + + private static IReadOnlyDictionary BuildComfyWindowsRocmEnvironment( + RocmRuntimeContext runtimeContext + ) + { + return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) + ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } + : new Dictionary(); + } + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. - /// private bool HasWindowsRocmSupport() { if (!Compat.IsWindows) @@ -626,7 +633,9 @@ private bool HasWindowsRocmSupport() return compatibility.IsCompatible; } - private bool ShouldDefaultToQuadCrossAttention() + /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower + /// and not as supported on older AMD architectures. + private bool DefaultToQuadCrossAttention() { if (!Compat.IsWindows || !HasWindowsRocmSupport()) return false; @@ -636,10 +645,7 @@ private bool ShouldDefaultToQuadCrossAttention() ? gpu?.GetAmdGfxArch() : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - return !string.IsNullOrWhiteSpace(gfxArch) - && !gfxArch.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) - && !gfxArch.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) - && !gfxArch.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase); + return WindowsRocmSupport.PreferLegacyAttentionFallback(gfxArch); } public override IPackageExtensionManager ExtensionManager => diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index e11fd322..0c91f23a 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -46,7 +46,6 @@ public class Wan2GP( { private static readonly RocmPackageProfile WindowsRocmProfile = new() { - PackageName = "Wan2GP", RequiresRocmSdk = true, UpgradePackages = true, PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index 21ff6e7d..e0126170 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -15,11 +15,6 @@ public class RocmEnvironmentOptions /// public bool IncludeUserOverrides { get; init; } = true; - /// - /// Selects a package-oriented ROCm environment preset managed by the helper. - /// - public RocmEnvironmentPreset Preset { get; init; } = RocmEnvironmentPreset.None; - /// /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. /// @@ -60,9 +55,3 @@ public class RocmEnvironmentOptions /// public bool ApplyRdna1Override { get; init; } = true; } - -public enum RocmEnvironmentPreset -{ - None, - ComfyUi, -} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index a7baa675..d1abf16f 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -8,11 +8,6 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public class RocmPackageProfile { - /// - /// Logical package name for diagnostics and profile-specific decisions. - /// - public string PackageName { get; init; } = string.Empty; - public bool RequiresRocmSdk { get; init; } /// diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs index dec46d04..ee000b1b 100644 --- a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -3,8 +3,8 @@ namespace StabilityMatrix.Core.Models.Rocm; /// -/// Centralizes Windows ROCm support policy so hardware detection, package selection, -/// and ROCm installation all use the same architecture support map. +/// Centralizes Windows ROCm support and architecture policy so hardware detection, package selection, +/// installation, and shared launch decisions use the same support map. /// public static class WindowsRocmSupport { @@ -21,6 +21,28 @@ public static bool IsSupportedArchitecture(string? gfxArch) return TryGetPackageIndexUrl(gfxArch) is not null; } + public static bool IsModernArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + public static bool IsLegacyArchitecture(string? gfxArch) + { + return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch); + } + + public static bool PreferLegacyAttentionFallback(string? gfxArch) + { + return IsLegacyArchitecture(gfxArch); + } + + public static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + public static string? TryGetPackageIndexUrl(string? gfxArch) { return gfxArch switch diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 19dced0b..664e27d9 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -712,12 +712,12 @@ RocmPackageProfile profile var options = profile.EnvironmentOptions; var gfxArch = runtimeContext.RuntimeGfxArch; - ApplyPresetLaunchEnvironment(environment, gfxArch, options); + ApplyDefaultLaunchEnvironment(environment, gfxArch, options); return environment; } - private void ApplyPresetLaunchEnvironment( + private void ApplyDefaultLaunchEnvironment( IDictionary environment, string? gfxArch, RocmEnvironmentOptions options @@ -729,39 +729,22 @@ RocmEnvironmentOptions options SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); - if (options.ApplyAotritonExperimental && IsModernWindowsRocmArchitecture(gfxArch)) + if (options.ApplyAotritonExperimental && WindowsRocmSupport.IsModernArchitecture(gfxArch)) { environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; } - if (!IsModernWindowsRocmArchitecture(gfxArch) && options.ApplyLegacySdpFallback) + if (options.ApplyLegacySdpFallback && WindowsRocmSupport.IsLegacyArchitecture(gfxArch)) { environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; } - if (options.ApplyRdna1Override && IsRdna1Architecture(gfxArch)) + if (options.ApplyRdna1Override && WindowsRocmSupport.IsRdna1Architecture(gfxArch)) { environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; } - - if (options.Preset == RocmEnvironmentPreset.ComfyUi && IsModernWindowsRocmArchitecture(gfxArch)) - { - environment["COMFYUI_ENABLE_MIOPEN"] = "1"; - } - } - - private static bool IsModernWindowsRocmArchitecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true - || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true - || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; - } - - private static bool IsRdna1Architecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; } private static void SetIfNotNull(IDictionary environment, string key, string? value) From a4fbb6445f6b9a30becf8870032bf1bbb1690357 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 22:05:32 -0400 Subject: [PATCH 07/10] add comment for legacy AMD GPU support in cross attention method --- StabilityMatrix.Core/Models/Packages/ComfyUI.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 7bfd0c36..759406c3 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -269,7 +269,7 @@ public class ComfyUI( Name = "Cross Attention Method", Type = LaunchOptionType.Bool, InitialValue = DefaultToQuadCrossAttention() - ? "--use-quad-cross-attention" + ? "--use-quad-cross-attention" // For Legacy AMD GPUs. : "--use-pytorch-cross-attention", Options = [ From 0165a7baaf5b524bc3a8715c4194107bcf20b235 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:32:35 -0400 Subject: [PATCH 08/10] Change forceRefresh parameter to false in GetAmdGpuCandidates Was used during debugging and was unintentionally left on. --- StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 664e27d9..7ff23bc2 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -272,7 +272,7 @@ await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, canc /// private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) { - var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); + var amdGpus = GetAmdGpuCandidates(forceRefresh: false).ToList(); if (amdGpus.Count == 0) { return new RocmCompatibilityResult From 215991ff6c5ef94bd44cf2dfba8f5c389321c013 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:46:49 -0400 Subject: [PATCH 09/10] Change exception type from ApplicationException to InvalidOperationException --- StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 7ff23bc2..a5382ac9 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -142,7 +142,7 @@ public async Task InstallWindowsNativePackageAsync( var compatibility = GetCompatibility(profile); if (!compatibility.IsCompatible) { - throw new ApplicationException( + throw new InvalidOperationException( compatibility.FailureReason ?? "Windows ROCm installation is not supported for the current machine." ); From 4cbf5852a4546343575540069b89aad506fb3e76 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:53:54 -0400 Subject: [PATCH 10/10] Add rocmPackageHelper dependency to Wan2GP --- StabilityMatrix.Core/Helper/Factory/PackageFactory.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 6a073986..0cb0b808 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -287,7 +287,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)), };