From a03623b86d016d097d49661fc0b5f2809b5a6092 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 19:02:05 -0400 Subject: [PATCH 01/18] Scaffold universal ROCm helper types and service - Add initial ROCm helper structure - Set up ROCm helper foundation Compile test sucessful. --- .../Models/Rocm/RocmCompatibilityResult.cs | 17 +++ .../Models/Rocm/RocmEnvironmentOptions.cs | 33 ++++++ .../Models/Rocm/RocmInstallContext.cs | 23 ++++ .../Models/Rocm/RocmPackageProfile.cs | 65 +++++++++++ .../Models/Rocm/RocmRuntimeContext.cs | 33 ++++++ .../Models/Rocm/RocmSdkPaths.cs | 22 ++++ .../Services/Rocm/IRocmPackageHelper.cs | 79 ++++++++++++++ .../Services/Rocm/RocmPackageHelper.cs | 103 ++++++++++++++++++ 8 files changed, 375 insertions(+) create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs create mode 100644 StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs create mode 100644 StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs diff --git a/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs new file mode 100644 index 00000000..401f3ada --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs @@ -0,0 +1,17 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Describes whether a package/profile is currently compatible with ROCm on the active machine. +/// +public class RocmCompatibilityResult +{ + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? ResolvedGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs new file mode 100644 index 00000000..11c2bbfb --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -0,0 +1,33 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Controls how helper-generated, package-specific, and user-defined environment variables +/// should be layered together once the helper has real behavior. +/// +public class RocmEnvironmentOptions +{ + /// + /// Determines the merge order used when multiple environment sources provide the same key. + /// + public RocmEnvironmentOverlayPriority OverlayPriority { get; init; } = + RocmEnvironmentOverlayPriority.HelperThenPackageThenUser; + + /// + /// When true, package-specific environment additions may be merged on top of helper defaults. + /// + public bool IncludePackageOverrides { get; init; } = true; + + /// + /// When true, user-defined Stability Matrix environment variables may be merged last. + /// + public bool IncludeUserOverrides { get; init; } = true; +} + +/// +/// Describes the intended precedence of environment sources for ROCm-enabled package launches. +/// +public enum RocmEnvironmentOverlayPriority +{ + HelperThenPackageThenUser, + HelperThenUserThenPackage, +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs new file mode 100644 index 00000000..2055dd71 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -0,0 +1,23 @@ +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures ROCm-related facts needed during package install or update flows. +/// +public class RocmInstallContext +{ + public string? PreferredGfxArch { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public TorchIndex TorchIndex { get; init; } = TorchIndex.Rocm; + + public string? WheelCompatibilityHints { get; init; } + + public string? SdkRoot { get; init; } + + public RocmSdkPaths SdkPaths { get; init; } = new(); + + public IReadOnlyDictionary Environment { get; init; } = new Dictionary(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs new file mode 100644 index 00000000..a15c247d --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -0,0 +1,65 @@ +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Declares what a package expects from the ROCm helper. +/// Package classes should describe intent here rather than hardcoding ROCm decisions inline. +/// +public class RocmPackageProfile +{ + /// + /// Logical package name for diagnostics and profile-specific decisions. + /// + public string PackageName { get; init; } = string.Empty; + + public bool RequiresWindows { get; init; } + + public bool RequiresRocmSdk { get; init; } + + public bool NeedsRuntimeGfxResolution { get; init; } + + public bool NeedsHipPath { get; init; } + + public bool NeedsRocmPath { get; init; } + + public bool NeedsTritonOverrideArch { get; init; } + + public bool NeedsRdna1Override { get; init; } + + public bool NeedsLegacySdpFallback { get; init; } + + public bool NeedsAotritonExperimental { get; init; } + + public bool NeedsTunableOpCache { get; init; } + + public bool NeedsTritonCache { get; init; } + + public bool NeedsMIOpenDbPaths { get; init; } + + public bool NeedsRocblasPaths { get; init; } + + /// + /// Optional callback for package-specific cache path variables. + /// The helper will eventually merge these with its own defaults. + /// + public Func>? CacheDirectoryFactory { get; init; } + + /// + /// Optional callback for package-specific environment variables derived from a resolved ROCm context. + /// + public Func< + RocmRuntimeContext, + IReadOnlyDictionary + >? ExtraEnvironmentFactory { get; init; } + + /// + /// Optional progress message prefix or label that package code can surface during install/update work. + /// + public string? ProgressLabel { get; init; } + + /// + /// Controls how helper, package, and user-defined environment variables should be merged. + /// + public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs new file mode 100644 index 00000000..87c88ba6 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -0,0 +1,33 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures resolved ROCm facts for a package launch or runtime decision. +/// This model is intended to separate hardware/runtime facts from package policy. +/// +public class RocmRuntimeContext +{ + public bool IsSupported { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public bool IsLegacyGpu { get; init; } + + public bool IsRdna1 { get; init; } + + public string? HipPath { get; init; } + + public string? RocmPath { get; init; } + + public string? RocmSdkSitePackagesPath { get; init; } + + public RocmSdkPaths SdkPaths { get; init; } = new(); + + public IReadOnlyDictionary ResolvedEnvironment { get; init; } = + new Dictionary(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs new file mode 100644 index 00000000..5789744f --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs @@ -0,0 +1,22 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Represents ROCm SDK-related paths resolved for a package install. +/// These values are intentionally plain data so package code can decide which paths matter. +/// +public class RocmSdkPaths +{ + public string? RocmRoot { get; init; } + + public string? HipPath { get; init; } + + public string? RocmPath { get; init; } + + public string? RocmSdkSitePackagesPath { get; init; } + + public string? MioPenDbPath { get; init; } + + public string? RocblasDbPath { get; init; } + + public string? RocblasLibraryPath { get; init; } +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs new file mode 100644 index 00000000..5b9383fc --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -0,0 +1,79 @@ +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Defines the ROCm helper surface area shared by ROCm-capable packages. +/// +public interface IRocmPackageHelper +{ + /// + /// Evaluates whether the current machine and package profile are compatible with ROCm. + /// + Task GetCompatibilityAsync( + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Resolves the runtime ROCm facts needed for package launch and environment construction. + /// + Task ResolveRuntimeContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Resolves the ROCm facts needed during package installation or update operations. + /// + Task ResolveInstallContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Builds an install-time environment dictionary from a resolved install context. + /// + IReadOnlyDictionary BuildInstallEnvironment( + string installLocation, + RocmInstallContext context, + RocmPackageProfile profile + ); + + /// + /// Re-resolves ROCm install facts after a package update changes dependencies or runtime state. + /// + Task RefreshPackageAfterUpdateAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Builds a launch-time environment dictionary from resolved ROCm runtime data. + /// + Task> BuildLaunchEnvironmentAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); + + /// + /// Applies a resolved launch environment to the provided Python venv runner. + /// + Task ApplyLaunchEnvironmentAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ); +} diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs new file mode 100644 index 00000000..c353dc15 --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -0,0 +1,103 @@ +using System.Collections.Immutable; +using Injectio.Attributes; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Provides the shared ROCm helper surface area used by ROCm-capable packages. +/// +[RegisterSingleton] +public class RocmPackageHelper : IRocmPackageHelper +{ + private const string NotImplementedMessage = "ROCm helper behavior has not been implemented yet."; + + /// + public Task GetCompatibilityAsync( + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + new RocmCompatibilityResult { IsCompatible = false, FailureReason = NotImplementedMessage } + ); + } + + /// + public Task ResolveRuntimeContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + new RocmRuntimeContext { IsSupported = false, FailureReason = NotImplementedMessage } + ); + } + + /// + public Task ResolveInstallContextAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult(new RocmInstallContext()); + } + + /// + public IReadOnlyDictionary BuildInstallEnvironment( + string installLocation, + RocmInstallContext context, + RocmPackageProfile profile + ) + { + return new Dictionary(); + } + + /// + public Task RefreshPackageAfterUpdateAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult(new RocmInstallContext()); + } + + /// + public Task> BuildLaunchEnvironmentAsync( + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult>(new Dictionary()); + } + + /// + public async Task ApplyLaunchEnvironmentAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + CancellationToken cancellationToken = default + ) + { + var environment = await BuildLaunchEnvironmentAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .ConfigureAwait(false); + + venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); + } +} From 218aff9a764a6131cae0e63f43170053d6dc73ea Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 19:17:43 -0400 Subject: [PATCH 02/18] Implement ROCm GPU detection helper --- .../Services/Rocm/RocmPackageHelper.cs | 288 +++++++++++++++++- 1 file changed, 280 insertions(+), 8 deletions(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index c353dc15..1adb141e 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -1,8 +1,12 @@ using System.Collections.Immutable; using Injectio.Attributes; +using NLog; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Core.Services.Rocm; @@ -10,9 +14,25 @@ namespace StabilityMatrix.Core.Services.Rocm; /// Provides the shared ROCm helper surface area used by ROCm-capable packages. /// [RegisterSingleton] -public class RocmPackageHelper : IRocmPackageHelper +public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { - private const string NotImplementedMessage = "ROCm helper behavior has not been implemented yet."; + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + private static readonly string[] UnsupportedRdna2ModelMarkers = + [ + "680m", + "660m", + "610m", + "rx6300", + "w6300", + "rx6400", + "w6400", + "rx6450", + "rx6550", + ]; + + private const string EnvironmentNotImplementedMessage = + "ROCm helper environment composition has not been implemented yet."; /// public Task GetCompatibilityAsync( @@ -20,9 +40,7 @@ public Task GetCompatibilityAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult( - new RocmCompatibilityResult { IsCompatible = false, FailureReason = NotImplementedMessage } - ); + return Task.FromResult(BuildCompatibilityResult(profile)); } /// @@ -33,8 +51,43 @@ public Task ResolveRuntimeContextAsync( CancellationToken cancellationToken = default ) { + var compatibility = BuildCompatibilityResult(profile); + if (!compatibility.IsCompatible) + { + return Task.FromResult( + new RocmRuntimeContext + { + IsSupported = false, + FailureReason = compatibility.FailureReason, + SelectedGpu = compatibility.SelectedGpu, + RuntimeGfxArch = compatibility.ResolvedGfxArch, + } + ); + } + + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var selectedGpu = + compatibility.SelectedGpu + ?? TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) + ?? supportedAmdGpus.FirstOrDefault(); + + var runtimeGfxArch = + compatibility.ResolvedGfxArch + ?? selectedGpu?.GetAmdGfxArch() + ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + return Task.FromResult( - new RocmRuntimeContext { IsSupported = false, FailureReason = NotImplementedMessage } + new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + IsLegacyGpu = IsLegacyArchitecture(runtimeGfxArch), + IsRdna1 = IsRdna1Architecture(runtimeGfxArch), + } ); } @@ -46,7 +99,22 @@ public Task ResolveInstallContextAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult(new RocmInstallContext()); + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) + .Where(IsSupportedWindowsRocmGpu) + .ToList(); + + var preferredGfxArch = TryResolvePreferredAmdGfxArch( + supportedAmdGpus, + settingsManager.Settings.PreferredGpu + ); + + return Task.FromResult( + new RocmInstallContext + { + PreferredGfxArch = preferredGfxArch, + RuntimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus), + } + ); } /// @@ -56,6 +124,9 @@ public IReadOnlyDictionary BuildInstallEnvironment( RocmPackageProfile profile ) { + _ = installLocation; + _ = context; + _ = profile; return new Dictionary(); } @@ -67,7 +138,7 @@ public Task RefreshPackageAfterUpdateAsync( CancellationToken cancellationToken = default ) { - return Task.FromResult(new RocmInstallContext()); + return ResolveInstallContextAsync(installLocation, installedPackage, profile, cancellationToken); } /// @@ -78,6 +149,10 @@ public Task> BuildLaunchEnvironmentAsync( CancellationToken cancellationToken = default ) { + _ = installLocation; + _ = installedPackage; + _ = profile; + _ = cancellationToken; return Task.FromResult>(new Dictionary()); } @@ -100,4 +175,201 @@ public async Task ApplyLaunchEnvironmentAsync( venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); } + + /// + /// Builds a compatibility result from the current machine state and package profile. + /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. + /// + private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) + { + if (profile.RequiresWindows && !Compat.IsWindows) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = "This ROCm profile currently requires Windows.", + }; + } + + var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); + if (amdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = "No AMD GPU was detected for ROCm evaluation.", + }; + } + + var preferredGpu = settingsManager.Settings.PreferredGpu; + if (preferredGpu is not null && IsExplicitlyUnsupportedRdna2Gpu(preferredGpu)) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = $"Selected GPU '{preferredGpu.Name}' is unsupported for Windows ROCm.", + SelectedGpu = preferredGpu, + }; + } + + var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); + if (supportedAmdGpus.Count == 0) + { + return new RocmCompatibilityResult + { + IsCompatible = false, + FailureReason = GetUnsupportedGpuReason(amdGpus), + }; + } + + var selectedGpu = + TryResolvePreferredAmdGpu(supportedAmdGpus, preferredGpu) ?? supportedAmdGpus.First(); + var resolvedGfxArch = selectedGpu.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + + return new RocmCompatibilityResult + { + IsCompatible = !string.IsNullOrWhiteSpace(resolvedGfxArch), + FailureReason = string.IsNullOrWhiteSpace(resolvedGfxArch) + ? "No supported AMD GFX architecture could be resolved for ROCm." + : null, + SelectedGpu = selectedGpu, + ResolvedGfxArch = resolvedGfxArch, + }; + } + + /// + /// Returns AMD GPUs from Stability Matrix's internal hardware model. + /// This is the canonical GPU source for the ROCm helper and intentionally avoids package-local probing. + /// + private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = false) + { + return HardwareHelper.IterGpuInfo(forceRefresh).Where(gpu => gpu.IsAmd).ToList(); + } + + /// + /// Resolves the preferred AMD GPU when the configured preference is still present in the current hardware list. + /// + private static GpuInfo? TryResolvePreferredAmdGpu( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + if (preferredGpu is null || !preferredGpu.IsAmd) + return null; + + var preferredMatch = availableGpus.FirstOrDefault(gpu => gpu.Equals(preferredGpu)); + if (preferredMatch is not null) + return preferredMatch; + + if (!string.IsNullOrWhiteSpace(preferredGpu.Name)) + { + Logger.Info( + "Preferred GPU {PreferredGpuName} was ignored for ROCm detection because it is not present in current hardware enumeration.", + preferredGpu.Name + ); + } + + return null; + } + + /// + /// Resolves the preferred AMD GFX architecture when the configured GPU is supported and currently present. + /// + private static string? TryResolvePreferredAmdGfxArch( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu); + return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu) + ? resolvedPreferredGpu.GetAmdGfxArch() + : null; + } + + /// + /// Resolves the first supported AMD GFX architecture from the current machine state when no preferred GPU applies. + /// + private static string? GetSupportedFallbackGfxArch(IEnumerable availableGpus) + { + return availableGpus + .Where(IsSupportedWindowsRocmGpu) + .Select(gpu => gpu.GetAmdGfxArch()) + .FirstOrDefault(IsSupportedWindowsRocmArchitecture); + } + + /// + /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. + /// Unsupported low-end RDNA2/APU models are filtered explicitly even when they identify as AMD hardware. + /// + private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) + { + if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) + return false; + + return IsSupportedWindowsRocmArchitecture(gpu.GetAmdGfxArch()); + } + + /// + /// Identifies Windows ROCm-incompatible RDNA2 models that need to remain outside the supported GPU set. + /// + private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) + { + if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + var normalizedName = gpu.Name.Replace(" ", string.Empty, StringComparison.Ordinal).ToLowerInvariant(); + return UnsupportedRdna2ModelMarkers.Any(normalizedName.Contains); + } + + /// + /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return gfxArch switch + { + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => true, + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => true, + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => true, + "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => true, + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => true, + _ => false, + }; + } + + /// + /// Returns true for architectures that need the legacy ROCm runtime path. + /// + private static bool IsLegacyArchitecture(string? gfxArch) + { + return gfxArch is not null + && ( + gfxArch.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) + || gfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) + ); + } + + /// + /// Returns true for RDNA1 architectures that need dedicated override handling. + /// + private static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + + /// + /// Produces a readable incompatibility reason when AMD hardware is present but not usable for Windows ROCm. + /// + private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) + { + if (amdGpus.Any(IsExplicitlyUnsupportedRdna2Gpu)) + { + return "Detected only unsupported AMD RDNA2 GPUs for Windows ROCm. Unsupported models include Radeon 680M/660M/610M and RX 6300/6400/6450/6550-class GPUs."; + } + + return "No AMD GPU with a supported Windows ROCm architecture was detected."; + } } From 70f852d523520f27d9c500e99462280112d313d2 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 20:11:58 -0400 Subject: [PATCH 03/18] Initial ComfyUI.cs intergration - Add initial ROCm helper calls/config - Removed pre-existing Windows ROCm blocks which will be obsolete following helper implementation --- .../Models/Packages/ComfyUI.cs | 156 +++++++----------- 1 file changed, 62 insertions(+), 94 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a4c34649..a0b5fb44 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -13,9 +13,11 @@ using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -26,7 +28,8 @@ public class ComfyUI( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -38,6 +41,14 @@ IPipWheelService pipWheelService ) { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + PackageName = "ComfyUI", + RequiresWindows = true, + NeedsRuntimeGfxResolution = true, + }; + public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -247,7 +258,7 @@ IPipWheelService pipWheelService Name = "Enable DirectML", Type = LaunchOptionType.Bool, InitialValue = - !HardwareHelper.HasWindowsRocmSupportedGpu() + !HasWindowsRocmSupport() && HardwareHelper.PreferDirectMLOrZluda() && this is not ComfyZluda, Options = ["--directml"], @@ -362,91 +373,34 @@ public override async Task InstallPackage( .ConfigureAwait(false); var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion(); - var gfxArch = - SettingsManager.Settings.PreferredGpu?.GetAmdGfxArch() - ?? HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - - // Special case for Windows ROCm Nightly builds - if ( - Compat.IsWindows - && !string.IsNullOrWhiteSpace(gfxArch) - && torchIndex is TorchIndex.Rocm - && options.PythonOptions.PythonVersion >= PyVersion.Parse("3.11.0") - ) - { - var config = new PipInstallConfig - { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - SkipTorchInstall = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); - - progress?.Report( - new ProgressReport(-1f, "Installing ROCm nightly torch...", isIndeterminate: true) + var isLegacyNvidia = + torchIndex == TorchIndex.Cuda + && ( + SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() + ?? HardwareHelper.HasLegacyNvidiaGpu() ); - var indexUrl = gfxArch switch - { - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150", // Strix/Gorgon Point - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151", // Strix Halo - _ when gfxArch.StartsWith("gfx110") => "https://rocm.nightlies.amd.com/v2/gfx110X-all", - _ when gfxArch.StartsWith("gfx120") => "https://rocm.nightlies.amd.com/v2/gfx120X-all", - _ => throw new ArgumentOutOfRangeException( - nameof(gfxArch), - $"Unsupported GFX Arch: {gfxArch}" - ), - }; - - var torchPipArgs = new PipInstallArgs() - .AddArgs("--pre", "--upgrade") - .WithTorch() - .WithTorchVision() - .WithTorchAudio() - .AddArgs("--index-url", indexUrl); - - await venvRunner.PipInstall(torchPipArgs, onConsoleOutput).ConfigureAwait(false); - } - else // Standard installation path for all other cases + + var config = new PipInstallConfig { - var isLegacyNvidia = - torchIndex == TorchIndex.Cuda - && ( - SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() - ?? HardwareHelper.HasLegacyNvidiaGpu() - ); + RequirementsFilePaths = ["requirements.txt"], + ExtraPipArgs = ["numpy<2"], + TorchaudioVersion = " ", // Request torchaudio without a specific version + CudaIndex = isLegacyNvidia ? "cu126" : "cu130", + RocmIndex = "rocm7.2", + UpgradePackages = true, + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + }; - var config = new PipInstallConfig - { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - TorchaudioVersion = " ", // Request torchaudio without a specific version - CudaIndex = isLegacyNvidia ? "cu126" : "cu130", - RocmIndex = "rocm7.2", - UpgradePackages = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); - } + await StandardPipInstallProcessAsync( + venvRunner, + options, + installedPackage, + config, + onConsoleOutput, + progress, + cancellationToken + ) + .ConfigureAwait(false); try { @@ -613,13 +567,7 @@ public override TorchIndex GetRecommendedTorchVersion() { var preferRocm = (Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm())) - || ( - Compat.IsWindows - && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() - ) - ); + || HasWindowsRocmSupport(); if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm) { @@ -629,6 +577,28 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } + /// + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. + /// + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + { + return SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() + ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + } + + var compatibility = rocmPackageHelper + .GetCompatibilityAsync(WindowsRocmProfile) + .GetAwaiter() + .GetResult(); + + return compatibility.IsCompatible; + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -982,9 +952,7 @@ await PipWheelService private ImmutableDictionary GetEnvVars(ImmutableDictionary env) { // if we're not on windows or we don't have a windows rocm gpu, return original env - var hasRocmGpu = - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + var hasRocmGpu = HasWindowsRocmSupport(); if (!Compat.IsWindows || !hasRocmGpu) return env; From 84f7f95eeb5d1733f93f1f3341ffbb2ad79451e7 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Tue, 21 Apr 2026 21:15:57 -0400 Subject: [PATCH 04/18] implement helper-owned Windows ROCm install flow - Windows ROCm install/bootstrap logic into shared ROCm helper - Add gfx-family mapping for Windows-native TheRock ROCm URLs - Route ComfyUI Win Rocm installs through helper-resolved ROCm runtime, rocm-sdk, and pytorch setup - Prevent requirements.txt from overwritting helper-installed ROCm torch packages - Add helper-owned post-install torch verification and improve unsupported GPU failure handling --- .../Models/Packages/ComfyUI.cs | 98 +++-- .../Models/Rocm/RocmInstallContext.cs | 4 + .../Models/Rocm/RocmPackageProfile.cs | 30 ++ .../Services/Rocm/IRocmPackageHelper.cs | 15 + .../Services/Rocm/RocmPackageHelper.cs | 336 +++++++++++++++++- 5 files changed, 447 insertions(+), 36 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a0b5fb44..e3f54269 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -46,7 +46,19 @@ public class ComfyUI( { PackageName = "ComfyUI", RequiresWindows = true, + RequiresRocmSdk = true, NeedsRuntimeGfxResolution = true, + NeedsAotritonExperimental = true, + NeedsTunableOpCache = true, + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = _ => new Dictionary + { + ["MIOPEN_FIND_MODE"] = "2", + ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:512,garbage_collection_threshold:0.8", + ["COMFYUI_ENABLE_MIOPEN"] = "1", + }, }; public override string Name => "ComfyUI"; @@ -380,27 +392,51 @@ public override async Task InstallPackage( ?? HardwareHelper.HasLegacyNvidiaGpu() ); - var config = new PipInstallConfig + if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport()) { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - TorchaudioVersion = " ", // Request torchaudio without a specific version - CudaIndex = isLegacyNvidia ? "cu126" : "cu130", - RocmIndex = "rocm7.2", - UpgradePackages = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; + if (rocmPackageHelper is null) + { + throw new InvalidOperationException( + "Windows ROCm installation requires the shared ROCm helper to resolve gfx-specific index URLs." + ); + } - await StandardPipInstallProcessAsync( - venvRunner, - options, - installedPackage, - config, - onConsoleOutput, - progress, - cancellationToken - ) - .ConfigureAwait(false); + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); + } + else + { + var config = new PipInstallConfig + { + RequirementsFilePaths = ["requirements.txt"], + ExtraPipArgs = ["numpy<2"], + TorchaudioVersion = " ", // Request torchaudio without a specific version + CudaIndex = isLegacyNvidia ? "cu126" : "cu130", + RocmIndex = "rocm7.2", + UpgradePackages = true, + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + }; + + await StandardPipInstallProcessAsync( + venvRunner, + options, + installedPackage, + config, + onConsoleOutput, + progress, + cancellationToken + ) + .ConfigureAwait(false); + } try { @@ -433,7 +469,11 @@ await StandardPipInstallProcessAsync( SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu(), WorkingDirectory = installLocation, - EnvironmentVariables = GetEnvVars(venvRunner.EnvironmentVariables), + EnvironmentVariables = GetEnvVars( + venvRunner.EnvironmentVariables, + installLocation, + installedPackage + ), }; await step.ExecuteAsync(progress).ConfigureAwait(false); @@ -483,7 +523,7 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); - VenvRunner.UpdateEnvironmentVariables(GetEnvVars); + VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage)); // Check for old NVIDIA driver version with cu130 installations var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu(); @@ -949,7 +989,11 @@ await PipWheelService .ConfigureAwait(false); } - private ImmutableDictionary GetEnvVars(ImmutableDictionary env) + private ImmutableDictionary GetEnvVars( + ImmutableDictionary env, + string installLocation, + InstalledPackage installedPackage + ) { // if we're not on windows or we don't have a windows rocm gpu, return original env var hasRocmGpu = HasWindowsRocmSupport(); @@ -957,6 +1001,16 @@ private ImmutableDictionary GetEnvVars(ImmutableDictionary + /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. + /// + public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"]; + + /// + /// Package requirement entries to exclude because the helper installs them from ROCm-specific feeds. + /// + public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?"; + + /// + /// Extra package-specific pip arguments to include when installing requirements after helper bootstrap. + /// + public IEnumerable ExtraInstallPipArgs { get; init; } = []; + + /// + /// Extra package-specific pip arguments to install after requirements and torch are complete. + /// + public IEnumerable PostInstallPipArgs { get; init; } = []; + + /// + /// When true, helper-managed requirements installs should use --upgrade. + /// + public bool UpgradePackages { get; init; } + + /// + /// When true, helper-managed torch installs should force reinstall the selected ROCm wheel set. + /// + public bool ForceReinstallTorch { get; init; } = true; + /// /// Optional callback for package-specific cache path variables. /// The helper will eventually merge these with its own defaults. diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 5b9383fc..0fac954e 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -1,5 +1,7 @@ using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; namespace StabilityMatrix.Core.Services.Rocm; @@ -76,4 +78,17 @@ Task ApplyLaunchEnvironmentAsync( RocmPackageProfile profile, CancellationToken cancellationToken = default ); + + /// + /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs. + /// + Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); } diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 1adb141e..a2c700b6 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -1,10 +1,15 @@ using System.Collections.Immutable; +using System.Text.Json; using Injectio.Attributes; using NLog; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; @@ -31,9 +36,6 @@ public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageH "rx6550", ]; - private const string EnvironmentNotImplementedMessage = - "ROCm helper environment composition has not been implemented yet."; - /// public Task GetCompatibilityAsync( RocmPackageProfile profile, @@ -99,6 +101,10 @@ public Task ResolveInstallContextAsync( CancellationToken cancellationToken = default ) { + _ = installLocation; + _ = installedPackage; + _ = cancellationToken; + var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) .Where(IsSupportedWindowsRocmGpu) .ToList(); @@ -108,11 +114,16 @@ public Task ResolveInstallContextAsync( settingsManager.Settings.PreferredGpu ); + var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + var windowsNativeIndexUrl = TryGetWindowsNativeRocmIndexUrl(runtimeGfxArch); + return Task.FromResult( new RocmInstallContext { PreferredGfxArch = preferredGfxArch, - RuntimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus), + RuntimeGfxArch = runtimeGfxArch, + RocmPackageIndexUrl = windowsNativeIndexUrl, + RocmTorchIndexUrl = windowsNativeIndexUrl, } ); } @@ -151,9 +162,30 @@ public Task> BuildLaunchEnvironmentAsync( { _ = installLocation; _ = installedPackage; - _ = profile; - _ = cancellationToken; - return Task.FromResult>(new Dictionary()); + + var runtimeContext = ResolveRuntimeContextAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .GetAwaiter() + .GetResult(); + + if (!runtimeContext.IsSupported) + return Task.FromResult>(new Dictionary()); + + var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); + var packageEnvironment = + profile.ExtraEnvironmentFactory?.Invoke(runtimeContext) ?? new Dictionary(); + + var mergedEnvironment = MergeLaunchEnvironment( + helperEnvironment, + packageEnvironment, + profile.EnvironmentOptions + ); + + return Task.FromResult>(mergedEnvironment); } /// @@ -176,6 +208,146 @@ public async Task ApplyLaunchEnvironmentAsync( venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); } + /// + public async Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var compatibility = await GetCompatibilityAsync(profile, cancellationToken).ConfigureAwait(false); + if (!compatibility.IsCompatible) + { + throw new ApplicationException( + compatibility.FailureReason + ?? "Windows ROCm installation is not supported for the current machine." + ); + } + + var installContext = await ResolveInstallContextAsync( + installLocation, + installedPackage, + profile, + cancellationToken + ) + .ConfigureAwait(false); + + var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; + var rocmTorchIndexUrl = installContext.RocmTorchIndexUrl ?? rocmPackageIndexUrl; + + if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl) || string.IsNullOrWhiteSpace(rocmTorchIndexUrl)) + { + throw new ApplicationException( + $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." + ); + } + + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); + var rocmRuntimeArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArgs("rocm[devel,libraries]", "--no-warn-script-location"); + + if (installedPackage.PipOverrides != null) + { + rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( + rocmSdkExe, + ["init"], + installLocation, + onConsoleOutput + ); + + await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + if (rocmSdkProcess.ExitCode != 0) + { + throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); + } + } + + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmTorchIndexUrl]) + .AddArgs("torch", "torchaudio", "torchvision", "--no-warn-script-location"); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } + + if (installedPackage.PipOverrides != null) + { + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report( + new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) + ); + + var requirementsPipArgs = new PipInstallArgs([.. profile.ExtraInstallPipArgs]); + if (profile.UpgradePackages) + { + requirementsPipArgs = requirementsPipArgs.AddArg("--upgrade"); + } + + foreach (var relativePath in profile.RequirementsFilePaths) + { + var requirementsFile = new FilePath(venvRunner.WorkingDirectory ?? installLocation, relativePath); + if (!requirementsFile.Exists) + continue; + + var requirementsContent = await requirementsFile + .ReadAllTextAsync(cancellationToken) + .ConfigureAwait(false); + + requirementsPipArgs = requirementsPipArgs.WithParsedFromRequirementsTxt( + requirementsContent, + profile.RequirementsExcludePattern + ); + } + + if (installedPackage.PipOverrides != null) + { + requirementsPipArgs = requirementsPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); + + if (!profile.PostInstallPipArgs.Any()) + return; + + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput).ConfigureAwait(false); + } + /// /// Builds a compatibility result from the current machine state and package profile. /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. @@ -309,7 +481,7 @@ private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) return false; - return IsSupportedWindowsRocmArchitecture(gpu.GetAmdGfxArch()); + return TryGetWindowsNativeRocmIndexUrl(gpu.GetAmdGfxArch()) is not null; } /// @@ -328,15 +500,30 @@ private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. /// private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return TryGetWindowsNativeRocmIndexUrl(gfxArch) is not null; + } + + /// + /// Maps an AMD GFX architecture identifier to the Windows-native ROCm Technical Preview feed URL. + /// + private static string? TryGetWindowsNativeRocmIndexUrl(string? gfxArch) { return gfxArch switch { - var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => true, - var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => true, - var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => true, - "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => true, - var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => true, - _ => false, + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/", + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx110X-all/", + "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", + "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", + "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", + "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx120X-all/", + _ => null, }; } @@ -372,4 +559,125 @@ private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) return "No AMD GPU with a supported Windows ROCm architecture was detected."; } + + /// + /// Verifies that the installed torch build still reports a usable ROCm runtime after helper-managed installs complete. + /// + private static async Task VerifyWindowsNativeTorchInstallAsync( + IPyVenvRunner venvRunner, + Action? onConsoleOutput + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new ApplicationException("torch was not installed after Windows ROCm setup."); + } + + var verificationResult = await venvRunner + .Run( + "-c \"import json, torch; print(json.dumps({'version': torch.__version__, 'hip': torch.version.hip, 'cuda': torch.cuda.is_available()}))\"" + ) + .ConfigureAwait(false); + + var verificationOutput = (verificationResult.StandardOutput ?? string.Empty).Trim(); + if (string.IsNullOrWhiteSpace(verificationOutput)) + { + throw new ApplicationException("Torch verification produced no output."); + } + + JsonDocument verificationDocument; + try + { + verificationDocument = JsonDocument.Parse(verificationOutput); + } + catch (Exception exception) + { + throw new ApplicationException( + $"Unexpected torch verification output: {verificationOutput}", + exception + ); + } + + using (verificationDocument) + { + var root = verificationDocument.RootElement; + var version = root.TryGetProperty("version", out var versionElement) + ? versionElement.GetString() + : null; + var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; + var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); + + if (string.IsNullOrWhiteSpace(hipVersion) || !cudaAvailable) + { + throw new ApplicationException( + $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" + ) + ); + } + } + + /// + /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. + /// + private static IReadOnlyDictionary BuildHelperLaunchEnvironment( + RocmRuntimeContext runtimeContext, + RocmPackageProfile profile + ) + { + var environment = new Dictionary(); + + if (profile.NeedsTunableOpCache) + { + environment["PYTORCH_TUNABLEOP_ENABLED"] = "1"; + } + + if (profile.NeedsAotritonExperimental) + { + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + } + + if (profile.NeedsTritonOverrideArch && !string.IsNullOrWhiteSpace(runtimeContext.RuntimeGfxArch)) + { + environment["HSA_OVERRIDE_GFX_VERSION"] = runtimeContext.RuntimeGfxArch; + } + + return environment; + } + + /// + /// Merges helper-owned and package-specific launch environment variables using the profile overlay rules. + /// + private static IReadOnlyDictionary MergeLaunchEnvironment( + IReadOnlyDictionary helperEnvironment, + IReadOnlyDictionary packageEnvironment, + RocmEnvironmentOptions options + ) + { + var merged = new Dictionary(); + + IReadOnlyDictionary[] orderedSources = + options.OverlayPriority == RocmEnvironmentOverlayPriority.HelperThenUserThenPackage + ? new[] { helperEnvironment, packageEnvironment } + : new[] { helperEnvironment, packageEnvironment }; + + foreach (var source in orderedSources) + { + if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) + continue; + + foreach (var pair in source) + { + merged[pair.Key] = pair.Value; + } + } + + return merged; + } } From 31b5955a8cf47a4475e99cd4666a350af9b8d1e1 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 14:51:20 -0400 Subject: [PATCH 05/18] =?UTF-8?q?-=20refactor=20the=20shared=20ROCm=20help?= =?UTF-8?q?er=20into=20a=20synchronous=20compatibility/runtime/install/env?= =?UTF-8?q?ironment=20API=20and=20simplify=20the=20ROCm=20profile/context?= =?UTF-8?q?=20models=20around=20the=20helper=E2=80=99s=20real=20responsibi?= =?UTF-8?q?lities=20-=20add=20a=20centralized=20Windows=20ROCm=20support?= =?UTF-8?q?=20map=20so=20GPU=20detection,=20architecture=20support=20check?= =?UTF-8?q?s,=20and=20package=20index=20resolution=20all=20use=20the=20sam?= =?UTF-8?q?e=20source=20of=20truth=20-=20expand=20AMD=20architecture=20det?= =?UTF-8?q?ection=20to=20cover=20additional=20RDNA4,=20Steam=20Deck,=20RDN?= =?UTF-8?q?A1,=20and=20Vega-class=20GPUs=20used=20by=20the=20Windows=20ROC?= =?UTF-8?q?m=20support=20path=20-=20add=20a=20helper-managed=20Windows=20R?= =?UTF-8?q?OCm=20bootstrap=20flow=20that=20installs=20the=20ROCm=20runtime?= =?UTF-8?q?,=20initializes/reinitializes=20the=20SDK,=20aligns=20rocm-sdk-?= =?UTF-8?q?devel=20with=20the=20resolved=20torch=20build,=20and=20verifies?= =?UTF-8?q?=20both=20torch=20ROCm=20metadata=20and=20runtime=20availabilit?= =?UTF-8?q?y=20-=20centralize=20ROCm=20launch=20environment=20construction?= =?UTF-8?q?=20in=20the=20helper,=20including=20default=20MIOpen,=20allocat?= =?UTF-8?q?or,=20flash-attention,=20and=20AOTriton=20settings=20plus=20leg?= =?UTF-8?q?acy=20SDP=20fallback,=20RDNA1=20overrides,=20and=20user=20env?= =?UTF-8?q?=20override=20layering=20-=20switch=20ComfyUI=20to=20helper-dri?= =?UTF-8?q?ven=20Windows=20ROCm=20compatibility=20and=20launch=20env=20han?= =?UTF-8?q?dling,=20and=20default=20legacy=20Windows=20ROCm=20GPUs=20to=20?= =?UTF-8?q?quad=20cross-attention=20while=20keeping=20Comfy-specific=20MIO?= =?UTF-8?q?pen=20enablement=20as=20a=20preset=20-=20integrate=20Wan2GP=20w?= =?UTF-8?q?ith=20the=20shared=20Windows=20ROCm=20helper=20for=20install=20?= =?UTF-8?q?and=20launch=20flows,=20while=20updating=20its=20Linux=20ROCm?= =?UTF-8?q?=20path=20to=20use=20upstream=20rocm7.2=20torch/vision/audio=20?= =?UTF-8?q?installs=20-=20wire=20the=20ROCm=20helper=20through=20package?= =?UTF-8?q?=20construction=20and=20add=20focused=20test=20coverage=20for?= =?UTF-8?q?=20ROCm=20build/version=20parsing,=20runtime=20failure=20classi?= =?UTF-8?q?fication,=20and=20Windows=20ROCm=20support/index=20resolution?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Helper/Factory/PackageFactory.cs | 11 +- .../Helper/HardwareInfo/GpuInfo.cs | 25 +- .../Helper/HardwareInfo/HardwareHelper.cs | 6 +- .../Models/Packages/ComfyUI.cs | 57 +- .../Models/Packages/Wan2GP.cs | 139 ++-- .../Models/Rocm/RocmEnvironmentOptions.cs | 65 +- .../Models/Rocm/RocmInstallContext.cs | 16 - .../Models/Rocm/RocmPackageProfile.cs | 37 +- .../Models/Rocm/RocmRuntimeContext.cs | 15 - .../Models/Rocm/RocmSdkPaths.cs | 22 - .../Models/Rocm/WindowsRocmSupport.cs | 46 ++ .../Services/Rocm/IRocmPackageHelper.cs | 48 +- .../Services/Rocm/RocmPackageHelper.cs | 677 +++++++++++------- .../Core/RocmPackageHelperTests.cs | 176 +++++ .../Helper/PackageFactoryTests.cs | 2 + 15 files changed, 823 insertions(+), 519 deletions(-) delete mode 100644 StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs create mode 100644 StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs create mode 100644 StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 118efa55..6a073986 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -4,6 +4,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Helper.Factory; @@ -18,6 +19,7 @@ public class PackageFactory : IPackageFactory private readonly IUvManager uvManager; private readonly IPyInstallationManager pyInstallationManager; private readonly IPipWheelService pipWheelService; + private readonly IRocmPackageHelper rocmPackageHelper; /// /// Mapping of package.Name to package @@ -32,7 +34,9 @@ public PackageFactory( IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, IPyRunner pyRunner, - IPipWheelService pipWheelService + IUvManager uvManager, + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) { this.githubApiCache = githubApiCache; @@ -40,8 +44,10 @@ IPipWheelService pipWheelService this.downloadService = downloadService; this.prerequisiteHelper = prerequisiteHelper; this.pyRunner = pyRunner; + this.uvManager = uvManager; this.pyInstallationManager = pyInstallationManager; this.pipWheelService = pipWheelService; + this.rocmPackageHelper = rocmPackageHelper; this.basePackages = basePackages.ToDictionary(x => x.Name); } @@ -55,7 +61,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "Fooocus" => new Fooocus( githubApiCache, diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index eedcb556..0013f65b 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -1,4 +1,6 @@ -namespace StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; + +namespace StabilityMatrix.Core.Helper.HardwareInfo; public record GpuInfo { @@ -62,11 +64,7 @@ public bool IsLegacyNvidiaGpu() public bool IsWindowsRocmSupportedGpu() { - var gfx = GetAmdGfxArch(); - if (gfx is null) - return false; - - return gfx.StartsWith("gfx110") || gfx.StartsWith("gfx120") || gfx.Equals("gfx1151"); + return WindowsRocmSupport.IsSupportedGpu(this); } public bool IsAmd => Name?.Contains("amd", StringComparison.OrdinalIgnoreCase) ?? false; @@ -84,7 +82,7 @@ public bool IsWindowsRocmSupportedGpu() return name switch { // RDNA4 - _ when Has("R9700") || Has("9070") => "gfx1201", + _ when Has("R9700") || Has("R9600") || Has("9070") => "gfx1201", _ when Has("9060") => "gfx1200", // RDNA3.5 APUs @@ -112,6 +110,9 @@ _ when Has("660M") || Has("680M") => "gfx1035", _ when Has("6300") || Has("6400") || Has("6450") || Has("6500") || Has("6550") || Has("6500M") => "gfx1034", + // RDNA2 Steam Deck APU + _ when Has("Van Gogh") || Has("Sephiroth") => "gfx1033", + // RDNA2 Navi23 _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M") => "gfx1032", @@ -121,6 +122,16 @@ _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031", // RDNA2 Navi21 (big die) _ when Has("6800") || Has("6900") || Has("6950") => "gfx1030", + // RDNA1 Navi10 XT (incl. Pro card) + _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010", + + // RDNA1 Navi10 XTX + _ when Has("5500") => "gfx1012", + + // Vega/GCN5 Dedicated GPUs + _ when Has("pro vii") || HasNoSpace("provii") => "gfx90X", + _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", + _ when Has("radeon vii") || HasNoSpace("radeonvii") => "gfx906", _ => null, }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs index 8458c730..93f093d4 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs @@ -7,6 +7,7 @@ using Microsoft.Win32; using NLog; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.Rocm; namespace StabilityMatrix.Core.Helper.HardwareInfo; @@ -316,12 +317,11 @@ public static bool HasAmdGpu() return IterGpuInfo().Any(gpu => gpu.IsAmd); } - public static bool HasWindowsRocmSupportedGpu() => - IterGpuInfo().Any(gpu => gpu is { IsAmd: true, Name: not null } && gpu.IsWindowsRocmSupportedGpu()); + public static bool HasWindowsRocmSupportedGpu() => IterGpuInfo().Any(WindowsRocmSupport.IsSupportedGpu); public static GpuInfo? GetWindowsRocmSupportedGpu() { - return IterGpuInfo().FirstOrDefault(gpu => gpu.IsWindowsRocmSupportedGpu()); + return IterGpuInfo().FirstOrDefault(WindowsRocmSupport.IsSupportedGpu); } public static bool HasIntelGpu() => IterGpuInfo().Any(gpu => gpu.IsIntel); diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index e3f54269..1832a27a 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -45,20 +45,11 @@ public class ComfyUI( private static readonly RocmPackageProfile WindowsRocmProfile = new() { PackageName = "ComfyUI", - RequiresWindows = true, RequiresRocmSdk = true, - NeedsRuntimeGfxResolution = true, - NeedsAotritonExperimental = true, - NeedsTunableOpCache = true, ExtraInstallPipArgs = ["numpy<2"], PostInstallPipArgs = ["typing-extensions>=4.15.0"], UpgradePackages = true, - ExtraEnvironmentFactory = _ => new Dictionary - { - ["MIOPEN_FIND_MODE"] = "2", - ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:512,garbage_collection_threshold:0.8", - ["COMFYUI_ENABLE_MIOPEN"] = "1", - }, + EnvironmentOptions = new RocmEnvironmentOptions { Preset = RocmEnvironmentPreset.ComfyUi }, }; public override string Name => "ComfyUI"; @@ -287,7 +278,9 @@ public class ComfyUI( { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = "--use-pytorch-cross-attention", + InitialValue = ShouldDefaultToQuadCrossAttention() + ? "--use-quad-cross-attention" + : "--use-pytorch-cross-attention", Options = [ "--use-split-cross-attention", @@ -626,19 +619,29 @@ private bool HasWindowsRocmSupport() return false; if (rocmPackageHelper is null) - { - return SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); - } + return false; - var compatibility = rocmPackageHelper - .GetCompatibilityAsync(WindowsRocmProfile) - .GetAwaiter() - .GetResult(); + var compatibility = rocmPackageHelper.GetCompatibility(WindowsRocmProfile); return compatibility.IsCompatible; } + private bool ShouldDefaultToQuadCrossAttention() + { + if (!Compat.IsWindows || !HasWindowsRocmSupport()) + return false; + + var gpu = SettingsManager.Settings.PreferredGpu; + var gfxArch = WindowsRocmSupport.IsSupportedGpu(gpu) + ? gpu?.GetAmdGfxArch() + : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); + + return !string.IsNullOrWhiteSpace(gfxArch) + && !gfxArch.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) + && !gfxArch.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) + && !gfxArch.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase); + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -1003,19 +1006,15 @@ InstalledPackage installedPackage if (rocmPackageHelper is not null) { - var rocmEnvironment = rocmPackageHelper - .BuildLaunchEnvironmentAsync(installLocation, installedPackage, WindowsRocmProfile) - .GetAwaiter() - .GetResult(); + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); return env.SetItems(rocmEnvironment); } - // set some experimental speed improving env vars for Windows ROCm - return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1") - .SetItem("MIOPEN_FIND_MODE", "2") - .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - .SetItem("PYTORCH_ALLOC_CONF", "max_split_size_mb:6144,garbage_collection_threshold:0.8") // greatly helps prevent GPU OOM and instability/driver timeouts/OS hard locks and decreases dependency on Tiled VAE at standard res's - .SetItem("COMFYUI_ENABLE_MIOPEN", "1"); // re-enables "cudnn" in ComfyUI as it's needed for MiOpen to function properly + return env; } } diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 2a00a626..e11fd322 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -6,9 +6,11 @@ using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -30,7 +32,8 @@ public class Wan2GP( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -41,6 +44,14 @@ IPipWheelService pipWheelService pipWheelService ) { + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + PackageName = "Wan2GP", + RequiresRocmSdk = true, + UpgradePackages = true, + PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], + }; + public override string Name => "Wan2GP"; public override string DisplayName { get; set; } = "Wan2GP"; public override string Author => "deepbeepmeep"; @@ -64,7 +75,7 @@ IPipWheelService pipWheelService public override bool IsCompatible => HardwareHelper.HasNvidiaGpu() - || (Compat.IsWindows ? HardwareHelper.HasWindowsRocmSupportedGpu() : HardwareHelper.HasAmdGpu()); + || (Compat.IsWindows ? HasWindowsRocmSupport() : HardwareHelper.HasAmdGpu()); public override string MainBranch => "main"; public override bool ShouldIgnoreReleases => true; @@ -72,7 +83,7 @@ IPipWheelService pipWheelService public override Dictionary> SharedOutputFolders => new() { [SharedOutputType.Img2Vid] = ["outputs"] }; - // AMD ROCm requires Python 3.11, NVIDIA uses 3.10 + // Wan2GP currently uses Python 3.11 for ROCm and 3.10 for CUDA. public override PyVersion RecommendedPythonVersion => IsAmdRocm ? Python.PyInstallationManager.Python_3_11_13 : Python.PyInstallationManager.Python_3_10_17; @@ -86,6 +97,17 @@ IPipWheelService pipWheelService /// private bool IsAmdRocm => GetRecommendedTorchVersion() == TorchIndex.Rocm; + private bool HasWindowsRocmSupport() + { + if (!Compat.IsWindows) + return false; + + if (rocmPackageHelper is null) + return HardwareHelper.HasWindowsRocmSupportedGpu(); + + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile).IsCompatible; + } + /// /// Python wrapper script that patches logging to also print to stdout/stderr, so /// StabilityMatrix can capture the output. Wan2GP logs through Gradio UI notifications @@ -213,8 +235,8 @@ public override TorchIndex GetRecommendedTorchVersion() ( Compat.IsWindows && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() + WindowsRocmSupport.IsSupportedGpu(SettingsManager.Settings.PreferredGpu) + || HasWindowsRocmSupport() ) ) || ( @@ -256,7 +278,15 @@ public override async Task InstallPackage( if (torchIndex == TorchIndex.Rocm) { - await InstallAmdRocmAsync(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + await InstallAmdRocmAsync( + venvRunner, + installLocation, + installedPackage, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); } else { @@ -359,68 +389,53 @@ await venvRunner private async Task InstallAmdRocmAsync( IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, IProgress? progress, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { - progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); - await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - if (Compat.IsWindows) { - // Windows AMD ROCm - special TheRock wheels - progress?.Report( - new ProgressReport(-1f, "Installing PyTorch ROCm wheels...", isIndeterminate: true) - ); - - // Set environment variable for wheel filename check bypass - venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1")); + if (rocmPackageHelper is null) + { + throw new InvalidOperationException( + "Windows ROCm installation for Wan2GP requires the shared ROCm helper." + ); + } - // Install PyTorch ROCm wheels from TheRock releases (Python 3.11) - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl", - onConsoleOutput + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken ) .ConfigureAwait(false); - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + return; + } - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - } - else - { - // Linux AMD ROCm - standard PyTorch ROCm - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - - // Install torch with ROCm index (force reinstall to ensure correct version) - progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .WithTorch("==2.7.0") - .WithTorchVision("==0.22.0") - .WithTorchAudio("==2.7.0") - .WithTorchExtraIndex("rocm6.3") - .AddArg("--force-reinstall") - .AddArg("--no-deps"); - - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - } + progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); + await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .WithTorch() + .WithTorchVision() + .WithTorchAudio() + .WithTorchExtraIndex("rocm7.2") + .AddArg("--force-reinstall") + .AddArg("--no-deps"); + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); // Install additional packages await venvRunner.PipInstall("hf-xet setuptools numpy==1.26.4", onConsoleOutput).ConfigureAwait(false); @@ -437,6 +452,16 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); + if (Compat.IsWindows && rocmPackageHelper is not null && HasWindowsRocmSupport()) + { + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( + installLocation, + installedPackage, + WindowsRocmProfile + ); + VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); + } + // Fix for distutils compatibility issue with Python 3.10 and setuptools VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib")); diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index 11c2bbfb..21ff6e7d 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -1,33 +1,68 @@ namespace StabilityMatrix.Core.Models.Rocm; /// -/// Controls how helper-generated, package-specific, and user-defined environment variables -/// should be layered together once the helper has real behavior. +/// Controls how ROCm helper defaults, package-specific variables, and user overrides are layered at launch. /// public class RocmEnvironmentOptions { - /// - /// Determines the merge order used when multiple environment sources provide the same key. - /// - public RocmEnvironmentOverlayPriority OverlayPriority { get; init; } = - RocmEnvironmentOverlayPriority.HelperThenPackageThenUser; - /// /// When true, package-specific environment additions may be merged on top of helper defaults. /// public bool IncludePackageOverrides { get; init; } = true; /// - /// When true, user-defined Stability Matrix environment variables may be merged last. + /// When true, user-defined Stability Matrix environment variables may override helper/package defaults last. /// public bool IncludeUserOverrides { get; init; } = true; + + /// + /// Selects a package-oriented ROCm environment preset managed by the helper. + /// + public RocmEnvironmentPreset Preset { get; init; } = RocmEnvironmentPreset.None; + + /// + /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. + /// + public string? PyTorchAllocConf { get; init; } = "max_split_size_mb:512,garbage_collection_threshold:0.8"; + + /// + /// When set, configures MIOpen find mode for helper-managed ROCm defaults. + /// + public string? MiopenFindMode { get; init; } = "2"; + + /// + /// When set, configures MIOpen search cutoff for helper-managed ROCm defaults. + /// + public string? MiopenSearchCutoff { get; init; } = "1"; + + /// + /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults. + /// + public string? MiopenFindEnforce { get; init; } = "3"; + + /// + /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults. + /// + public string? FlashAttentionTritonAmdEnable { get; init; } = "TRUE"; + + /// + /// When true, helper-managed defaults will enable ROCm AOTriton on modern Windows ROCm architectures. + /// + public bool ApplyAotritonExperimental { get; init; } = true; + + /// + /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures. + /// + public bool ApplyLegacySdpFallback { get; init; } = true; + + /// + /// When true, helper-managed defaults will apply the RDNA1 HSA override mask when needed. + /// + public bool ApplyRdna1Override { get; init; } = true; } -/// -/// Describes the intended precedence of environment sources for ROCm-enabled package launches. -/// -public enum RocmEnvironmentOverlayPriority +public enum RocmEnvironmentPreset { - HelperThenPackageThenUser, - HelperThenUserThenPackage, + None, + ComfyUi, } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs index d18d70d0..597eb4fe 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -1,5 +1,3 @@ -using StabilityMatrix.Core.Models; - namespace StabilityMatrix.Core.Models.Rocm; /// @@ -7,21 +5,7 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public class RocmInstallContext { - public string? PreferredGfxArch { get; init; } - public string? RuntimeGfxArch { get; init; } public string? RocmPackageIndexUrl { get; init; } - - public string? RocmTorchIndexUrl { get; init; } - - public TorchIndex TorchIndex { get; init; } = TorchIndex.Rocm; - - public string? WheelCompatibilityHints { get; init; } - - public string? SdkRoot { get; init; } - - public RocmSdkPaths SdkPaths { get; init; } = new(); - - public IReadOnlyDictionary Environment { get; init; } = new Dictionary(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index f2a9323e..a7baa675 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -13,32 +13,8 @@ public class RocmPackageProfile /// public string PackageName { get; init; } = string.Empty; - public bool RequiresWindows { get; init; } - public bool RequiresRocmSdk { get; init; } - public bool NeedsRuntimeGfxResolution { get; init; } - - public bool NeedsHipPath { get; init; } - - public bool NeedsRocmPath { get; init; } - - public bool NeedsTritonOverrideArch { get; init; } - - public bool NeedsRdna1Override { get; init; } - - public bool NeedsLegacySdpFallback { get; init; } - - public bool NeedsAotritonExperimental { get; init; } - - public bool NeedsTunableOpCache { get; init; } - - public bool NeedsTritonCache { get; init; } - - public bool NeedsMIOpenDbPaths { get; init; } - - public bool NeedsRocblasPaths { get; init; } - /// /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. /// @@ -69,12 +45,6 @@ public class RocmPackageProfile /// public bool ForceReinstallTorch { get; init; } = true; - /// - /// Optional callback for package-specific cache path variables. - /// The helper will eventually merge these with its own defaults. - /// - public Func>? CacheDirectoryFactory { get; init; } - /// /// Optional callback for package-specific environment variables derived from a resolved ROCm context. /// @@ -84,12 +54,7 @@ public Func< >? ExtraEnvironmentFactory { get; init; } /// - /// Optional progress message prefix or label that package code can surface during install/update work. - /// - public string? ProgressLabel { get; init; } - - /// - /// Controls how helper, package, and user-defined environment variables should be merged. + /// Controls whether package-specific environment variables should be layered on top of helper defaults. /// public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs index 87c88ba6..1fdda791 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -15,19 +15,4 @@ public class RocmRuntimeContext public GpuInfo? SelectedGpu { get; init; } public string? RuntimeGfxArch { get; init; } - - public bool IsLegacyGpu { get; init; } - - public bool IsRdna1 { get; init; } - - public string? HipPath { get; init; } - - public string? RocmPath { get; init; } - - public string? RocmSdkSitePackagesPath { get; init; } - - public RocmSdkPaths SdkPaths { get; init; } = new(); - - public IReadOnlyDictionary ResolvedEnvironment { get; init; } = - new Dictionary(); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs b/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs deleted file mode 100644 index 5789744f..00000000 --- a/StabilityMatrix.Core/Models/Rocm/RocmSdkPaths.cs +++ /dev/null @@ -1,22 +0,0 @@ -namespace StabilityMatrix.Core.Models.Rocm; - -/// -/// Represents ROCm SDK-related paths resolved for a package install. -/// These values are intentionally plain data so package code can decide which paths matter. -/// -public class RocmSdkPaths -{ - public string? RocmRoot { get; init; } - - public string? HipPath { get; init; } - - public string? RocmPath { get; init; } - - public string? RocmSdkSitePackagesPath { get; init; } - - public string? MioPenDbPath { get; init; } - - public string? RocblasDbPath { get; init; } - - public string? RocblasLibraryPath { get; init; } -} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs new file mode 100644 index 00000000..dec46d04 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -0,0 +1,46 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Centralizes Windows ROCm support policy so hardware detection, package selection, +/// and ROCm installation all use the same architecture support map. +/// +public static class WindowsRocmSupport +{ + public static bool IsSupportedGpu(GpuInfo? gpu) + { + if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + return IsSupportedArchitecture(gpu.GetAmdGfxArch()); + } + + public static bool IsSupportedArchitecture(string? gfxArch) + { + return TryGetPackageIndexUrl(gfxArch) is not null; + } + + public static string? TryGetPackageIndexUrl(string? gfxArch) + { + return gfxArch switch + { + "gfx900" => "https://rocm.nightlies.amd.com/v2-staging/gfx900/", // Vega 10 + "gfx906" => "https://rocm.nightlies.amd.com/v2-staging/gfx906/", // Radeon VII, Vega 20 + "gfx90X" => "https://rocm.nightlies.amd.com/v2-staging/gfx90X/", // Radeon Pro VII + var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", // RDNA1 (5000 series, Pro) + var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-all/", // RDNA2 (6000 series, 6xxM Mobile, Steam Deck) + var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx110X-all/", // RDNA3 (7000 series, 7xxM Mobile) + "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", // RDNA3.5 (Strix/Gorgon Point) + "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", // RDNA3.5 (Strix Halo) + "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", // RDNA3.5 (Kraken Point) + "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", // RDNA3.5 (Medusa Point) + var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => + "https://rocm.nightlies.amd.com/v2/gfx120X-all/", // RDNA4 (9000 series) + _ => null, + }; + } +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 0fac954e..03b4ce0c 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -14,69 +14,33 @@ public interface IRocmPackageHelper /// /// Evaluates whether the current machine and package profile are compatible with ROCm. /// - Task GetCompatibilityAsync( - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); + RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile); /// /// Resolves the runtime ROCm facts needed for package launch and environment construction. /// - Task ResolveRuntimeContextAsync( + RocmRuntimeContext ResolveRuntimeContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ); /// /// Resolves the ROCm facts needed during package installation or update operations. /// - Task ResolveInstallContextAsync( + RocmInstallContext ResolveInstallContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - - /// - /// Builds an install-time environment dictionary from a resolved install context. - /// - IReadOnlyDictionary BuildInstallEnvironment( - string installLocation, - RocmInstallContext context, RocmPackageProfile profile ); - /// - /// Re-resolves ROCm install facts after a package update changes dependencies or runtime state. - /// - Task RefreshPackageAfterUpdateAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - /// /// Builds a launch-time environment dictionary from resolved ROCm runtime data. /// - Task> BuildLaunchEnvironmentAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ); - - /// - /// Applies a resolved launch environment to the provided Python venv runner. - /// - Task ApplyLaunchEnvironmentAsync( - IPyVenvRunner venvRunner, + IReadOnlyDictionary BuildLaunchEnvironment( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ); /// diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index a2c700b6..19dced0b 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -22,49 +22,34 @@ namespace StabilityMatrix.Core.Services.Rocm; public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - - private static readonly string[] UnsupportedRdna2ModelMarkers = - [ - "680m", - "660m", - "610m", - "rx6300", - "w6300", - "rx6400", - "w6400", - "rx6450", - "rx6550", - ]; + private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; /// - public Task GetCompatibilityAsync( - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) + public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) { - return Task.FromResult(BuildCompatibilityResult(profile)); + return BuildCompatibilityResult(profile); } /// - public Task ResolveRuntimeContextAsync( + public RocmRuntimeContext ResolveRuntimeContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { + _ = installLocation; + _ = installedPackage; + var compatibility = BuildCompatibilityResult(profile); if (!compatibility.IsCompatible) { - return Task.FromResult( - new RocmRuntimeContext - { - IsSupported = false, - FailureReason = compatibility.FailureReason, - SelectedGpu = compatibility.SelectedGpu, - RuntimeGfxArch = compatibility.ResolvedGfxArch, - } - ); + return new RocmRuntimeContext + { + IsSupported = false, + FailureReason = compatibility.FailureReason, + SelectedGpu = compatibility.SelectedGpu, + RuntimeGfxArch = compatibility.ResolvedGfxArch, + }; } var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) @@ -81,29 +66,23 @@ public Task ResolveRuntimeContextAsync( ?? selectedGpu?.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - return Task.FromResult( - new RocmRuntimeContext - { - IsSupported = true, - SelectedGpu = selectedGpu, - RuntimeGfxArch = runtimeGfxArch, - IsLegacyGpu = IsLegacyArchitecture(runtimeGfxArch), - IsRdna1 = IsRdna1Architecture(runtimeGfxArch), - } - ); + return new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + }; } /// - public Task ResolveInstallContextAsync( + public RocmInstallContext ResolveInstallContext( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { _ = installLocation; _ = installedPackage; - _ = cancellationToken; var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) .Where(IsSupportedWindowsRocmGpu) @@ -115,65 +94,26 @@ public Task ResolveInstallContextAsync( ); var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - var windowsNativeIndexUrl = TryGetWindowsNativeRocmIndexUrl(runtimeGfxArch); - - return Task.FromResult( - new RocmInstallContext - { - PreferredGfxArch = preferredGfxArch, - RuntimeGfxArch = runtimeGfxArch, - RocmPackageIndexUrl = windowsNativeIndexUrl, - RocmTorchIndexUrl = windowsNativeIndexUrl, - } - ); - } - - /// - public IReadOnlyDictionary BuildInstallEnvironment( - string installLocation, - RocmInstallContext context, - RocmPackageProfile profile - ) - { - _ = installLocation; - _ = context; - _ = profile; - return new Dictionary(); - } + var windowsNativeIndexUrl = WindowsRocmSupport.TryGetPackageIndexUrl(runtimeGfxArch); - /// - public Task RefreshPackageAfterUpdateAsync( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) - { - return ResolveInstallContextAsync(installLocation, installedPackage, profile, cancellationToken); + return new RocmInstallContext + { + RuntimeGfxArch = runtimeGfxArch, + RocmPackageIndexUrl = windowsNativeIndexUrl, + }; } /// - public Task> BuildLaunchEnvironmentAsync( + public IReadOnlyDictionary BuildLaunchEnvironment( string installLocation, InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default + RocmPackageProfile profile ) { - _ = installLocation; - _ = installedPackage; - - var runtimeContext = ResolveRuntimeContextAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .GetAwaiter() - .GetResult(); + var runtimeContext = ResolveRuntimeContext(installLocation, installedPackage, profile); if (!runtimeContext.IsSupported) - return Task.FromResult>(new Dictionary()); + return new Dictionary(); var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); var packageEnvironment = @@ -185,27 +125,7 @@ public Task> BuildLaunchEnvironmentAsync( profile.EnvironmentOptions ); - return Task.FromResult>(mergedEnvironment); - } - - /// - public async Task ApplyLaunchEnvironmentAsync( - IPyVenvRunner venvRunner, - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile, - CancellationToken cancellationToken = default - ) - { - var environment = await BuildLaunchEnvironmentAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .ConfigureAwait(false); - - venvRunner.UpdateEnvironmentVariables(env => env.SetItems(environment)); + return mergedEnvironment; } /// @@ -219,7 +139,7 @@ public async Task InstallWindowsNativePackageAsync( CancellationToken cancellationToken = default ) { - var compatibility = await GetCompatibilityAsync(profile, cancellationToken).ConfigureAwait(false); + var compatibility = GetCompatibility(profile); if (!compatibility.IsCompatible) { throw new ApplicationException( @@ -228,18 +148,11 @@ public async Task InstallWindowsNativePackageAsync( ); } - var installContext = await ResolveInstallContextAsync( - installLocation, - installedPackage, - profile, - cancellationToken - ) - .ConfigureAwait(false); + var installContext = ResolveInstallContext(installLocation, installedPackage, profile); var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; - var rocmTorchIndexUrl = installContext.RocmTorchIndexUrl ?? rocmPackageIndexUrl; - if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl) || string.IsNullOrWhiteSpace(rocmTorchIndexUrl)) + if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl)) { throw new ApplicationException( $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." @@ -254,7 +167,7 @@ public async Task InstallWindowsNativePackageAsync( progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); var rocmRuntimeArgs = new PipInstallArgs() .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .AddArgs("rocm[devel,libraries]", "--no-warn-script-location"); + .AddArgs("rocm[devel,libraries]"); if (installedPackage.PipOverrides != null) { @@ -264,43 +177,10 @@ public async Task InstallWindowsNativePackageAsync( await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); - var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); - if (!File.Exists(rocmSdkExe)) - { - throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); - } - - using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( - rocmSdkExe, - ["init"], - installLocation, - onConsoleOutput - ); - - await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); - if (rocmSdkProcess.ExitCode != 0) - { - throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); - } - } - - progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .AddKeyedArgs("--index-url", ["--index-url", rocmTorchIndexUrl]) - .AddArgs("torch", "torchaudio", "torchvision", "--no-warn-script-location"); - - if (profile.ForceReinstallTorch) - { - torchArgs = torchArgs.AddArg("--force-reinstall"); - } - - if (installedPackage.PipOverrides != null) - { - torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); } - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - progress?.Report( new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) ); @@ -334,18 +214,56 @@ public async Task InstallWindowsNativePackageAsync( await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); - if (!profile.PostInstallPipArgs.Any()) - return; + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .AddArg("--pre") + .AddArg("--upgrade") + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .WithTorch() + .WithTorchAudio() + .WithTorchVision(); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } - var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); if (installedPackage.PipOverrides != null) { - postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); } - await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await AlignRocmSdkDevelVersionAsync(venvRunner, rocmPackageIndexUrl, onConsoleOutput) + .ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Reinitializing ROCm SDK...", isIndeterminate: true)); + await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + + if (profile.PostInstallPipArgs.Any()) + { + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + } - await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput).ConfigureAwait(false); + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + + if (profile.RequiresRocmSdk) + { + await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } } /// @@ -354,15 +272,6 @@ public async Task InstallWindowsNativePackageAsync( /// private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) { - if (profile.RequiresWindows && !Compat.IsWindows) - { - return new RocmCompatibilityResult - { - IsCompatible = false, - FailureReason = "This ROCm profile currently requires Windows.", - }; - } - var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); if (amdGpus.Count == 0) { @@ -374,15 +283,6 @@ private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile prof } var preferredGpu = settingsManager.Settings.PreferredGpu; - if (preferredGpu is not null && IsExplicitlyUnsupportedRdna2Gpu(preferredGpu)) - { - return new RocmCompatibilityResult - { - IsCompatible = false, - FailureReason = $"Selected GPU '{preferredGpu.Name}' is unsupported for Windows ROCm.", - SelectedGpu = preferredGpu, - }; - } var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); if (supportedAmdGpus.Count == 0) @@ -471,29 +371,10 @@ private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = fa /// /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. - /// Unsupported low-end RDNA2/APU models are filtered explicitly even when they identify as AMD hardware. /// private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) { - if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) - return false; - - if (IsExplicitlyUnsupportedRdna2Gpu(gpu)) - return false; - - return TryGetWindowsNativeRocmIndexUrl(gpu.GetAmdGfxArch()) is not null; - } - - /// - /// Identifies Windows ROCm-incompatible RDNA2 models that need to remain outside the supported GPU set. - /// - private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) - { - if (!gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) - return false; - - var normalizedName = gpu.Name.Replace(" ", string.Empty, StringComparison.Ordinal).ToLowerInvariant(); - return UnsupportedRdna2ModelMarkers.Any(normalizedName.Contains); + return WindowsRocmSupport.IsSupportedGpu(gpu); } /// @@ -501,50 +382,7 @@ private static bool IsExplicitlyUnsupportedRdna2Gpu(GpuInfo gpu) /// private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) { - return TryGetWindowsNativeRocmIndexUrl(gfxArch) is not null; - } - - /// - /// Maps an AMD GFX architecture identifier to the Windows-native ROCm Technical Preview feed URL. - /// - private static string? TryGetWindowsNativeRocmIndexUrl(string? gfxArch) - { - return gfxArch switch - { - var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", - var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/", - var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx110X-all/", - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", - "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", - "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", - var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx120X-all/", - _ => null, - }; - } - - /// - /// Returns true for architectures that need the legacy ROCm runtime path. - /// - private static bool IsLegacyArchitecture(string? gfxArch) - { - return gfxArch is not null - && ( - gfxArch.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) - || gfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) - ); - } - - /// - /// Returns true for RDNA1 architectures that need dedicated override handling. - /// - private static bool IsRdna1Architecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + return WindowsRocmSupport.IsSupportedArchitecture(gfxArch); } /// @@ -552,20 +390,17 @@ private static bool IsRdna1Architecture(string? gfxArch) /// private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) { - if (amdGpus.Any(IsExplicitlyUnsupportedRdna2Gpu)) - { - return "Detected only unsupported AMD RDNA2 GPUs for Windows ROCm. Unsupported models include Radeon 680M/660M/610M and RX 6300/6400/6450/6550-class GPUs."; - } - + _ = amdGpus; return "No AMD GPU with a supported Windows ROCm architecture was detected."; } /// - /// Verifies that the installed torch build still reports a usable ROCm runtime after helper-managed installs complete. + /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. /// private static async Task VerifyWindowsNativeTorchInstallAsync( IPyVenvRunner venvRunner, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); @@ -586,10 +421,16 @@ private static async Task VerifyWindowsNativeTorchInstallAsync( throw new ApplicationException("Torch verification produced no output."); } + var verificationJson = TryExtractJsonObject(verificationOutput); + if (string.IsNullOrWhiteSpace(verificationJson)) + { + throw new ApplicationException($"Unexpected torch verification output: {verificationOutput}"); + } + JsonDocument verificationDocument; try { - verificationDocument = JsonDocument.Parse(verificationOutput); + verificationDocument = JsonDocument.Parse(verificationJson); } catch (Exception exception) { @@ -608,66 +449,341 @@ private static async Task VerifyWindowsNativeTorchInstallAsync( var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); - if (string.IsNullOrWhiteSpace(hipVersion) || !cudaAvailable) + if (!IsUsableWindowsNativeTorchBuild(version, hipVersion)) { throw new ApplicationException( $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" ); } + if (!cudaAvailable) + { + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Torch verification warning: installed ROCm torch build reported cuda={cudaAvailable}; continuing because ROCm metadata was detected (version={version}, hip={hipVersion})." + ) + ); + } + onConsoleOutput?.Invoke( ProcessOutput.FromStdOutLine( $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" ) ); } + + _ = cancellationToken; + } + + /// + /// Runs rocm-sdk init after the helper-managed runtime packages are installed so the Windows ROCm SDK can prepare the venv. + /// + private static async Task InitializeWindowsNativeRocmSdkAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( + rocmSdkExe, + ["init"], + installLocation, + onConsoleOutput + ); + + await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + if (rocmSdkProcess.ExitCode != 0) + { + throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); + } + } + + /// + /// Uses AMD's bundled hipInfo.exe to confirm the installed Windows ROCm runtime can enumerate a ROCm-capable GPU. + /// + private static async Task VerifyWindowsNativeRocmRuntimeAsync( + string installLocation, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); + if (!File.Exists(rocmSdkExe)) + { + throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); + } + + var rocmBinResult = await ProcessRunner + .GetProcessResultAsync(rocmSdkExe, ["path", "--bin"], installLocation, useUtf8Encoding: true) + .ConfigureAwait(false); + + var rocmBinPath = (rocmBinResult.StandardOutput ?? string.Empty).Trim(); + if (!rocmBinResult.IsSuccessExitCode || string.IsNullOrWhiteSpace(rocmBinPath)) + { + var rocmBinOutput = CombineProcessOutput( + rocmBinResult.StandardOutput, + rocmBinResult.StandardError + ); + throw new ApplicationException( + $"ROCm runtime verification failed while resolving the ROCm SDK bin path. Output: {rocmBinOutput}" + ); + } + + var hipInfoExe = Path.Combine(rocmBinPath, $"hipInfo{Compat.ExeExtension}"); + if (!File.Exists(hipInfoExe)) + { + throw new FileNotFoundException( + "hipInfo.exe was not found in the ROCm SDK bin directory", + hipInfoExe + ); + } + + var hipInfoResult = await ProcessRunner + .GetProcessResultAsync( + hipInfoExe, + [], + installLocation, + new Dictionary { ["PATH"] = rocmBinPath }, + useUtf8Encoding: true + ) + .ConfigureAwait(false); + + var hipInfoOutput = CombineProcessOutput(hipInfoResult.StandardOutput, hipInfoResult.StandardError); + if (!hipInfoResult.IsSuccessExitCode) + { + var runtimeFailureReason = TryGetWindowsNativeRocmRuntimeFailureReason(hipInfoOutput); + throw new ApplicationException( + runtimeFailureReason is null + ? $"ROCm runtime verification failed while probing the installed runtime with hipInfo.exe. Output: {hipInfoOutput}" + : $"ROCm runtime verification failed: {runtimeFailureReason} Output: {hipInfoOutput}" + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"ROCm runtime verification succeeded via hipInfo.exe: {hipInfoOutput}" + ) + ); + + _ = cancellationToken; + } + + /// + /// Reinstalls rocm-sdk-devel to the resolved ROCm build version when the torch step downgrades the runtime stack. + /// + private static async Task AlignRocmSdkDevelVersionAsync( + IPyVenvRunner venvRunner, + string rocmPackageIndexUrl, + Action? onConsoleOutput + ) + { + var rocmInfo = await venvRunner.PipShow("rocm").ConfigureAwait(false); + var rocmSdkDevelInfo = await venvRunner.PipShow("rocm-sdk-devel").ConfigureAwait(false); + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + + var targetVersion = GetRocmSdkDevelAlignmentVersion( + rocmInfo?.Version, + rocmSdkDevelInfo?.Version, + torchInfo?.Version + ); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return; + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Aligning rocm-sdk-devel from version={rocmSdkDevelInfo?.Version ?? "not-installed"} to version={targetVersion} to match the resolved ROCm torch/runtime build." + ) + ); + + var alignmentArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArg("--force-reinstall") + .AddArg($"rocm-sdk-devel=={targetVersion}"); + + await venvRunner.PipInstall(alignmentArgs, onConsoleOutput).ConfigureAwait(false); + } + + internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion) + { + if (!string.IsNullOrWhiteSpace(hipVersion)) + return true; + + return !string.IsNullOrWhiteSpace(version) + && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); + } + + internal static string? GetRocmSdkDevelAlignmentVersion( + string? rocmVersion, + string? rocmSdkDevelVersion, + string? torchVersion = null + ) + { + var targetVersion = !string.IsNullOrWhiteSpace(rocmVersion) + ? rocmVersion + : TryExtractRocmBuildVersion(torchVersion); + + if (string.IsNullOrWhiteSpace(targetVersion)) + return null; + + return string.Equals(targetVersion, rocmSdkDevelVersion, StringComparison.OrdinalIgnoreCase) + ? null + : targetVersion; + } + + internal static string? TryGetWindowsNativeRocmRuntimeFailureReason(string? output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + if (output.Contains("no ROCm-capable device is detected", StringComparison.OrdinalIgnoreCase)) + { + return "the installed ROCm runtime could not detect a ROCm-capable GPU on this system."; + } + + if (output.Contains("No WDDM adapters found", StringComparison.OrdinalIgnoreCase)) + { + return "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack."; + } + + return null; + } + + internal static string? TryExtractRocmBuildVersion(string? torchVersion) + { + if (string.IsNullOrWhiteSpace(torchVersion)) + return null; + + var rocmMarkerIndex = torchVersion.IndexOf("rocm", StringComparison.OrdinalIgnoreCase); + if (rocmMarkerIndex < 0) + return null; + + var rocmBuildVersion = torchVersion[(rocmMarkerIndex + "rocm".Length)..].Trim(); + return string.IsNullOrWhiteSpace(rocmBuildVersion) ? null : rocmBuildVersion; + } + + internal static string? TryExtractJsonObject(string output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + var trimmedOutput = output.Trim(); + + for (var index = 0; index < trimmedOutput.Length; index++) + { + if (trimmedOutput[index] != '{') + continue; + + try + { + using var document = JsonDocument.Parse(trimmedOutput[index..]); + return document.RootElement.GetRawText(); + } + catch (JsonException) { } + } + + return null; + } + + internal static string CombineProcessOutput(string? standardOutput, string? standardError) + { + var sections = new[] { standardOutput?.Trim(), standardError?.Trim() }.Where(section => + !string.IsNullOrWhiteSpace(section) + ); + + return string.Join(Environment.NewLine, sections); } /// /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. /// - private static IReadOnlyDictionary BuildHelperLaunchEnvironment( + private IReadOnlyDictionary BuildHelperLaunchEnvironment( RocmRuntimeContext runtimeContext, RocmPackageProfile profile ) { - var environment = new Dictionary(); + var environment = new Dictionary(EnvComparer); + var options = profile.EnvironmentOptions; + var gfxArch = runtimeContext.RuntimeGfxArch; + + ApplyPresetLaunchEnvironment(environment, gfxArch, options); + + return environment; + } + + private void ApplyPresetLaunchEnvironment( + IDictionary environment, + string? gfxArch, + RocmEnvironmentOptions options + ) + { + SetIfNotNull(environment, "FLASH_ATTENTION_TRITON_AMD_ENABLE", options.FlashAttentionTritonAmdEnable); + SetIfNotNull(environment, "MIOPEN_FIND_MODE", options.MiopenFindMode); + SetIfNotNull(environment, "MIOPEN_SEARCH_CUTOFF", options.MiopenSearchCutoff); + SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); + SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); - if (profile.NeedsTunableOpCache) + if (options.ApplyAotritonExperimental && IsModernWindowsRocmArchitecture(gfxArch)) { - environment["PYTORCH_TUNABLEOP_ENABLED"] = "1"; + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; } - if (profile.NeedsAotritonExperimental) + if (!IsModernWindowsRocmArchitecture(gfxArch) && options.ApplyLegacySdpFallback) { - environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; } - if (profile.NeedsTritonOverrideArch && !string.IsNullOrWhiteSpace(runtimeContext.RuntimeGfxArch)) + if (options.ApplyRdna1Override && IsRdna1Architecture(gfxArch)) { - environment["HSA_OVERRIDE_GFX_VERSION"] = runtimeContext.RuntimeGfxArch; + environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; } - return environment; + if (options.Preset == RocmEnvironmentPreset.ComfyUi && IsModernWindowsRocmArchitecture(gfxArch)) + { + environment["COMFYUI_ENABLE_MIOPEN"] = "1"; + } + } + + private static bool IsModernWindowsRocmArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + private static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + + private static void SetIfNotNull(IDictionary environment, string key, string? value) + { + if (!string.IsNullOrWhiteSpace(value)) + { + environment[key] = value; + } } /// - /// Merges helper-owned and package-specific launch environment variables using the profile overlay rules. + /// Merges helper-owned and package-specific launch environment variables. /// - private static IReadOnlyDictionary MergeLaunchEnvironment( + private IReadOnlyDictionary MergeLaunchEnvironment( IReadOnlyDictionary helperEnvironment, IReadOnlyDictionary packageEnvironment, RocmEnvironmentOptions options ) { - var merged = new Dictionary(); + var merged = new Dictionary(EnvComparer); - IReadOnlyDictionary[] orderedSources = - options.OverlayPriority == RocmEnvironmentOverlayPriority.HelperThenUserThenPackage - ? new[] { helperEnvironment, packageEnvironment } - : new[] { helperEnvironment, packageEnvironment }; - - foreach (var source in orderedSources) + foreach (var source in new[] { helperEnvironment, packageEnvironment }) { if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) continue; @@ -678,6 +794,17 @@ RocmEnvironmentOptions options } } + if ( + options.IncludeUserOverrides + && settingsManager.Settings.EnvironmentVariables is { Count: > 0 } userOverrides + ) + { + foreach (var pair in userOverrides) + { + merged[pair.Key] = pair.Value; + } + } + return merged; } } diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs new file mode 100644 index 00000000..0afb7b86 --- /dev/null +++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs @@ -0,0 +1,176 @@ +using System.Text.Json; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class RocmPackageHelperTests +{ + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsRocmVersion_WhenVersionsMismatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260501" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_ReturnsNull_WhenVersionsAlreadyMatch() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: "7.13.0a20260416", + rocmSdkDevelVersion: "7.13.0a20260416" + ); + + Assert.IsNull(targetVersion); + } + + [TestMethod] + public void GetRocmSdkDevelAlignmentVersion_FallsBackToTorchBuildVersion() + { + var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( + rocmVersion: null, + rocmSdkDevelVersion: "7.13.0a20260501", + torchVersion: "2.11.0+rocm7.13.0a20260416" + ); + + Assert.AreEqual("7.13.0a20260416", targetVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsNull_WhenTorchVersionHasNoRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0"); + + Assert.IsNull(rocmBuildVersion); + } + + [TestMethod] + public void TryExtractRocmBuildVersion_ReturnsVersionSuffix_WhenTorchVersionContainsRocmTag() + { + var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0+rocm7.13.0a20260416"); + + Assert.AreEqual("7.13.0a20260416", rocmBuildVersion); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenHipMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: "test-hip-version" + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenVersionContainsRocm() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version+rocm", + hipVersion: null + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsFalse_WhenNoRocmMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: null + ); + + Assert.IsFalse(isUsable); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsJson_WhenOutputContainsDiagnosticPrefix() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output" + + "\nwarning: continuing with torch verification" + + "\n{\"version\": \"test-version\", \"hip\": \"test-hip-version\", \"cuda\": false}"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNotNull(json); + + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + + Assert.AreEqual("test-version", root.GetProperty("version").GetString()); + Assert.AreEqual("test-hip-version", root.GetProperty("hip").GetString()); + Assert.IsFalse(root.GetProperty("cuda").GetBoolean()); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output\n" + + "warning: no JSON payload was produced"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNull(json); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsDeviceDetectionMessage() + { + const string output = "checkHipErrors() HIP API error = 0100 \"no ROCm-capable device is detected\""; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.", + reason + ); + } + + [TestMethod] + public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsWddmMessage() + { + const string output = "warning: No WDDM adapters found."; + + var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); + + Assert.AreEqual( + "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.", + reason + ); + } + + [TestMethod] + public void CombineProcessOutput_JoinsStdoutAndStderr() + { + var combined = RocmPackageHelper.CombineProcessOutput("stdout line", "stderr line"); + + Assert.AreEqual($"stdout line{Environment.NewLine}stderr line", combined); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetPackageIndexUrl_ReturnsExpectedIndex_ForKrakenPoint() + { + var indexUrl = WindowsRocmSupport.TryGetPackageIndexUrl("gfx1152"); + + Assert.AreEqual("https://rocm.nightlies.amd.com/v2-staging/gfx1152/", indexUrl); + } + + [TestMethod] + public void WindowsRocmSupport_IsSupportedGpu_ReturnsTrue_ForSupportedAmdGpu() + { + var gpu = new GpuInfo { Name = "AMD Radeon RX 9070 XT", MemoryBytes = 16UL * Size.GiB }; + + Assert.IsTrue(WindowsRocmSupport.IsSupportedGpu(gpu)); + } +} diff --git a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs index f7802703..bf45d60c 100644 --- a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs +++ b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs @@ -24,6 +24,8 @@ public void Setup() null!, null!, null!, + null!, + null!, null! ); } From bd7ddfd045820eefb35feee5184ff76dcba753f4 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 22:02:47 -0400 Subject: [PATCH 06/18] refactor shared Windows ROCm policy and package launch defaults - centralize Windows ROCm architecture classification and legacy-attention fallback policy in WindowsRocmSupport - move ComfyUI-specific MIOpen env handling out of the helper and into package-owned ROCm config - reuse shared ROCm policy for ComfyUI quad-attention defaults and helper-managed AOTriton / math SDP / RDNA1 gates - remove dead ROCm preset plumbing and trim unused RocmPackageProfile surface - rename helper/package methods for clearer default-policy semantics --- .../Models/Packages/ComfyUI.cs | 42 +++++++++++-------- .../Models/Packages/Wan2GP.cs | 1 - .../Models/Rocm/RocmEnvironmentOptions.cs | 11 ----- .../Models/Rocm/RocmPackageProfile.cs | 5 --- .../Models/Rocm/WindowsRocmSupport.cs | 26 +++++++++++- .../Services/Rocm/RocmPackageHelper.cs | 27 +++--------- 6 files changed, 53 insertions(+), 59 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 1832a27a..7bfd0c36 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -42,16 +42,6 @@ public class ComfyUI( { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); - private static readonly RocmPackageProfile WindowsRocmProfile = new() - { - PackageName = "ComfyUI", - RequiresRocmSdk = true, - ExtraInstallPipArgs = ["numpy<2"], - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - UpgradePackages = true, - EnvironmentOptions = new RocmEnvironmentOptions { Preset = RocmEnvironmentPreset.ComfyUi }, - }; - public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -278,7 +268,7 @@ public class ComfyUI( { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = ShouldDefaultToQuadCrossAttention() + InitialValue = DefaultToQuadCrossAttention() ? "--use-quad-cross-attention" : "--use-pytorch-cross-attention", Options = @@ -610,9 +600,26 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } - /// + /// Windows ROCm install profile for ComfyUI. + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + RequiresRocmSdk = true, + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, + }; + + private static IReadOnlyDictionary BuildComfyWindowsRocmEnvironment( + RocmRuntimeContext runtimeContext + ) + { + return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) + ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } + : new Dictionary(); + } + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. - /// private bool HasWindowsRocmSupport() { if (!Compat.IsWindows) @@ -626,7 +633,9 @@ private bool HasWindowsRocmSupport() return compatibility.IsCompatible; } - private bool ShouldDefaultToQuadCrossAttention() + /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower + /// and not as supported on older AMD architectures. + private bool DefaultToQuadCrossAttention() { if (!Compat.IsWindows || !HasWindowsRocmSupport()) return false; @@ -636,10 +645,7 @@ private bool ShouldDefaultToQuadCrossAttention() ? gpu?.GetAmdGfxArch() : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - return !string.IsNullOrWhiteSpace(gfxArch) - && !gfxArch.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) - && !gfxArch.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) - && !gfxArch.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase); + return WindowsRocmSupport.PreferLegacyAttentionFallback(gfxArch); } public override IPackageExtensionManager ExtensionManager => diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index e11fd322..0c91f23a 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -46,7 +46,6 @@ public class Wan2GP( { private static readonly RocmPackageProfile WindowsRocmProfile = new() { - PackageName = "Wan2GP", RequiresRocmSdk = true, UpgradePackages = true, PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index 21ff6e7d..e0126170 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -15,11 +15,6 @@ public class RocmEnvironmentOptions /// public bool IncludeUserOverrides { get; init; } = true; - /// - /// Selects a package-oriented ROCm environment preset managed by the helper. - /// - public RocmEnvironmentPreset Preset { get; init; } = RocmEnvironmentPreset.None; - /// /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. /// @@ -60,9 +55,3 @@ public class RocmEnvironmentOptions /// public bool ApplyRdna1Override { get; init; } = true; } - -public enum RocmEnvironmentPreset -{ - None, - ComfyUi, -} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index a7baa675..d1abf16f 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -8,11 +8,6 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public class RocmPackageProfile { - /// - /// Logical package name for diagnostics and profile-specific decisions. - /// - public string PackageName { get; init; } = string.Empty; - public bool RequiresRocmSdk { get; init; } /// diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs index dec46d04..ee000b1b 100644 --- a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -3,8 +3,8 @@ namespace StabilityMatrix.Core.Models.Rocm; /// -/// Centralizes Windows ROCm support policy so hardware detection, package selection, -/// and ROCm installation all use the same architecture support map. +/// Centralizes Windows ROCm support and architecture policy so hardware detection, package selection, +/// installation, and shared launch decisions use the same support map. /// public static class WindowsRocmSupport { @@ -21,6 +21,28 @@ public static bool IsSupportedArchitecture(string? gfxArch) return TryGetPackageIndexUrl(gfxArch) is not null; } + public static bool IsModernArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + public static bool IsLegacyArchitecture(string? gfxArch) + { + return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch); + } + + public static bool PreferLegacyAttentionFallback(string? gfxArch) + { + return IsLegacyArchitecture(gfxArch); + } + + public static bool IsRdna1Architecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; + } + public static string? TryGetPackageIndexUrl(string? gfxArch) { return gfxArch switch diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 19dced0b..664e27d9 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -712,12 +712,12 @@ RocmPackageProfile profile var options = profile.EnvironmentOptions; var gfxArch = runtimeContext.RuntimeGfxArch; - ApplyPresetLaunchEnvironment(environment, gfxArch, options); + ApplyDefaultLaunchEnvironment(environment, gfxArch, options); return environment; } - private void ApplyPresetLaunchEnvironment( + private void ApplyDefaultLaunchEnvironment( IDictionary environment, string? gfxArch, RocmEnvironmentOptions options @@ -729,39 +729,22 @@ RocmEnvironmentOptions options SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); - if (options.ApplyAotritonExperimental && IsModernWindowsRocmArchitecture(gfxArch)) + if (options.ApplyAotritonExperimental && WindowsRocmSupport.IsModernArchitecture(gfxArch)) { environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; } - if (!IsModernWindowsRocmArchitecture(gfxArch) && options.ApplyLegacySdpFallback) + if (options.ApplyLegacySdpFallback && WindowsRocmSupport.IsLegacyArchitecture(gfxArch)) { environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; } - if (options.ApplyRdna1Override && IsRdna1Architecture(gfxArch)) + if (options.ApplyRdna1Override && WindowsRocmSupport.IsRdna1Architecture(gfxArch)) { environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; } - - if (options.Preset == RocmEnvironmentPreset.ComfyUi && IsModernWindowsRocmArchitecture(gfxArch)) - { - environment["COMFYUI_ENABLE_MIOPEN"] = "1"; - } - } - - private static bool IsModernWindowsRocmArchitecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true - || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true - || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; - } - - private static bool IsRdna1Architecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; } private static void SetIfNotNull(IDictionary environment, string key, string? value) From a4fbb6445f6b9a30becf8870032bf1bbb1690357 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Fri, 1 May 2026 22:05:32 -0400 Subject: [PATCH 07/18] add comment for legacy AMD GPU support in cross attention method --- StabilityMatrix.Core/Models/Packages/ComfyUI.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 7bfd0c36..759406c3 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -269,7 +269,7 @@ public class ComfyUI( Name = "Cross Attention Method", Type = LaunchOptionType.Bool, InitialValue = DefaultToQuadCrossAttention() - ? "--use-quad-cross-attention" + ? "--use-quad-cross-attention" // For Legacy AMD GPUs. : "--use-pytorch-cross-attention", Options = [ From 0165a7baaf5b524bc3a8715c4194107bcf20b235 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:32:35 -0400 Subject: [PATCH 08/18] Change forceRefresh parameter to false in GetAmdGpuCandidates Was used during debugging and was unintentionally left on. --- StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 664e27d9..7ff23bc2 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -272,7 +272,7 @@ await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, canc /// private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) { - var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); + var amdGpus = GetAmdGpuCandidates(forceRefresh: false).ToList(); if (amdGpus.Count == 0) { return new RocmCompatibilityResult From 215991ff6c5ef94bd44cf2dfba8f5c389321c013 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:46:49 -0400 Subject: [PATCH 09/18] Change exception type from ApplicationException to InvalidOperationException --- StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 7ff23bc2..a5382ac9 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -142,7 +142,7 @@ public async Task InstallWindowsNativePackageAsync( var compatibility = GetCompatibility(profile); if (!compatibility.IsCompatible) { - throw new ApplicationException( + throw new InvalidOperationException( compatibility.FailureReason ?? "Windows ROCm installation is not supported for the current machine." ); From 4cbf5852a4546343575540069b89aad506fb3e76 Mon Sep 17 00:00:00 2001 From: NeuralFault <65365345+NeuralFault@users.noreply.github.com> Date: Sat, 2 May 2026 18:53:54 -0400 Subject: [PATCH 10/18] Add rocmPackageHelper dependency to Wan2GP --- StabilityMatrix.Core/Helper/Factory/PackageFactory.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 6a073986..0cb0b808 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -287,7 +287,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)), }; From 3458ef1da0d5ccf2ef8d81782e6b76d72ea743dc Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Thu, 7 May 2026 12:54:12 -0400 Subject: [PATCH 11/18] Add shared Windows ROCm helper compatibility support Add helper-managed ROCm torch compatibility profiles for Windows packages Enable ComfyUI to use shared ROCm dependency fallback behavior Add shared Windows ROCm launch notice and experimental-support messaging Align Wan2GP Windows ROCm disclaimer text with the shared helper messaging Lower helper-managed MIOPEN_FIND_ENFORCE default from 3 to 1 Add gfx103x borrowed dependency fallback using ROCm-hosted setuptools and mpmath from gfx103x-dgpu due to -all index missing compatible versions Add helper preinstall and supplemental dependency handling for ROCm torch installs Add temporary environment override support for helper-owned install steps --- .../Models/Packages/ComfyUI.cs | 32 +++ .../Models/Packages/Wan2GP.cs | 2 +- .../Models/Rocm/RocmEnvironmentOptions.cs | 2 +- .../Models/Rocm/RocmPackageProfile.cs | 7 + .../Models/Rocm/RocmTorchCompatibilityMode.cs | 17 ++ .../Services/Rocm/IRocmPackageHelper.cs | 5 + .../Services/Rocm/RocmPackageHelper.cs | 202 +++++++++++++++++- 7 files changed, 254 insertions(+), 13 deletions(-) create mode 100644 StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 759406c3..4fef3fda 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -561,6 +561,8 @@ older torch index (e.g. cu128) } } + var handledFirstConsoleOutput = false; + VenvRunner.RunDetached( [Path.Combine(installLocation, options.Command ?? LaunchCommand), .. options.Arguments], HandleConsoleOutput, @@ -573,6 +575,12 @@ void HandleConsoleOutput(ProcessOutput s) { onConsoleOutput?.Invoke(s); + if (!handledFirstConsoleOutput) + { + handledFirstConsoleOutput = true; + EmitWindowsRocmLaunchNotice(installedPackage, onConsoleOutput); + } + if (!s.Text.Contains("To see the GUI go to", StringComparison.OrdinalIgnoreCase)) return; @@ -586,6 +594,29 @@ void HandleConsoleOutput(ProcessOutput s) } } + private void EmitWindowsRocmLaunchNotice( + InstalledPackage installedPackage, + Action? onConsoleOutput + ) + { + if (!ShouldShowWindowsRocmLaunchNotice(installedPackage) || rocmPackageHelper is null) + return; + + foreach (var line in rocmPackageHelper.GetWindowsLaunchNoticeLines()) + { + onConsoleOutput?.Invoke(ProcessOutput.FromStdOutLine($"{line}{Environment.NewLine}")); + } + } + + private bool ShouldShowWindowsRocmLaunchNotice(InstalledPackage installedPackage) + { + if (!Compat.IsWindows || rocmPackageHelper is null || !HasWindowsRocmSupport()) + return false; + + var torchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion(); + return torchIndex == TorchIndex.Rocm; + } + public override TorchIndex GetRecommendedTorchVersion() { var preferRocm = @@ -607,6 +638,7 @@ public override TorchIndex GetRecommendedTorchVersion() ExtraInstallPipArgs = ["numpy<2"], PostInstallPipArgs = ["typing-extensions>=4.15.0"], UpgradePackages = true, + TorchCompatibilityMode = RocmTorchCompatibilityMode.HelperManagedDependencyFallback, ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, }; diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 0c91f23a..4e82a751 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -88,7 +88,7 @@ public class Wan2GP( public override string Disclaimer => IsAmdRocm && Compat.IsWindows - ? "AMD GPU support on Windows requires RX 7000 series or newer GPU" + ? "Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.\nBecause this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself." : string.Empty; /// diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index e0126170..de5787dc 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -33,7 +33,7 @@ public class RocmEnvironmentOptions /// /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults. /// - public string? MiopenFindEnforce { get; init; } = "3"; + public string? MiopenFindEnforce { get; init; } = "1"; /// /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults. diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index d1abf16f..2f18a07a 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -40,6 +40,13 @@ public class RocmPackageProfile /// public bool ForceReinstallTorch { get; init; } = true; + /// + /// Optional helper-managed compatibility mode for ROCm torch installation. + /// The package declares the intent here while the helper resolves any architecture-specific + /// fallback indexes or borrowed dependency rules internally. + /// + public RocmTorchCompatibilityMode TorchCompatibilityMode { get; init; } + /// /// Optional callback for package-specific environment variables derived from a resolved ROCm context. /// diff --git a/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs b/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs new file mode 100644 index 00000000..f8e9f7da --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs @@ -0,0 +1,17 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// High-level helper-managed compatibility modes for ROCm torch installation. +/// Package profiles should declare intent here and let the ROCm helper resolve any +/// architecture-specific index or dependency fallback details. +/// +public enum RocmTorchCompatibilityMode +{ + None, + + /// + /// Lets the helper apply built-in Windows ROCm dependency fallback rules when + /// specific TheRock indexes are missing compatible transitive dependency wheels. + /// + HelperManagedDependencyFallback, +} diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 03b4ce0c..78924250 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -43,6 +43,11 @@ IReadOnlyDictionary BuildLaunchEnvironment( RocmPackageProfile profile ); + /// + /// Returns shared Windows ROCm launch notice lines for helper-managed packages. + /// + IReadOnlyList GetWindowsLaunchNoticeLines(); + /// /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs. /// diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 664e27d9..d2a89bbe 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -21,8 +21,37 @@ namespace StabilityMatrix.Core.Services.Rocm; [RegisterSingleton] public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { + private sealed class HelperManagedTorchInstallPolicy + { + public IReadOnlyList PreInstallPackageSpecifiers { get; init; } = []; + + public IReadOnlyList SupplementalPackageSpecifiers { get; init; } = []; + + public bool? UsePreRelease { get; init; } + + public bool? ForceReinstall { get; init; } + } + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; + private static readonly string[] WindowsLaunchNoticeLines = + [ + "Stability Matrix Windows ROCm Notice: Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.", + "Because this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself.", + ]; + private static readonly HelperManagedTorchInstallPolicy Rdna2BorrowedDependencyFallbackPolicy = new() + { + PreInstallPackageSpecifiers = + [ + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/setuptools-80.9.0-py3-none-any.whl", + "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/mpmath-1.3.0-py3-none-any.whl", + ], + SupplementalPackageSpecifiers = + [ + "setuptools @ https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/setuptools-80.9.0-py3-none-any.whl", + "mpmath @ https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/mpmath-1.3.0-py3-none-any.whl", + ], + }; /// public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) @@ -128,6 +157,12 @@ RocmPackageProfile profile return mergedEnvironment; } + /// + public IReadOnlyList GetWindowsLaunchNoticeLines() + { + return WindowsLaunchNoticeLines; + } + /// public async Task InstallWindowsNativePackageAsync( IPyVenvRunner venvRunner, @@ -162,19 +197,53 @@ public async Task InstallWindowsNativePackageAsync( progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + var torchInstallPolicy = GetApplicableTorchInstallPolicy(profile, installContext.RuntimeGfxArch); + if (profile.RequiresRocmSdk) { progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); - var rocmRuntimeArgs = new PipInstallArgs() - .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .AddArgs("rocm[devel,libraries]"); - - if (installedPackage.PipOverrides != null) - { - rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides(installedPackage.PipOverrides); - } - await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); + await WithTemporaryEnvironmentOverrideAsync( + venvRunner, + "SETUPTOOLS_USE_DISTUTILS", + "setuptools", + async () => + { + if ( + torchInstallPolicy is not null + && torchInstallPolicy.PreInstallPackageSpecifiers.Count > 0 + ) + { + var runtimeBootstrapArgs = BuildSupplementalPreinstallArgs(torchInstallPolicy); + + if (installedPackage.PipOverrides != null) + { + runtimeBootstrapArgs = runtimeBootstrapArgs.WithUserOverrides( + installedPackage.PipOverrides + ); + } + + await venvRunner + .PipInstall(runtimeBootstrapArgs, onConsoleOutput) + .ConfigureAwait(false); + } + + var rocmRuntimeArgs = new PipInstallArgs() + .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) + .AddArg("--no-build-isolation") + .AddArgs("rocm[devel,libraries]"); + + if (installedPackage.PipOverrides != null) + { + rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides( + installedPackage.PipOverrides + ); + } + + await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); + } + ) + .ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) @@ -214,16 +283,52 @@ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, canc await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); + if (torchInstallPolicy is not null && torchInstallPolicy.PreInstallPackageSpecifiers.Count > 0) + { + progress?.Report( + new ProgressReport( + -1f, + "Installing ROCm torch bootstrap dependencies...", + isIndeterminate: true + ) + ); + + var preinstallArgs = BuildSupplementalPreinstallArgs(torchInstallPolicy); + + if (installedPackage.PipOverrides != null) + { + preinstallArgs = preinstallArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(preinstallArgs, onConsoleOutput).ConfigureAwait(false); + } + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + var usePreRelease = torchInstallPolicy?.UsePreRelease ?? true; + var forceReinstall = torchInstallPolicy?.ForceReinstall ?? profile.ForceReinstallTorch; + var torchArgs = new PipInstallArgs() - .AddArg("--pre") .AddArg("--upgrade") .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) .WithTorch() .WithTorchAudio() .WithTorchVision(); - if (profile.ForceReinstallTorch) + if (usePreRelease) + { + torchArgs = torchArgs.AddArg("--pre"); + } + + if (torchInstallPolicy is not null && torchInstallPolicy.SupplementalPackageSpecifiers.Count > 0) + { + torchArgs = torchArgs.AddArgs( + torchInstallPolicy + .SupplementalPackageSpecifiers.Select(specifier => new Argument(specifier)) + .ToArray() + ); + } + + if (forceReinstall) { torchArgs = torchArgs.AddArg("--force-reinstall"); } @@ -394,6 +499,81 @@ private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) return "No AMD GPU with a supported Windows ROCm architecture was detected."; } + /// + /// Returns the optional torch install policy when the resolved runtime GFX architecture matches the + /// declared activation prefixes. + /// + private static HelperManagedTorchInstallPolicy? GetApplicableTorchInstallPolicy( + RocmPackageProfile profile, + string? runtimeGfxArch + ) + { + return profile.TorchCompatibilityMode switch + { + RocmTorchCompatibilityMode.None => null, + RocmTorchCompatibilityMode.HelperManagedDependencyFallback => + GetHelperManagedDependencyFallbackPolicy(runtimeGfxArch), + _ => null, + }; + } + + /// + /// Resolves the helper-owned dependency fallback policy for runtime architectures that need borrowed + /// transitive dependencies from a supplemental TheRock index. + /// + private static HelperManagedTorchInstallPolicy? GetHelperManagedDependencyFallbackPolicy( + string? runtimeGfxArch + ) + { + if (string.IsNullOrWhiteSpace(runtimeGfxArch)) + return null; + + return runtimeGfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) + ? Rdna2BorrowedDependencyFallbackPolicy + : null; + } + + /// + /// Builds a deterministic preinstall step for borrowed torch bootstrap dependencies using direct + /// wheel URLs so the main ROCm package resolution remains on the primary index. + /// + private static PipInstallArgs BuildSupplementalPreinstallArgs(HelperManagedTorchInstallPolicy policy) + { + return new PipInstallArgs() + .AddArg("--upgrade") + .AddArgs( + policy.PreInstallPackageSpecifiers.Select(specifier => new Argument(specifier)).ToArray() + ); + } + + /// + /// Temporarily overrides a venv environment variable for a helper-owned install step and restores the + /// previous value afterward. + /// + private static async Task WithTemporaryEnvironmentOverrideAsync( + IPyVenvRunner venvRunner, + string key, + string value, + Func action + ) + { + venvRunner.EnvironmentVariables.TryGetValue(key, out var originalValue); + var hadOriginalValue = venvRunner.EnvironmentVariables.ContainsKey(key); + + venvRunner.UpdateEnvironmentVariables(env => env.SetItem(key, value)); + + try + { + await action().ConfigureAwait(false); + } + finally + { + venvRunner.UpdateEnvironmentVariables(env => + hadOriginalValue ? env.SetItem(key, originalValue!) : env.Remove(key) + ); + } + } + /// /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. /// From dc15e2d517f8982be1c98443071de920ec7efd75 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Thu, 7 May 2026 19:12:28 -0400 Subject: [PATCH 12/18] Refactor Windows ROCm helper around multi-arch TheRock installs - generalize the ROCm helper install path around a new multi-arch TheRock nightly repo index flow - unify ROCm and PyTorch installation into a single command that automatically selects the correct PyTorch build for the user's GPU architecture - centralize Windows ROCm GPU and architecture resolution behind one shared helper machine-state path - trim the ROCm helper API and remove obsolete package-side ROCm install/runtime duplication - deleted no longer used torch compatibility model since A3WebUI/reForge work was rolled back and abandoned to be visited again in the future for intergration --- .../Helper/Factory/PackageFactory.cs | 3 +- .../Helper/HardwareInfo/GpuInfo.cs | 4 +- .../Models/Packages/ComfyUI.cs | 54 +- .../Models/Packages/Wan2GP.cs | 40 +- .../Models/Rocm/RocmInstallContext.cs | 2 +- .../Models/Rocm/RocmPackageProfile.cs | 15 +- .../Models/Rocm/RocmTorchCompatibilityMode.cs | 17 - .../Models/Rocm/WindowsRocmSupport.cs | 41 +- .../Services/Rocm/IRocmPackageHelper.cs | 26 +- .../Services/Rocm/RocmPackageHelper.cs | 546 +++--------------- .../Core/RocmPackageHelperTests.cs | 84 +-- 11 files changed, 154 insertions(+), 678 deletions(-) delete mode 100644 StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 6a073986..0cb0b808 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -287,7 +287,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)), }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index 0013f65b..2bb0f4a8 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -129,9 +129,9 @@ _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010", _ when Has("5500") => "gfx1012", // Vega/GCN5 Dedicated GPUs - _ when Has("pro vii") || HasNoSpace("provii") => "gfx90X", _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", - _ when Has("radeon vii") || HasNoSpace("radeonvii") => "gfx906", + _ when Has("radeon vii") || HasNoSpace("radeonvii") || Has("pro vii") || HasNoSpace("provii") => + "gfx906", _ => null, }; diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 4fef3fda..edca7d00 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -377,10 +377,13 @@ public override async Task InstallPackage( if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport()) { + // This is an internal guard for a wiring/configuration failure. + // It can only trigger when Windows ROCm support was detected, but this ComfyUI instance was created + // without the shared ROCm helper (for example via a manual construction path that omitted the dependency). if (rocmPackageHelper is null) { throw new InvalidOperationException( - "Windows ROCm installation requires the shared ROCm helper to resolve gfx-specific index URLs." + "Windows ROCm installation encountered an internal configuration error [rocmPackageHelper is null]. Please restart Stability Matrix and try again. If the issue persists, please report it to Stability Matrix." ); } @@ -599,7 +602,10 @@ private void EmitWindowsRocmLaunchNotice( Action? onConsoleOutput ) { - if (!ShouldShowWindowsRocmLaunchNotice(installedPackage) || rocmPackageHelper is null) + if (rocmPackageHelper is null) + return; + + if (!ShouldShowWindowsRocmLaunchNotice(installedPackage)) return; foreach (var line in rocmPackageHelper.GetWindowsLaunchNoticeLines()) @@ -610,7 +616,7 @@ private void EmitWindowsRocmLaunchNotice( private bool ShouldShowWindowsRocmLaunchNotice(InstalledPackage installedPackage) { - if (!Compat.IsWindows || rocmPackageHelper is null || !HasWindowsRocmSupport()) + if (!Compat.IsWindows || !HasWindowsRocmSupport()) return false; var torchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion(); @@ -634,11 +640,9 @@ public override TorchIndex GetRecommendedTorchVersion() /// Windows ROCm install profile for ComfyUI. private static readonly RocmPackageProfile WindowsRocmProfile = new() { - RequiresRocmSdk = true, ExtraInstallPipArgs = ["numpy<2"], PostInstallPipArgs = ["typing-extensions>=4.15.0"], UpgradePackages = true, - TorchCompatibilityMode = RocmTorchCompatibilityMode.HelperManagedDependencyFallback, ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, }; @@ -654,30 +658,28 @@ RocmRuntimeContext runtimeContext /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. private bool HasWindowsRocmSupport() { - if (!Compat.IsWindows) - return false; - - if (rocmPackageHelper is null) - return false; + return GetWindowsRocmCompatibility().IsCompatible; + } - var compatibility = rocmPackageHelper.GetCompatibility(WindowsRocmProfile); + private RocmCompatibilityResult GetWindowsRocmCompatibility() + { + if (!Compat.IsWindows || rocmPackageHelper is null) + { + return new RocmCompatibilityResult { IsCompatible = false }; + } - return compatibility.IsCompatible; + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile); } /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower /// and not as supported on older AMD architectures. private bool DefaultToQuadCrossAttention() { - if (!Compat.IsWindows || !HasWindowsRocmSupport()) + var compatibility = GetWindowsRocmCompatibility(); + if (!compatibility.IsCompatible) return false; - var gpu = SettingsManager.Settings.PreferredGpu; - var gfxArch = WindowsRocmSupport.IsSupportedGpu(gpu) - ? gpu?.GetAmdGfxArch() - : HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - - return WindowsRocmSupport.PreferLegacyAttentionFallback(gfxArch); + return WindowsRocmSupport.PreferLegacyAttentionFallback(compatibility.ResolvedGfxArch); } public override IPackageExtensionManager ExtensionManager => @@ -1042,17 +1044,11 @@ InstalledPackage installedPackage if (!Compat.IsWindows || !hasRocmGpu) return env; - if (rocmPackageHelper is not null) - { - var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( - installLocation, - installedPackage, - WindowsRocmProfile - ); + if (rocmPackageHelper is null) + return env; - return env.SetItems(rocmEnvironment); - } + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(WindowsRocmProfile); - return env; + return env.SetItems(rocmEnvironment); } } diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 4e82a751..9150a61a 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -33,7 +33,7 @@ public class Wan2GP( IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, IPipWheelService pipWheelService, - IRocmPackageHelper? rocmPackageHelper = null + IRocmPackageHelper rocmPackageHelper ) : BaseGitPackage( githubApi, @@ -46,7 +46,6 @@ public class Wan2GP( { private static readonly RocmPackageProfile WindowsRocmProfile = new() { - RequiresRocmSdk = true, UpgradePackages = true, PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], }; @@ -98,13 +97,17 @@ public class Wan2GP( private bool HasWindowsRocmSupport() { - if (!Compat.IsWindows) - return false; + return GetWindowsRocmCompatibility().IsCompatible; + } - if (rocmPackageHelper is null) - return HardwareHelper.HasWindowsRocmSupportedGpu(); + private RocmCompatibilityResult GetWindowsRocmCompatibility() + { + if (!Compat.IsWindows) + { + return new RocmCompatibilityResult { IsCompatible = false }; + } - return rocmPackageHelper.GetCompatibility(WindowsRocmProfile).IsCompatible; + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile); } /// @@ -231,13 +234,7 @@ public override TorchIndex GetRecommendedTorchVersion() { // Check for AMD ROCm support (Windows or Linux) var preferRocm = - ( - Compat.IsWindows - && ( - WindowsRocmSupport.IsSupportedGpu(SettingsManager.Settings.PreferredGpu) - || HasWindowsRocmSupport() - ) - ) + (Compat.IsWindows && HasWindowsRocmSupport()) || ( Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm()) @@ -397,13 +394,6 @@ CancellationToken cancellationToken { if (Compat.IsWindows) { - if (rocmPackageHelper is null) - { - throw new InvalidOperationException( - "Windows ROCm installation for Wan2GP requires the shared ROCm helper." - ); - } - await rocmPackageHelper .InstallWindowsNativePackageAsync( venvRunner, @@ -451,13 +441,9 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); - if (Compat.IsWindows && rocmPackageHelper is not null && HasWindowsRocmSupport()) + if (Compat.IsWindows && HasWindowsRocmSupport()) { - var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment( - installLocation, - installedPackage, - WindowsRocmProfile - ); + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(WindowsRocmProfile); VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs index 597eb4fe..daaeb849 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -7,5 +7,5 @@ public class RocmInstallContext { public string? RuntimeGfxArch { get; init; } - public string? RocmPackageIndexUrl { get; init; } + public string? MultiArchDeviceExtra { get; init; } } diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs index 2f18a07a..cd28db7a 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -8,20 +8,18 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public class RocmPackageProfile { - public bool RequiresRocmSdk { get; init; } - /// - /// Requirement files to install after helper-owned ROCm runtime / torch bootstrap steps complete. + /// Requirement files to install after helper-owned ROCm torch installation completes. /// public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"]; /// - /// Package requirement entries to exclude because the helper installs them from ROCm-specific feeds. + /// Package requirement entries to exclude because the helper installs them from the ROCm multi-arch feed. /// public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?"; /// - /// Extra package-specific pip arguments to include when installing requirements after helper bootstrap. + /// Extra package-specific pip arguments to include when installing requirements before the helper-managed torch step. /// public IEnumerable ExtraInstallPipArgs { get; init; } = []; @@ -40,13 +38,6 @@ public class RocmPackageProfile /// public bool ForceReinstallTorch { get; init; } = true; - /// - /// Optional helper-managed compatibility mode for ROCm torch installation. - /// The package declares the intent here while the helper resolves any architecture-specific - /// fallback indexes or borrowed dependency rules internally. - /// - public RocmTorchCompatibilityMode TorchCompatibilityMode { get; init; } - /// /// Optional callback for package-specific environment variables derived from a resolved ROCm context. /// diff --git a/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs b/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs deleted file mode 100644 index f8e9f7da..00000000 --- a/StabilityMatrix.Core/Models/Rocm/RocmTorchCompatibilityMode.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace StabilityMatrix.Core.Models.Rocm; - -/// -/// High-level helper-managed compatibility modes for ROCm torch installation. -/// Package profiles should declare intent here and let the ROCm helper resolve any -/// architecture-specific index or dependency fallback details. -/// -public enum RocmTorchCompatibilityMode -{ - None, - - /// - /// Lets the helper apply built-in Windows ROCm dependency fallback rules when - /// specific TheRock indexes are missing compatible transitive dependency wheels. - /// - HelperManagedDependencyFallback, -} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs index ee000b1b..e784b615 100644 --- a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -8,6 +8,9 @@ namespace StabilityMatrix.Core.Models.Rocm; /// public static class WindowsRocmSupport { + public const string MultiArchPythonPackageIndexUrl = + "https://rocm.nightlies.amd.com/whl-staging-multi-arch/"; + public static bool IsSupportedGpu(GpuInfo? gpu) { if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) @@ -18,7 +21,7 @@ public static bool IsSupportedGpu(GpuInfo? gpu) public static bool IsSupportedArchitecture(string? gfxArch) { - return TryGetPackageIndexUrl(gfxArch) is not null; + return TryGetCanonicalArchitecture(gfxArch) is not null; } public static bool IsModernArchitecture(string? gfxArch) @@ -43,26 +46,28 @@ public static bool IsRdna1Architecture(string? gfxArch) return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; } - public static string? TryGetPackageIndexUrl(string? gfxArch) + public static string? TryGetCanonicalArchitecture(string? gfxArch) { - return gfxArch switch + if (string.IsNullOrWhiteSpace(gfxArch)) + return null; + + var normalizedArch = gfxArch.ToLowerInvariant(); + + return normalizedArch switch { - "gfx900" => "https://rocm.nightlies.amd.com/v2-staging/gfx900/", // Vega 10 - "gfx906" => "https://rocm.nightlies.amd.com/v2-staging/gfx906/", // Radeon VII, Vega 20 - "gfx90X" => "https://rocm.nightlies.amd.com/v2-staging/gfx90X/", // Radeon Pro VII - var s when s != null && s.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx101X-dgpu/", // RDNA1 (5000 series, Pro) - var s when s != null && s.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2-staging/gfx103X-all/", // RDNA2 (6000 series, 6xxM Mobile, Steam Deck) - var s when s != null && s.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx110X-all/", // RDNA3 (7000 series, 7xxM Mobile) - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150/", // RDNA3.5 (Strix/Gorgon Point) - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151/", // RDNA3.5 (Strix Halo) - "gfx1152" => "https://rocm.nightlies.amd.com/v2-staging/gfx1152/", // RDNA3.5 (Kraken Point) - "gfx1153" => "https://rocm.nightlies.amd.com/v2-staging/gfx1153/", // RDNA3.5 (Medusa Point) - var s when s != null && s.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) => - "https://rocm.nightlies.amd.com/v2/gfx120X-all/", // RDNA4 (9000 series) + "gfx900" or "gfx906" or "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => normalizedArch, + var s + when s.StartsWith("gfx101", StringComparison.Ordinal) + || s.StartsWith("gfx103", StringComparison.Ordinal) + || s.StartsWith("gfx110", StringComparison.Ordinal) + || s.StartsWith("gfx120", StringComparison.Ordinal) => normalizedArch, _ => null, }; } + + public static string? TryGetMultiArchDeviceExtra(string? gfxArch) + { + var canonicalArch = TryGetCanonicalArchitecture(gfxArch); + return canonicalArch is null ? null : $"device-{canonicalArch}"; + } } diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 78924250..12c991c3 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -16,32 +16,10 @@ public interface IRocmPackageHelper /// RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile); - /// - /// Resolves the runtime ROCm facts needed for package launch and environment construction. - /// - RocmRuntimeContext ResolveRuntimeContext( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ); - - /// - /// Resolves the ROCm facts needed during package installation or update operations. - /// - RocmInstallContext ResolveInstallContext( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ); - /// /// Builds a launch-time environment dictionary from resolved ROCm runtime data. /// - IReadOnlyDictionary BuildLaunchEnvironment( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ); + IReadOnlyDictionary BuildLaunchEnvironment(RocmPackageProfile profile); /// /// Returns shared Windows ROCm launch notice lines for helper-managed packages. @@ -49,7 +27,7 @@ RocmPackageProfile profile IReadOnlyList GetWindowsLaunchNoticeLines(); /// - /// Performs the Windows-native ROCm bootstrap/install flow for a package using helper-resolved gfx-family feed URLs. + /// Performs the Windows-native ROCm install flow for a package using helper-resolved multi-arch device extras. /// Task InstallWindowsNativePackageAsync( IPyVenvRunner venvRunner, diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index d2a89bbe..fc3d3bf2 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using System.Text.Json; using Injectio.Attributes; using NLog; @@ -21,17 +20,6 @@ namespace StabilityMatrix.Core.Services.Rocm; [RegisterSingleton] public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper { - private sealed class HelperManagedTorchInstallPolicy - { - public IReadOnlyList PreInstallPackageSpecifiers { get; init; } = []; - - public IReadOnlyList SupplementalPackageSpecifiers { get; init; } = []; - - public bool? UsePreRelease { get; init; } - - public bool? ForceReinstall { get; init; } - } - private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; private static readonly string[] WindowsLaunchNoticeLines = @@ -39,107 +27,57 @@ private sealed class HelperManagedTorchInstallPolicy "Stability Matrix Windows ROCm Notice: Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.", "Because this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself.", ]; - private static readonly HelperManagedTorchInstallPolicy Rdna2BorrowedDependencyFallbackPolicy = new() - { - PreInstallPackageSpecifiers = - [ - "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/setuptools-80.9.0-py3-none-any.whl", - "https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/mpmath-1.3.0-py3-none-any.whl", - ], - SupplementalPackageSpecifiers = - [ - "setuptools @ https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/setuptools-80.9.0-py3-none-any.whl", - "mpmath @ https://rocm.nightlies.amd.com/v2-staging/gfx103X-dgpu/mpmath-1.3.0-py3-none-any.whl", - ], - }; /// public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) { + _ = profile; return BuildCompatibilityResult(profile); } /// - public RocmRuntimeContext ResolveRuntimeContext( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ) + private RocmRuntimeContext ResolveRuntimeContext(RocmPackageProfile profile) { - _ = installLocation; - _ = installedPackage; + _ = profile; - var compatibility = BuildCompatibilityResult(profile); - if (!compatibility.IsCompatible) + var state = ResolveWindowsMachineState(); + if (!state.IsCompatible) { return new RocmRuntimeContext { IsSupported = false, - FailureReason = compatibility.FailureReason, - SelectedGpu = compatibility.SelectedGpu, - RuntimeGfxArch = compatibility.ResolvedGfxArch, + FailureReason = state.FailureReason, + SelectedGpu = state.SelectedGpu, + RuntimeGfxArch = state.RuntimeGfxArch, }; } - var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) - .Where(IsSupportedWindowsRocmGpu) - .ToList(); - - var selectedGpu = - compatibility.SelectedGpu - ?? TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) - ?? supportedAmdGpus.FirstOrDefault(); - - var runtimeGfxArch = - compatibility.ResolvedGfxArch - ?? selectedGpu?.GetAmdGfxArch() - ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - return new RocmRuntimeContext { IsSupported = true, - SelectedGpu = selectedGpu, - RuntimeGfxArch = runtimeGfxArch, + SelectedGpu = state.SelectedGpu, + RuntimeGfxArch = state.RuntimeGfxArch, }; } /// - public RocmInstallContext ResolveInstallContext( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ) + private RocmInstallContext ResolveInstallContext(RocmPackageProfile profile) { - _ = installLocation; - _ = installedPackage; - - var supportedAmdGpus = GetAmdGpuCandidates(forceRefresh: true) - .Where(IsSupportedWindowsRocmGpu) - .ToList(); - - var preferredGfxArch = TryResolvePreferredAmdGfxArch( - supportedAmdGpus, - settingsManager.Settings.PreferredGpu - ); + _ = profile; - var runtimeGfxArch = preferredGfxArch ?? GetSupportedFallbackGfxArch(supportedAmdGpus); - var windowsNativeIndexUrl = WindowsRocmSupport.TryGetPackageIndexUrl(runtimeGfxArch); + var state = ResolveWindowsMachineState(); return new RocmInstallContext { - RuntimeGfxArch = runtimeGfxArch, - RocmPackageIndexUrl = windowsNativeIndexUrl, + RuntimeGfxArch = state.RuntimeGfxArch, + MultiArchDeviceExtra = state.MultiArchDeviceExtra, }; } /// - public IReadOnlyDictionary BuildLaunchEnvironment( - string installLocation, - InstalledPackage installedPackage, - RocmPackageProfile profile - ) + public IReadOnlyDictionary BuildLaunchEnvironment(RocmPackageProfile profile) { - var runtimeContext = ResolveRuntimeContext(installLocation, installedPackage, profile); + var runtimeContext = ResolveRuntimeContext(profile); if (!runtimeContext.IsSupported) return new Dictionary(); @@ -183,73 +121,20 @@ public async Task InstallWindowsNativePackageAsync( ); } - var installContext = ResolveInstallContext(installLocation, installedPackage, profile); + var installContext = ResolveInstallContext(profile); - var rocmPackageIndexUrl = installContext.RocmPackageIndexUrl; + var multiArchDeviceExtra = installContext.MultiArchDeviceExtra; - if (string.IsNullOrWhiteSpace(rocmPackageIndexUrl)) + if (string.IsNullOrWhiteSpace(multiArchDeviceExtra)) { throw new ApplicationException( - $"No Windows ROCm Technical Preview index URL is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." + $"No Windows ROCm multi-arch device extra is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." ); } progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - var torchInstallPolicy = GetApplicableTorchInstallPolicy(profile, installContext.RuntimeGfxArch); - - if (profile.RequiresRocmSdk) - { - progress?.Report(new ProgressReport(-1f, "Installing ROCm runtime...", isIndeterminate: true)); - - await WithTemporaryEnvironmentOverrideAsync( - venvRunner, - "SETUPTOOLS_USE_DISTUTILS", - "setuptools", - async () => - { - if ( - torchInstallPolicy is not null - && torchInstallPolicy.PreInstallPackageSpecifiers.Count > 0 - ) - { - var runtimeBootstrapArgs = BuildSupplementalPreinstallArgs(torchInstallPolicy); - - if (installedPackage.PipOverrides != null) - { - runtimeBootstrapArgs = runtimeBootstrapArgs.WithUserOverrides( - installedPackage.PipOverrides - ); - } - - await venvRunner - .PipInstall(runtimeBootstrapArgs, onConsoleOutput) - .ConfigureAwait(false); - } - - var rocmRuntimeArgs = new PipInstallArgs() - .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .AddArg("--no-build-isolation") - .AddArgs("rocm[devel,libraries]"); - - if (installedPackage.PipOverrides != null) - { - rocmRuntimeArgs = rocmRuntimeArgs.WithUserOverrides( - installedPackage.PipOverrides - ); - } - - await venvRunner.PipInstall(rocmRuntimeArgs, onConsoleOutput).ConfigureAwait(false); - } - ) - .ConfigureAwait(false); - - progress?.Report(new ProgressReport(-1f, "Initializing ROCm SDK...", isIndeterminate: true)); - await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) - .ConfigureAwait(false); - } - progress?.Report( new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) ); @@ -283,52 +168,18 @@ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, canc await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); - if (torchInstallPolicy is not null && torchInstallPolicy.PreInstallPackageSpecifiers.Count > 0) - { - progress?.Report( - new ProgressReport( - -1f, - "Installing ROCm torch bootstrap dependencies...", - isIndeterminate: true - ) - ); - - var preinstallArgs = BuildSupplementalPreinstallArgs(torchInstallPolicy); - - if (installedPackage.PipOverrides != null) - { - preinstallArgs = preinstallArgs.WithUserOverrides(installedPackage.PipOverrides); - } - - await venvRunner.PipInstall(preinstallArgs, onConsoleOutput).ConfigureAwait(false); - } - progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); - var usePreRelease = torchInstallPolicy?.UsePreRelease ?? true; - var forceReinstall = torchInstallPolicy?.ForceReinstall ?? profile.ForceReinstallTorch; var torchArgs = new PipInstallArgs() .AddArg("--upgrade") - .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .WithTorch() - .WithTorchAudio() - .WithTorchVision(); - - if (usePreRelease) - { - torchArgs = torchArgs.AddArg("--pre"); - } - - if (torchInstallPolicy is not null && torchInstallPolicy.SupplementalPackageSpecifiers.Count > 0) - { - torchArgs = torchArgs.AddArgs( - torchInstallPolicy - .SupplementalPackageSpecifiers.Select(specifier => new Argument(specifier)) - .ToArray() + .AddKeyedArgs("--index-url", ["--index-url", WindowsRocmSupport.MultiArchPythonPackageIndexUrl]) + .AddArgs( + new Argument($"torch[{multiArchDeviceExtra}]"), + new Argument($"torchvision[{multiArchDeviceExtra}]"), + new Argument("torchaudio") ); - } - if (forceReinstall) + if (profile.ForceReinstallTorch) { torchArgs = torchArgs.AddArg("--force-reinstall"); } @@ -339,17 +190,6 @@ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, canc } await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - - if (profile.RequiresRocmSdk) - { - await AlignRocmSdkDevelVersionAsync(venvRunner, rocmPackageIndexUrl, onConsoleOutput) - .ConfigureAwait(false); - - progress?.Report(new ProgressReport(-1f, "Reinitializing ROCm SDK...", isIndeterminate: true)); - await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, cancellationToken) - .ConfigureAwait(false); - } - if (profile.PostInstallPipArgs.Any()) { var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); @@ -363,12 +203,6 @@ await InitializeWindowsNativeRocmSdkAsync(installLocation, onConsoleOutput, canc await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken) .ConfigureAwait(false); - - if (profile.RequiresRocmSdk) - { - await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, cancellationToken) - .ConfigureAwait(false); - } } /// @@ -376,23 +210,35 @@ await VerifyWindowsNativeRocmRuntimeAsync(installLocation, onConsoleOutput, canc /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. /// private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) + { + _ = profile; + var state = ResolveWindowsMachineState(); + + return new RocmCompatibilityResult + { + IsCompatible = state.IsCompatible, + FailureReason = state.FailureReason, + SelectedGpu = state.SelectedGpu, + ResolvedGfxArch = state.RuntimeGfxArch, + }; + } + + private ResolvedWindowsRocmState ResolveWindowsMachineState() { var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); if (amdGpus.Count == 0) { - return new RocmCompatibilityResult + return new ResolvedWindowsRocmState { IsCompatible = false, FailureReason = "No AMD GPU was detected for ROCm evaluation.", }; } - var preferredGpu = settingsManager.Settings.PreferredGpu; - var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); if (supportedAmdGpus.Count == 0) { - return new RocmCompatibilityResult + return new ResolvedWindowsRocmState { IsCompatible = false, FailureReason = GetUnsupportedGpuReason(amdGpus), @@ -400,17 +246,22 @@ private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile prof } var selectedGpu = - TryResolvePreferredAmdGpu(supportedAmdGpus, preferredGpu) ?? supportedAmdGpus.First(); - var resolvedGfxArch = selectedGpu.GetAmdGfxArch() ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) + ?? supportedAmdGpus.First(); + var runtimeGfxArch = + WindowsRocmSupport.TryGetCanonicalArchitecture(selectedGpu.GetAmdGfxArch()) + ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + var isCompatible = !string.IsNullOrWhiteSpace(runtimeGfxArch); - return new RocmCompatibilityResult + return new ResolvedWindowsRocmState { - IsCompatible = !string.IsNullOrWhiteSpace(resolvedGfxArch), - FailureReason = string.IsNullOrWhiteSpace(resolvedGfxArch) - ? "No supported AMD GFX architecture could be resolved for ROCm." - : null, + IsCompatible = isCompatible, + FailureReason = isCompatible + ? null + : "No supported AMD GFX architecture could be resolved for ROCm.", SelectedGpu = selectedGpu, - ResolvedGfxArch = resolvedGfxArch, + RuntimeGfxArch = runtimeGfxArch, + MultiArchDeviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra(runtimeGfxArch), }; } @@ -459,7 +310,7 @@ private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = fa { var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu); return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu) - ? resolvedPreferredGpu.GetAmdGfxArch() + ? WindowsRocmSupport.TryGetCanonicalArchitecture(resolvedPreferredGpu.GetAmdGfxArch()) : null; } @@ -470,7 +321,7 @@ private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = fa { return availableGpus .Where(IsSupportedWindowsRocmGpu) - .Select(gpu => gpu.GetAmdGfxArch()) + .Select(gpu => WindowsRocmSupport.TryGetCanonicalArchitecture(gpu.GetAmdGfxArch())) .FirstOrDefault(IsSupportedWindowsRocmArchitecture); } @@ -499,81 +350,6 @@ private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) return "No AMD GPU with a supported Windows ROCm architecture was detected."; } - /// - /// Returns the optional torch install policy when the resolved runtime GFX architecture matches the - /// declared activation prefixes. - /// - private static HelperManagedTorchInstallPolicy? GetApplicableTorchInstallPolicy( - RocmPackageProfile profile, - string? runtimeGfxArch - ) - { - return profile.TorchCompatibilityMode switch - { - RocmTorchCompatibilityMode.None => null, - RocmTorchCompatibilityMode.HelperManagedDependencyFallback => - GetHelperManagedDependencyFallbackPolicy(runtimeGfxArch), - _ => null, - }; - } - - /// - /// Resolves the helper-owned dependency fallback policy for runtime architectures that need borrowed - /// transitive dependencies from a supplemental TheRock index. - /// - private static HelperManagedTorchInstallPolicy? GetHelperManagedDependencyFallbackPolicy( - string? runtimeGfxArch - ) - { - if (string.IsNullOrWhiteSpace(runtimeGfxArch)) - return null; - - return runtimeGfxArch.StartsWith("gfx103", StringComparison.OrdinalIgnoreCase) - ? Rdna2BorrowedDependencyFallbackPolicy - : null; - } - - /// - /// Builds a deterministic preinstall step for borrowed torch bootstrap dependencies using direct - /// wheel URLs so the main ROCm package resolution remains on the primary index. - /// - private static PipInstallArgs BuildSupplementalPreinstallArgs(HelperManagedTorchInstallPolicy policy) - { - return new PipInstallArgs() - .AddArg("--upgrade") - .AddArgs( - policy.PreInstallPackageSpecifiers.Select(specifier => new Argument(specifier)).ToArray() - ); - } - - /// - /// Temporarily overrides a venv environment variable for a helper-owned install step and restores the - /// previous value afterward. - /// - private static async Task WithTemporaryEnvironmentOverrideAsync( - IPyVenvRunner venvRunner, - string key, - string value, - Func action - ) - { - venvRunner.EnvironmentVariables.TryGetValue(key, out var originalValue); - var hadOriginalValue = venvRunner.EnvironmentVariables.ContainsKey(key); - - venvRunner.UpdateEnvironmentVariables(env => env.SetItem(key, value)); - - try - { - await action().ConfigureAwait(false); - } - finally - { - venvRunner.UpdateEnvironmentVariables(env => - hadOriginalValue ? env.SetItem(key, originalValue!) : env.Remove(key) - ); - } - } - /// /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. /// @@ -655,141 +431,6 @@ CancellationToken cancellationToken _ = cancellationToken; } - /// - /// Runs rocm-sdk init after the helper-managed runtime packages are installed so the Windows ROCm SDK can prepare the venv. - /// - private static async Task InitializeWindowsNativeRocmSdkAsync( - string installLocation, - Action? onConsoleOutput, - CancellationToken cancellationToken - ) - { - var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); - if (!File.Exists(rocmSdkExe)) - { - throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); - } - - using var rocmSdkProcess = ProcessRunner.StartAnsiProcess( - rocmSdkExe, - ["init"], - installLocation, - onConsoleOutput - ); - - await rocmSdkProcess.WaitForExitAsync(cancellationToken).ConfigureAwait(false); - if (rocmSdkProcess.ExitCode != 0) - { - throw new ProcessException($"rocm-sdk init failed with code {rocmSdkProcess.ExitCode}"); - } - } - - /// - /// Uses AMD's bundled hipInfo.exe to confirm the installed Windows ROCm runtime can enumerate a ROCm-capable GPU. - /// - private static async Task VerifyWindowsNativeRocmRuntimeAsync( - string installLocation, - Action? onConsoleOutput, - CancellationToken cancellationToken - ) - { - var rocmSdkExe = Path.Combine(installLocation, "venv", "Scripts", "rocm-sdk.exe"); - if (!File.Exists(rocmSdkExe)) - { - throw new FileNotFoundException("rocm-sdk.exe was not installed", rocmSdkExe); - } - - var rocmBinResult = await ProcessRunner - .GetProcessResultAsync(rocmSdkExe, ["path", "--bin"], installLocation, useUtf8Encoding: true) - .ConfigureAwait(false); - - var rocmBinPath = (rocmBinResult.StandardOutput ?? string.Empty).Trim(); - if (!rocmBinResult.IsSuccessExitCode || string.IsNullOrWhiteSpace(rocmBinPath)) - { - var rocmBinOutput = CombineProcessOutput( - rocmBinResult.StandardOutput, - rocmBinResult.StandardError - ); - throw new ApplicationException( - $"ROCm runtime verification failed while resolving the ROCm SDK bin path. Output: {rocmBinOutput}" - ); - } - - var hipInfoExe = Path.Combine(rocmBinPath, $"hipInfo{Compat.ExeExtension}"); - if (!File.Exists(hipInfoExe)) - { - throw new FileNotFoundException( - "hipInfo.exe was not found in the ROCm SDK bin directory", - hipInfoExe - ); - } - - var hipInfoResult = await ProcessRunner - .GetProcessResultAsync( - hipInfoExe, - [], - installLocation, - new Dictionary { ["PATH"] = rocmBinPath }, - useUtf8Encoding: true - ) - .ConfigureAwait(false); - - var hipInfoOutput = CombineProcessOutput(hipInfoResult.StandardOutput, hipInfoResult.StandardError); - if (!hipInfoResult.IsSuccessExitCode) - { - var runtimeFailureReason = TryGetWindowsNativeRocmRuntimeFailureReason(hipInfoOutput); - throw new ApplicationException( - runtimeFailureReason is null - ? $"ROCm runtime verification failed while probing the installed runtime with hipInfo.exe. Output: {hipInfoOutput}" - : $"ROCm runtime verification failed: {runtimeFailureReason} Output: {hipInfoOutput}" - ); - } - - onConsoleOutput?.Invoke( - ProcessOutput.FromStdOutLine( - $"ROCm runtime verification succeeded via hipInfo.exe: {hipInfoOutput}" - ) - ); - - _ = cancellationToken; - } - - /// - /// Reinstalls rocm-sdk-devel to the resolved ROCm build version when the torch step downgrades the runtime stack. - /// - private static async Task AlignRocmSdkDevelVersionAsync( - IPyVenvRunner venvRunner, - string rocmPackageIndexUrl, - Action? onConsoleOutput - ) - { - var rocmInfo = await venvRunner.PipShow("rocm").ConfigureAwait(false); - var rocmSdkDevelInfo = await venvRunner.PipShow("rocm-sdk-devel").ConfigureAwait(false); - var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); - - var targetVersion = GetRocmSdkDevelAlignmentVersion( - rocmInfo?.Version, - rocmSdkDevelInfo?.Version, - torchInfo?.Version - ); - - if (string.IsNullOrWhiteSpace(targetVersion)) - return; - - onConsoleOutput?.Invoke( - ProcessOutput.FromStdErrLine( - $"Aligning rocm-sdk-devel from version={rocmSdkDevelInfo?.Version ?? "not-installed"} to version={targetVersion} to match the resolved ROCm torch/runtime build." - ) - ); - - var alignmentArgs = new PipInstallArgs() - .AddKeyedArgs("--index-url", ["--index-url", rocmPackageIndexUrl]) - .AddArg("--force-reinstall") - .AddArg($"rocm-sdk-devel=={targetVersion}"); - - await venvRunner.PipInstall(alignmentArgs, onConsoleOutput).ConfigureAwait(false); - } - internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion) { if (!string.IsNullOrWhiteSpace(hipVersion)) @@ -799,55 +440,6 @@ internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hi && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); } - internal static string? GetRocmSdkDevelAlignmentVersion( - string? rocmVersion, - string? rocmSdkDevelVersion, - string? torchVersion = null - ) - { - var targetVersion = !string.IsNullOrWhiteSpace(rocmVersion) - ? rocmVersion - : TryExtractRocmBuildVersion(torchVersion); - - if (string.IsNullOrWhiteSpace(targetVersion)) - return null; - - return string.Equals(targetVersion, rocmSdkDevelVersion, StringComparison.OrdinalIgnoreCase) - ? null - : targetVersion; - } - - internal static string? TryGetWindowsNativeRocmRuntimeFailureReason(string? output) - { - if (string.IsNullOrWhiteSpace(output)) - return null; - - if (output.Contains("no ROCm-capable device is detected", StringComparison.OrdinalIgnoreCase)) - { - return "the installed ROCm runtime could not detect a ROCm-capable GPU on this system."; - } - - if (output.Contains("No WDDM adapters found", StringComparison.OrdinalIgnoreCase)) - { - return "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack."; - } - - return null; - } - - internal static string? TryExtractRocmBuildVersion(string? torchVersion) - { - if (string.IsNullOrWhiteSpace(torchVersion)) - return null; - - var rocmMarkerIndex = torchVersion.IndexOf("rocm", StringComparison.OrdinalIgnoreCase); - if (rocmMarkerIndex < 0) - return null; - - var rocmBuildVersion = torchVersion[(rocmMarkerIndex + "rocm".Length)..].Trim(); - return string.IsNullOrWhiteSpace(rocmBuildVersion) ? null : rocmBuildVersion; - } - internal static string? TryExtractJsonObject(string output) { if (string.IsNullOrWhiteSpace(output)) @@ -871,15 +463,6 @@ internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hi return null; } - internal static string CombineProcessOutput(string? standardOutput, string? standardError) - { - var sections = new[] { standardOutput?.Trim(), standardError?.Trim() }.Where(section => - !string.IsNullOrWhiteSpace(section) - ); - - return string.Join(Environment.NewLine, sections); - } - /// /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. /// @@ -970,4 +553,17 @@ RocmEnvironmentOptions options return merged; } + + private sealed class ResolvedWindowsRocmState + { + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public string? MultiArchDeviceExtra { get; init; } + } } diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs index 0afb7b86..ad79a23b 100644 --- a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs +++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs @@ -10,53 +10,27 @@ namespace StabilityMatrix.Tests.Core; public class RocmPackageHelperTests { [TestMethod] - public void GetRocmSdkDevelAlignmentVersion_ReturnsRocmVersion_WhenVersionsMismatch() + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForSupportedArch() { - var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( - rocmVersion: "7.13.0a20260416", - rocmSdkDevelVersion: "7.13.0a20260501" - ); - - Assert.AreEqual("7.13.0a20260416", targetVersion); - } - - [TestMethod] - public void GetRocmSdkDevelAlignmentVersion_ReturnsNull_WhenVersionsAlreadyMatch() - { - var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( - rocmVersion: "7.13.0a20260416", - rocmSdkDevelVersion: "7.13.0a20260416" - ); - - Assert.IsNull(targetVersion); - } - - [TestMethod] - public void GetRocmSdkDevelAlignmentVersion_FallsBackToTorchBuildVersion() - { - var targetVersion = RocmPackageHelper.GetRocmSdkDevelAlignmentVersion( - rocmVersion: null, - rocmSdkDevelVersion: "7.13.0a20260501", - torchVersion: "2.11.0+rocm7.13.0a20260416" - ); + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx1201"); - Assert.AreEqual("7.13.0a20260416", targetVersion); + Assert.AreEqual("device-gfx1201", deviceExtra); } [TestMethod] - public void TryExtractRocmBuildVersion_ReturnsNull_WhenTorchVersionHasNoRocmTag() + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForCanonicalVega20Arch() { - var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0"); + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx906"); - Assert.IsNull(rocmBuildVersion); + Assert.AreEqual("device-gfx906", deviceExtra); } [TestMethod] - public void TryExtractRocmBuildVersion_ReturnsVersionSuffix_WhenTorchVersionContainsRocmTag() + public void WindowsRocmSupport_TryGetCanonicalArchitecture_ReturnsCanonicalArch_WhenAlreadyCanonical() { - var rocmBuildVersion = RocmPackageHelper.TryExtractRocmBuildVersion("2.11.0+rocm7.13.0a20260416"); + var canonicalArch = WindowsRocmSupport.TryGetCanonicalArchitecture("gfx906"); - Assert.AreEqual("7.13.0a20260416", rocmBuildVersion); + Assert.AreEqual("gfx906", canonicalArch); } [TestMethod] @@ -125,45 +99,11 @@ public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson() } [TestMethod] - public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsDeviceDetectionMessage() - { - const string output = "checkHipErrors() HIP API error = 0100 \"no ROCm-capable device is detected\""; - - var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); - - Assert.AreEqual( - "the installed ROCm runtime could not detect a ROCm-capable GPU on this system.", - reason - ); - } - - [TestMethod] - public void TryGetWindowsNativeRocmRuntimeFailureReason_ReturnsWddmMessage() - { - const string output = "warning: No WDDM adapters found."; - - var reason = RocmPackageHelper.TryGetWindowsNativeRocmRuntimeFailureReason(output); - - Assert.AreEqual( - "the ROCm runtime could not find any compatible WDDM adapters for the current GPU/driver stack.", - reason - ); - } - - [TestMethod] - public void CombineProcessOutput_JoinsStdoutAndStderr() - { - var combined = RocmPackageHelper.CombineProcessOutput("stdout line", "stderr line"); - - Assert.AreEqual($"stdout line{Environment.NewLine}stderr line", combined); - } - - [TestMethod] - public void WindowsRocmSupport_TryGetPackageIndexUrl_ReturnsExpectedIndex_ForKrakenPoint() + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForKrakenPoint() { - var indexUrl = WindowsRocmSupport.TryGetPackageIndexUrl("gfx1152"); + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx1152"); - Assert.AreEqual("https://rocm.nightlies.amd.com/v2-staging/gfx1152/", indexUrl); + Assert.AreEqual("device-gfx1152", deviceExtra); } [TestMethod] From 66ec7563c529880861b0f6723740b5f7d3438439 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Thu, 7 May 2026 21:03:28 -0400 Subject: [PATCH 13/18] SwarmUI integration for Win/ROCm EnVar passthrough. Modify aotriton support. - extract the shared Windows ComfyUI ROCm profile from package-local code - reuse the shared ROCm helper/profile for both direct ComfyUI launch behavior and SwarmUI self-launch pass-through - inject Windows ROCm ComfyUI env vars into SwarmUI launching so they propagate to the self-launched backend - Modified Torch ROCm AOTriton activation EnVar to exclude gfx1152/1153 due to no support yet. --- .../Helper/Factory/PackageFactory.cs | 3 +- .../Models/Packages/ComfyUI.cs | 24 ++------- .../Models/Packages/StableSwarm.cs | 50 ++++++++++++++++++- .../Models/Rocm/ComfyWindowsRocmProfile.cs | 23 +++++++++ .../Models/Rocm/WindowsRocmSupport.cs | 16 ++++++ .../Services/Rocm/RocmPackageHelper.cs | 2 +- 6 files changed, 94 insertions(+), 24 deletions(-) create mode 100644 StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 0cb0b808..3c031b2c 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -159,7 +159,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "automatic" => new VladAutomatic( githubApiCache, diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index edca7d00..003a9846 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -392,7 +392,7 @@ await rocmPackageHelper venvRunner, installLocation, installedPackage, - WindowsRocmProfile, + ComfyWindowsRocmProfile.Profile, progress, onConsoleOutput, cancellationToken @@ -637,24 +637,6 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } - /// Windows ROCm install profile for ComfyUI. - private static readonly RocmPackageProfile WindowsRocmProfile = new() - { - ExtraInstallPipArgs = ["numpy<2"], - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - UpgradePackages = true, - ExtraEnvironmentFactory = BuildComfyWindowsRocmEnvironment, - }; - - private static IReadOnlyDictionary BuildComfyWindowsRocmEnvironment( - RocmRuntimeContext runtimeContext - ) - { - return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) - ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } - : new Dictionary(); - } - /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. private bool HasWindowsRocmSupport() { @@ -668,7 +650,7 @@ private RocmCompatibilityResult GetWindowsRocmCompatibility() return new RocmCompatibilityResult { IsCompatible = false }; } - return rocmPackageHelper.GetCompatibility(WindowsRocmProfile); + return rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); } /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower @@ -1047,7 +1029,7 @@ InstalledPackage installedPackage if (rocmPackageHelper is null) return env; - var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(WindowsRocmProfile); + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(ComfyWindowsRocmProfile.Profile); return env.SetItems(rocmEnvironment); } diff --git a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs index d71e52f5..b76f2a56 100644 --- a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs +++ b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs @@ -10,9 +10,11 @@ using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -23,7 +25,8 @@ public class StableSwarm( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) : BaseGitPackage( githubApi, @@ -407,6 +410,7 @@ public override async Task RunPackage( } aspEnvVars.Update(settingsManager.Settings.EnvironmentVariables); + aspEnvVars.Update(BuildLinkedComfyLaunchEnvironment()); // Windows ROCm ComfyUI env var pass-through void HandleConsoleOutput(ProcessOutput s) { @@ -563,6 +567,50 @@ await prerequisiteHelper .ConfigureAwait(false); } + /// + /// Resolves the Comfy backend that Swarm is expected to self-launch. + /// + private InstalledPackage? TryResolveLinkedComfyBackend() + { + return settingsManager.Settings.InstalledPackages.FirstOrDefault(x => + x.PackageName is nameof(ComfyUI) or "ComfyUI-Zluda" + ); + } + + /// + /// Builds ROCm launch environment variables for Swarm so they flow through to its self-launched Comfy backend. + /// + private IReadOnlyDictionary BuildLinkedComfyLaunchEnvironment() + { + var comfyPackage = TryResolveLinkedComfyBackend(); + if (comfyPackage is null || !ShouldInjectLinkedComfyRocmEnvironment(comfyPackage)) + { + return new Dictionary(); + } + + return rocmPackageHelper.BuildLaunchEnvironment(ComfyWindowsRocmProfile.Profile); + } + + /// + /// Returns true only when the linked backend is standard ComfyUI on a supported Windows ROCm path. + /// + private bool ShouldInjectLinkedComfyRocmEnvironment(InstalledPackage comfyPackage) + { + if (!Compat.IsWindows || comfyPackage.PackageName != nameof(ComfyUI)) + { + return false; + } + + var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + if (!compatibility.IsCompatible) + { + return false; + } + + var selectedTorchIndex = comfyPackage.PreferredTorchIndex ?? TorchIndex.Rocm; + return selectedTorchIndex == TorchIndex.Rocm; + } + private Task SetupModelFoldersConfig(DirectoryPath installDirectory) { var settingsPath = GetSettingsPath(installDirectory); diff --git a/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs b/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs new file mode 100644 index 00000000..ea4035ec --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs @@ -0,0 +1,23 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Shared Windows ROCm profile for Comfy backends launched either directly by Stability Matrix or indirectly via SwarmUI. +/// +public static class ComfyWindowsRocmProfile +{ + public static RocmPackageProfile Profile { get; } = + new() + { + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = BuildEnvironment, + }; + + private static IReadOnlyDictionary BuildEnvironment(RocmRuntimeContext runtimeContext) + { + return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) + ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } + : new Dictionary(); + } +} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs index e784b615..49fe5d76 100644 --- a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -11,6 +11,14 @@ public static class WindowsRocmSupport public const string MultiArchPythonPackageIndexUrl = "https://rocm.nightlies.amd.com/whl-staging-multi-arch/"; + // Used to exclude modern gfxarches from AOTriton activation EnVar as AOTriton does not currently support them. + // This is a temporary measure until AOTriton adds support for these architectures. + private static readonly HashSet AotritonExperimentalExcludedArchitectures = + [ + "gfx1152", + "gfx1153", + ]; + public static bool IsSupportedGpu(GpuInfo? gpu) { if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) @@ -31,6 +39,14 @@ public static bool IsModernArchitecture(string? gfxArch) || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; } + public static bool SupportsAotritonExperimental(string? gfxArch) + { + var canonicalArch = TryGetCanonicalArchitecture(gfxArch); + return canonicalArch is not null + && IsModernArchitecture(canonicalArch) + && !AotritonExperimentalExcludedArchitectures.Contains(canonicalArch); + } + public static bool IsLegacyArchitecture(string? gfxArch) { return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch); diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index fc3d3bf2..9e21cc5c 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -492,7 +492,7 @@ RocmEnvironmentOptions options SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); - if (options.ApplyAotritonExperimental && WindowsRocmSupport.IsModernArchitecture(gfxArch)) + if (options.ApplyAotritonExperimental && WindowsRocmSupport.SupportsAotritonExperimental(gfxArch)) { environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; } From 6a1a5e4c8f993815cc6894e74a5e47d157a10c37 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Thu, 7 May 2026 21:25:48 -0400 Subject: [PATCH 14/18] Fix package link tests for ROCm helper DI --- StabilityMatrix.Tests/Models/Packages/PackageHelper.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs b/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs index b165031d..8e210690 100644 --- a/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs +++ b/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs @@ -6,6 +6,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Tests.Models.Packages; @@ -24,7 +25,8 @@ public static IEnumerable GetPackages() .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) - .AddSingleton(Substitute.For()); + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()); var assembly = typeof(BasePackage).Assembly; var packageTypes = assembly From 08949a30b9b98a161eb155ec021f6c10c73731d0 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Sat, 9 May 2026 19:27:28 -0400 Subject: [PATCH 15/18] Added Sage Attention v1 package command with Win/ROCm gating and specific handling and patching sourced from ComfyUI-Zluda. Pulls VSBuild Tools prerequisite auto install path as a requirement. --- .../InstallWindowsRocmSageAttentionStep.cs | 190 ++++++++++++++++++ .../Models/Packages/ComfyUI.cs | 165 ++++++++++----- 2 files changed, 303 insertions(+), 52 deletions(-) create mode 100644 StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs new file mode 100644 index 00000000..c410a888 --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs @@ -0,0 +1,190 @@ +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Core.Models.PackageModification; + +public class InstallWindowsRocmSageAttentionStep( + IDownloadService downloadService, + IPyInstallationManager pyInstallationManager, + IPrerequisiteHelper prerequisiteHelper, + IRocmPackageHelper rocmPackageHelper +) : IPackageStep +{ + private const string TritonWindowsVersion = "3.6.0.post25"; + private const string SageAttentionVersion = "1.0.6"; + + private const string AttnQkInt8PerBlockUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/attn_qk_int8_per_block.py"; + + private const string AttnQkInt8PerBlockCausalUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/attn_qk_int8_per_block_causal.py"; + + private const string QuantPerBlockUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/quant_per_block.py"; + + public required InstalledPackage InstalledPackage { get; init; } + public required DirectoryPath WorkingDirectory { get; init; } + public IReadOnlyDictionary? EnvironmentVariables { get; init; } + + public string ProgressTitle => "Installing Windows ROCm SageAttention"; + + public async Task ExecuteAsync(IProgress? progress = null) + { + if (!global::System.OperatingSystem.IsWindows()) + { + throw new PlatformNotSupportedException( + "Windows ROCm SageAttention installation is only supported on Windows." + ); + } + + if (!prerequisiteHelper.IsVcBuildToolsInstalled) + { + await prerequisiteHelper + .InstallPackageRequirements([PackagePrerequisite.VcBuildTools], progress: progress) + .ConfigureAwait(false); + } + + var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + if (!compatibility.IsCompatible) + { + throw new InvalidOperationException( + compatibility.FailureReason + ?? "Windows ROCm SageAttention requires a supported Windows ROCm machine state." + ); + } + + if (InstalledPackage.FullPath is null) + { + throw new InvalidOperationException("Installed package path is not available."); + } + + var venvDir = WorkingDirectory.JoinDir("venv"); + if (!venvDir.Exists) + { + throw new DirectoryNotFoundException($"ComfyUI venv was not found at '{venvDir.FullPath}'."); + } + + var pyVersion = PyVersion.Parse(InstalledPackage.PythonVersion); + if (pyVersion.StringValue == "0.0.0") + { + pyVersion = PyInstallationManager.Python_3_10_11; + } + + var baseInstall = !string.IsNullOrWhiteSpace(InstalledPackage.PythonVersion) + ? new PyBaseInstall( + await pyInstallationManager.GetInstallationAsync(pyVersion).ConfigureAwait(false) + ) + : PyBaseInstall.Default; + + await using var venvRunner = baseInstall.CreateVenvRunner( + venvDir, + workingDirectory: WorkingDirectory, + environmentVariables: EnvironmentVariables + ); + + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new InvalidOperationException( + "torch is not installed in this ComfyUI environment. Install the Windows ROCm torch build first." + ); + } + + if (!RocmPackageHelper.IsUsableWindowsNativeTorchBuild(torchInfo.Version, null)) + { + throw new InvalidOperationException( + $"Installed torch is not a usable Windows ROCm build (detected version: {torchInfo.Version})." + ); + } + + progress?.Report( + new ProgressReport( + -1f, + "Installing triton-windows for Windows ROCm SageAttention...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall($"triton-windows=={TritonWindowsVersion}").ConfigureAwait(false); + + progress?.Report( + new ProgressReport(-1f, "Installing SageAttention for Windows ROCm...", isIndeterminate: true) + ); + await venvRunner.PipInstall($"--no-deps sageattention=={SageAttentionVersion}").ConfigureAwait(false); + + var sageAttentionDir = WorkingDirectory.JoinDir("venv", "Lib", "site-packages", "sageattention"); + if (!sageAttentionDir.Exists) + { + throw new DirectoryNotFoundException( + $"Installed SageAttention package path was not found at '{sageAttentionDir.FullPath}'." + ); + } + + progress?.Report( + new ProgressReport(-1f, "Patching SageAttention for Windows ROCm...", isIndeterminate: true) + ); + + await DownloadAndReplaceFileAsync( + sageAttentionDir, + "attn_qk_int8_per_block.py", + AttnQkInt8PerBlockUrl, + progress + ) + .ConfigureAwait(false); + await DownloadAndReplaceFileAsync( + sageAttentionDir, + "attn_qk_int8_per_block_causal.py", + AttnQkInt8PerBlockCausalUrl, + progress + ) + .ConfigureAwait(false); + await DownloadAndReplaceFileAsync(sageAttentionDir, "quant_per_block.py", QuantPerBlockUrl, progress) + .ConfigureAwait(false); + } + + private async Task DownloadAndReplaceFileAsync( + DirectoryPath sageAttentionDir, + string fileName, + string sourceUrl, + IProgress? progress + ) + { + var targetFile = sageAttentionDir.JoinFile(fileName); + if (!targetFile.Exists) + { + throw new FileNotFoundException( + $"Expected SageAttention file '{fileName}' was not found.", + targetFile.FullPath + ); + } + + var backupFile = sageAttentionDir.JoinFile($"{fileName}.bak"); + if (!backupFile.Exists) + { + await backupFile + .WriteAllTextAsync(await targetFile.ReadAllTextAsync().ConfigureAwait(false)) + .ConfigureAwait(false); + } + + var tempFile = WorkingDirectory.JoinFile($"sm-rocm-sage-{fileName}.tmp"); + await downloadService.DownloadToFileAsync(sourceUrl, tempFile, progress).ConfigureAwait(false); + + try + { + var replacementContent = await tempFile.ReadAllTextAsync().ConfigureAwait(false); + await targetFile.WriteAllTextAsync(replacementContent).ConfigureAwait(false); + } + finally + { + if (tempFile.Exists) + { + await tempFile.DeleteAsync().ConfigureAwait(false); + } + } + } +} diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 003a9846..521548af 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -341,6 +341,17 @@ public override List GetExtraCommands() ); } + if (Compat.IsWindows && HasWindowsRocmSupport()) + { + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install Sage Attention", + Command = InstallWindowsRocmSageAttention, + } + ); + } + if (!Compat.IsMacOS && SettingsManager.Settings.PreferredGpu?.ComputeCapabilityValue is >= 7.5m) { commands.Add( @@ -424,51 +435,55 @@ await StandardPipInstallProcessAsync( .ConfigureAwait(false); } - try + if (!(Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport())) { - var sageVersion = await venvRunner.PipShow("sageattention").ConfigureAwait(false); - var torchVersion = await venvRunner.PipShow("torch").ConfigureAwait(false); - - if (torchVersion is not null && sageVersion is not null) + try { - var version = torchVersion.Version; - var plusPos = version.IndexOf('+'); - var index = plusPos >= 0 ? version[(plusPos + 1)..] : string.Empty; - var versionWithoutIndex = plusPos >= 0 ? version[..plusPos] : version; + var sageVersion = await venvRunner.PipShow("sageattention").ConfigureAwait(false); + var torchVersion = await venvRunner.PipShow("torch").ConfigureAwait(false); - if ( - !sageVersion.Version.Contains(index) || !sageVersion.Version.Contains(versionWithoutIndex) - ) + if (torchVersion is not null && sageVersion is not null) { - progress?.Report( - new ProgressReport(-1f, "Updating SageAttention...", isIndeterminate: true) - ); + var version = torchVersion.Version; + var plusPos = version.IndexOf('+'); + var index = plusPos >= 0 ? version[(plusPos + 1)..] : string.Empty; + var versionWithoutIndex = plusPos >= 0 ? version[..plusPos] : version; - var step = new InstallSageAttentionStep( - downloadService, - prerequisiteHelper, - pyInstallationManager + if ( + !sageVersion.Version.Contains(index) + || !sageVersion.Version.Contains(versionWithoutIndex) ) { - InstalledPackage = installedPackage, - IsBlackwellGpu = - SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() - ?? HardwareHelper.HasBlackwellGpu(), - WorkingDirectory = installLocation, - EnvironmentVariables = GetEnvVars( - venvRunner.EnvironmentVariables, - installLocation, - installedPackage - ), - }; - - await step.ExecuteAsync(progress).ConfigureAwait(false); + progress?.Report( + new ProgressReport(-1f, "Updating SageAttention...", isIndeterminate: true) + ); + + var step = new InstallSageAttentionStep( + downloadService, + prerequisiteHelper, + pyInstallationManager + ) + { + InstalledPackage = installedPackage, + IsBlackwellGpu = + SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() + ?? HardwareHelper.HasBlackwellGpu(), + WorkingDirectory = installLocation, + EnvironmentVariables = GetEnvVars( + venvRunner.EnvironmentVariables, + installLocation, + installedPackage + ), + }; + + await step.ExecuteAsync(progress).ConfigureAwait(false); + } } } - } - catch (Exception e) - { - Logger.Error(e, "Failed to verify/update SageAttention after installation"); + catch (Exception e) + { + Logger.Error(e, "Failed to verify/update SageAttention after installation"); + } } // Install Comfy Manager (built-in to ComfyUI) @@ -939,11 +954,59 @@ await PipWheelService if (runner.Failed) return; + await EnableSageAttentionAsync(installedPackage).ConfigureAwait(false); + } + + private async Task InstallWindowsRocmSageAttention(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm SageAttention installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmSageAttentionStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm SageAttention installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + + if (runner.Failed) + return; + + await EnableSageAttentionAsync(installedPackage).ConfigureAwait(false); + } + + private async Task EnableSageAttentionAsync(InstalledPackage installedPackage) + { await using var transaction = settingsManager.BeginTransaction(); - var attentionOptions = transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.Where(opt => opt.Name.Contains("attention")); + var packageInSettings = transaction.Settings.InstalledPackages.First(x => + x.Id == installedPackage.Id + ); + var attentionOptions = packageInSettings.LaunchArgs?.Where(opt => opt.Name.Contains("attention")); if (attentionOptions is not null) { foreach (var option in attentionOptions) @@ -952,9 +1015,9 @@ await PipWheelService } } - var sageAttention = transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.FirstOrDefault(opt => opt.Name.Contains("sage-attention")); + var sageAttention = packageInSettings.LaunchArgs?.FirstOrDefault(opt => + opt.Name.Contains("sage-attention") + ); if (sageAttention is not null) { @@ -962,16 +1025,14 @@ await PipWheelService } else { - transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.Add( - new LaunchOption - { - Name = "--use-sage-attention", - Type = LaunchOptionType.Bool, - OptionValue = true, - } - ); + packageInSettings.LaunchArgs?.Add( + new LaunchOption + { + Name = "--use-sage-attention", + Type = LaunchOptionType.Bool, + OptionValue = true, + } + ); } } From 1338612c6e64402fe7059ad5c09dd91d5de9e7e3 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Sat, 9 May 2026 22:08:03 -0400 Subject: [PATCH 16/18] ROCm package commands refactor. Flash Attention, Rocm Devel SDK, and bitsandbytes package commands. - replace the SageAttention-specific ROCm step with a generic Windows ROCm package-command step and shared command enum - move rocm-sdk-devel resolution and install logic into the ROCm helper, including nightly-date matching with fallback to the latest multi-arch build - gate Flash Attention to legacy ROCm architectures and bitsandbytes to Python 3.12 installs Non-ROCm Helper changes - Added an optional visibility predicate to extra package commands, enabling installed-package-instance-aware filtering so packages can declaratively hide commands that do not apply to the current install state (for example, the ComfyUI bitsandbytes command on non-Python-3.12 Windows ROCm installs). --- .../PackageManager/PackageCardViewModel.cs | 5 +- .../Models/ExtraPackageCommand.cs | 1 + ...> InstallWindowsRocmPackageCommandStep.cs} | 157 +++++++++++++++--- .../Models/Packages/ComfyUI.cs | 150 ++++++++++++++++- .../Services/Rocm/IRocmPackageHelper.cs | 11 ++ .../Services/Rocm/RocmPackageHelper.cs | 146 +++++++++++++++- 6 files changed, 434 insertions(+), 36 deletions(-) rename StabilityMatrix.Core/Models/PackageModification/{InstallWindowsRocmSageAttentionStep.cs => InstallWindowsRocmPackageCommandStep.cs} (57%) diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs index 4b5252ac..2fbc50c6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs @@ -170,7 +170,10 @@ partial void OnPackageChanged(InstalledPackage? value) // Set the extra commands if available from the package var packageExtraCommands = basePackage?.GetExtraCommands(); - ExtraCommands = packageExtraCommands?.Count > 0 ? packageExtraCommands : null; + var visibleExtraCommands = packageExtraCommands + ?.Where(command => command.IsVisible?.Invoke(value) ?? true) + .ToList(); + ExtraCommands = visibleExtraCommands?.Count > 0 ? visibleExtraCommands : null; runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged; EventManager.Instance.PackageRelaunchRequested += InstanceOnPackageRelaunchRequested; diff --git a/StabilityMatrix.Core/Models/ExtraPackageCommand.cs b/StabilityMatrix.Core/Models/ExtraPackageCommand.cs index 87398a8c..fa7219db 100644 --- a/StabilityMatrix.Core/Models/ExtraPackageCommand.cs +++ b/StabilityMatrix.Core/Models/ExtraPackageCommand.cs @@ -4,4 +4,5 @@ public class ExtraPackageCommand { public required string CommandName { get; set; } public required Func Command { get; set; } + public Func? IsVisible { get; set; } } diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs similarity index 57% rename from StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs rename to StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs index c410a888..a59f2bc5 100644 --- a/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmSageAttentionStep.cs +++ b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs @@ -9,13 +9,27 @@ namespace StabilityMatrix.Core.Models.PackageModification; -public class InstallWindowsRocmSageAttentionStep( +public enum WindowsRocmPackageCommandType +{ + SageAttention, + DevelopmentSdk, + BitsAndBytes, + FlashAttention, +} + +public class InstallWindowsRocmPackageCommandStep( IDownloadService downloadService, IPyInstallationManager pyInstallationManager, IPrerequisiteHelper prerequisiteHelper, IRocmPackageHelper rocmPackageHelper ) : IPackageStep { + private const string BitsAndBytesWheelUrl = + "https://github.com/0xDELUXA/bitsandbytes_win_rocm/releases/download/0.50.0.dev0-py3-rocm7-win_amd64_all/bitsandbytes-0.50.0.dev0-cp312-cp312-win_amd64.whl"; + private const string AmdAiterWheelUrl = + "https://github.com/0xDELUXA/flash-attention/releases/download/v2.8.4_win-rocm/amd_aiter-0.0.0-py3-none-win_amd64.whl"; + private const string FlashAttentionWheelUrl = + "https://github.com/0xDELUXA/flash-attention/releases/download/v2.8.4_win-rocm/flash_attn-2.8.4-py3-none-win_amd64.whl"; private const string TritonWindowsVersion = "3.6.0.post25"; private const string SageAttentionVersion = "1.0.6"; @@ -30,32 +44,24 @@ IRocmPackageHelper rocmPackageHelper public required InstalledPackage InstalledPackage { get; init; } public required DirectoryPath WorkingDirectory { get; init; } + public required WindowsRocmPackageCommandType CommandType { get; init; } public IReadOnlyDictionary? EnvironmentVariables { get; init; } - public string ProgressTitle => "Installing Windows ROCm SageAttention"; + public string ProgressTitle => CommandType switch + { + WindowsRocmPackageCommandType.SageAttention => "Installing Windows ROCm SageAttention", + WindowsRocmPackageCommandType.DevelopmentSdk => "Installing Windows ROCm Development SDK", + WindowsRocmPackageCommandType.BitsAndBytes => "Installing Windows ROCm bitsandbytes", + WindowsRocmPackageCommandType.FlashAttention => "Installing Windows ROCm Flash Attention", + _ => "Running Windows ROCm package command", + }; public async Task ExecuteAsync(IProgress? progress = null) { if (!global::System.OperatingSystem.IsWindows()) { throw new PlatformNotSupportedException( - "Windows ROCm SageAttention installation is only supported on Windows." - ); - } - - if (!prerequisiteHelper.IsVcBuildToolsInstalled) - { - await prerequisiteHelper - .InstallPackageRequirements([PackagePrerequisite.VcBuildTools], progress: progress) - .ConfigureAwait(false); - } - - var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); - if (!compatibility.IsCompatible) - { - throw new InvalidOperationException( - compatibility.FailureReason - ?? "Windows ROCm SageAttention requires a supported Windows ROCm machine state." + "Windows ROCm package commands are only supported on Windows." ); } @@ -88,20 +94,66 @@ await pyInstallationManager.GetInstallationAsync(pyVersion).ConfigureAwait(false environmentVariables: EnvironmentVariables ); - var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); - if (torchInfo is null) + switch (CommandType) { - throw new InvalidOperationException( - "torch is not installed in this ComfyUI environment. Install the Windows ROCm torch build first." - ); + case WindowsRocmPackageCommandType.SageAttention: + await ExecuteSageAttentionAsync(venvRunner, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.DevelopmentSdk: + await ExecuteDevelopmentSdkAsync(venvRunner, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.BitsAndBytes: + await ExecuteBitsAndBytesAsync(venvRunner, pyVersion, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.FlashAttention: + await ExecuteFlashAttentionAsync(venvRunner, progress).ConfigureAwait(false); + break; + default: + throw new InvalidOperationException( + $"Unsupported Windows ROCm package command type: {CommandType}." + ); } + } - if (!RocmPackageHelper.IsUsableWindowsNativeTorchBuild(torchInfo.Version, null)) + private void EnsureRocmCompatibility() + { + var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + if (!compatibility.IsCompatible) { throw new InvalidOperationException( - $"Installed torch is not a usable Windows ROCm build (detected version: {torchInfo.Version})." + compatibility.FailureReason + ?? "Windows ROCm package commands require a supported Windows ROCm machine state." ); } + } + + private async Task EnsureVcBuildToolsAsync(IProgress? progress) + { + if (!prerequisiteHelper.IsVcBuildToolsInstalled) + { + await prerequisiteHelper + .InstallPackageRequirements([PackagePrerequisite.VcBuildTools], progress: progress) + .ConfigureAwait(false); + } + } + + private async Task ExecuteDevelopmentSdkAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + await rocmPackageHelper.EnsureWindowsSdkDevelAsync(venvRunner, progress).ConfigureAwait(false); + } + + private async Task ExecuteSageAttentionAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + await EnsureVcBuildToolsAsync(progress).ConfigureAwait(false); + await rocmPackageHelper.EnsureWindowsSdkDevelAsync(venvRunner, progress).ConfigureAwait(false); progress?.Report( new ProgressReport( @@ -147,6 +199,57 @@ await DownloadAndReplaceFileAsync(sageAttentionDir, "quant_per_block.py", QuantP .ConfigureAwait(false); } + private async Task ExecuteBitsAndBytesAsync( + IPyVenvRunner venvRunner, + PyVersion pyVersion, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + + if (pyVersion.Major != 3 || pyVersion.Minor != 12) + { + throw new InvalidOperationException( + $"Windows ROCm bitsandbytes is only supported on Python 3.12.x (detected version: {pyVersion})." + ); + } + + progress?.Report( + new ProgressReport( + -1f, + "Installing bitsandbytes for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(BitsAndBytesWheelUrl).ConfigureAwait(false); + } + + private async Task ExecuteFlashAttentionAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + + progress?.Report( + new ProgressReport( + -1f, + "Installing Flash Attention dependencies for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(AmdAiterWheelUrl).ConfigureAwait(false); + + progress?.Report( + new ProgressReport( + -1f, + "Installing Flash Attention for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(FlashAttentionWheelUrl).ConfigureAwait(false); + } + private async Task DownloadAndReplaceFileAsync( DirectoryPath sageAttentionDir, string fileName, @@ -187,4 +290,4 @@ await backupFile } } } -} +} \ No newline at end of file diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index 521548af..6c2db121 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -346,10 +346,45 @@ public override List GetExtraCommands() commands.Add( new ExtraPackageCommand { - CommandName = "Install Sage Attention", + CommandName = "Install Triton and SageAttention (ROCm)", Command = InstallWindowsRocmSageAttention, } ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install Flash Attention (ROCm)", + Command = InstallWindowsRocmFlashAttention, + IsVisible = _ => + WindowsRocmSupport.IsLegacyArchitecture( + GetWindowsRocmCompatibility().ResolvedGfxArch + ), + } + ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install ROCm Development SDK", + Command = InstallWindowsRocmDevelopmentSdk, + } + ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install bitsandbytes (ROCm)", + Command = InstallWindowsRocmBitsAndBytes, + IsVisible = installedPackage => + { + if (!PyVersion.TryParse(installedPackage.PythonVersion, out var pyVersion)) + return false; + + return pyVersion.Major == 3 && pyVersion.Minor == 12; + }, + } + ); } if (!Compat.IsMacOS && SettingsManager.Settings.PreferredGpu?.ComputeCapabilityValue is >= 7.5m) @@ -975,7 +1010,7 @@ private async Task InstallWindowsRocmSageAttention(InstalledPackage? installedPa await runner .ExecuteSteps( [ - new InstallWindowsRocmSageAttentionStep( + new InstallWindowsRocmPackageCommandStep( downloadService, pyInstallationManager, prerequisiteHelper, @@ -985,6 +1020,7 @@ await runner ) ) { + CommandType = WindowsRocmPackageCommandType.SageAttention, InstalledPackage = installedPackage, WorkingDirectory = new DirectoryPath(installedPackage.FullPath), EnvironmentVariables = environmentVariables, @@ -999,6 +1035,116 @@ await runner await EnableSageAttentionAsync(installedPackage).ConfigureAwait(false); } + private async Task InstallWindowsRocmDevelopmentSdk(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm Development SDK installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm SDK installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.DevelopmentSdk, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + }, + ] + ) + .ConfigureAwait(false); + } + + private async Task InstallWindowsRocmBitsAndBytes(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm bitsandbytes installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm bitsandbytes installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.BitsAndBytes, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + } + + private async Task InstallWindowsRocmFlashAttention(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm Flash Attention installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm Flash Attention installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.FlashAttention, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + } + private async Task EnableSageAttentionAsync(InstalledPackage installedPackage) { await using var transaction = settingsManager.BeginTransaction(); diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs index 12c991c3..167bac9a 100644 --- a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -26,6 +26,17 @@ public interface IRocmPackageHelper /// IReadOnlyList GetWindowsLaunchNoticeLines(); + /// + /// Ensures a usable Windows ROCm SDK devel package is installed from the ROCm multi-arch index, + /// preferring the same nightly build date as the installed torch build and falling back to the latest available build. + /// + Task EnsureWindowsSdkDevelAsync( + IPyVenvRunner venvRunner, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); + /// /// Performs the Windows-native ROCm install flow for a package using helper-resolved multi-arch device extras. /// diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 8dc1c38e..29d39d1f 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -22,20 +22,26 @@ public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageH { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; + private const string RocmSdkDevelPackageName = "rocm-sdk-devel"; private static readonly string[] WindowsLaunchNoticeLines = [ "Stability Matrix Windows ROCm Notice: Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.", "Because this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself.", ]; - /// + /// + /// Evaluates the current Windows machine state for the given package profile and returns the resolved ROCm compatibility result. + /// public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) { _ = profile; return BuildCompatibilityResult(profile); } - /// + /// + /// Resolves launch-time ROCm runtime details from the current Windows machine state. + /// This is used to build helper-managed environment variables for package launch. + /// private RocmRuntimeContext ResolveRuntimeContext(RocmPackageProfile profile) { _ = profile; @@ -60,7 +66,10 @@ private RocmRuntimeContext ResolveRuntimeContext(RocmPackageProfile profile) }; } - /// + /// + /// Resolves install-time ROCm package selection details from the current Windows machine state. + /// This includes the canonical runtime GFX architecture and the matching multi-arch device extra. + /// private RocmInstallContext ResolveInstallContext(RocmPackageProfile profile) { _ = profile; @@ -74,7 +83,10 @@ private RocmInstallContext ResolveInstallContext(RocmPackageProfile profile) }; } - /// + /// + /// Builds the final launch environment for a ROCm-capable package by combining helper defaults, + /// package-specific environment values, and optional user overrides. + /// public IReadOnlyDictionary BuildLaunchEnvironment(RocmPackageProfile profile) { var runtimeContext = ResolveRuntimeContext(profile); @@ -95,13 +107,111 @@ public IReadOnlyDictionary BuildLaunchEnvironment(RocmPackagePro return mergedEnvironment; } - /// + /// + /// Returns the shared informational notice lines shown when launching Windows ROCm packages. + /// public IReadOnlyList GetWindowsLaunchNoticeLines() { return WindowsLaunchNoticeLines; } - /// + /// + /// Ensures rocm-sdk-devel is installed from the ROCm multi-arch index. + /// It prefers a build whose nightly date matches the installed ROCm torch build and falls back to the latest available build when no exact match is available. + /// + public async Task EnsureWindowsSdkDevelAsync( + IPyVenvRunner venvRunner, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new InvalidOperationException( + "torch is not installed in this environment. Install the Windows ROCm torch build first." + ); + } + + if (!IsUsableWindowsNativeTorchBuild(torchInfo.Version, null)) + { + throw new InvalidOperationException( + $"Installed torch is not a usable Windows ROCm build (detected version: {torchInfo.Version})." + ); + } + + var nightlyBuildDateToken = TryGetNightlyBuildDateToken(torchInfo.Version); + var installedRocmSdkDevel = await venvRunner.PipShow(RocmSdkDevelPackageName).ConfigureAwait(false); + if ( + !string.IsNullOrWhiteSpace(nightlyBuildDateToken) + && HasNightlyBuildDateToken(installedRocmSdkDevel?.Version, nightlyBuildDateToken) + ) + { + return; + } + + var indexResult = await venvRunner + .PipIndex(RocmSdkDevelPackageName, WindowsRocmSupport.MultiArchPythonPackageIndexUrl) + .ConfigureAwait(false); + + var latestVersion = indexResult?.AvailableVersions.FirstOrDefault(); + var matchingVersion = string.IsNullOrWhiteSpace(nightlyBuildDateToken) + ? null + : indexResult?.AvailableVersions.FirstOrDefault(version => + HasNightlyBuildDateToken(version, nightlyBuildDateToken) + ); + var versionToInstall = matchingVersion ?? latestVersion; + + if (string.IsNullOrWhiteSpace(versionToInstall)) + { + throw new InvalidOperationException( + $"No {RocmSdkDevelPackageName} builds were found on the ROCm multi-arch index." + ); + } + + if (!string.IsNullOrWhiteSpace(matchingVersion)) + { + progress?.Report( + new ProgressReport( + -1f, + $"Installing {RocmSdkDevelPackageName} {matchingVersion} for Windows ROCm...", + isIndeterminate: true + ) + ); + } + else + { + progress?.Report( + new ProgressReport( + -1f, + $"Falling back to latest available {RocmSdkDevelPackageName} build {versionToInstall} for Windows ROCm...", + isIndeterminate: true + ) + ); + } + + await venvRunner + .PipInstall( + new PipInstallArgs() + .AddArg("--upgrade") + .AddKeyedArgs( + "--index-url", + ["--index-url", WindowsRocmSupport.MultiArchPythonPackageIndexUrl] + ) + .AddArg($"{RocmSdkDevelPackageName}=={versionToInstall}"), + onConsoleOutput + ) + .ConfigureAwait(false); + + _ = cancellationToken; + } + + /// + /// Performs the shared Windows-native ROCm install flow for helper-managed packages. + /// This installs package requirements, the ROCm torch wheel set from the multi-arch index, + /// and then verifies that the resulting torch installation reports usable ROCm metadata. + /// public async Task InstallWindowsNativePackageAsync( IPyVenvRunner venvRunner, string installLocation, @@ -440,6 +550,30 @@ internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hi && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); } + private static string? TryGetNightlyBuildDateToken(string? version) + { + if (string.IsNullOrWhiteSpace(version)) + return null; + + var devIndex = version.IndexOf("dev", StringComparison.OrdinalIgnoreCase); + if (devIndex < 0) + return null; + + var startIndex = devIndex + 3; + if (version.Length < startIndex + 8) + return null; + + var token = version.Substring(startIndex, 8); + return token.All(char.IsDigit) ? token : null; + } + + private static bool HasNightlyBuildDateToken(string? version, string nightlyBuildDateToken) + { + return !string.IsNullOrWhiteSpace(version) + && !string.IsNullOrWhiteSpace(nightlyBuildDateToken) + && version.Contains($"dev{nightlyBuildDateToken}", StringComparison.OrdinalIgnoreCase); + } + internal static string? TryExtractJsonObject(string output) { if (string.IsNullOrWhiteSpace(output)) From f7b0d5639b2efccb3ccd2831ea823edc523862b4 Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Sun, 10 May 2026 19:22:37 -0400 Subject: [PATCH 17/18] Handle missing and prerelease ROCm SDK lookups in pip runners PyVenvRunner/UvVenvRunner changes: - treat pip show package-not-found results as a normal missing-package case instead of throwing from the venv runners - treat pip index versions no-match results as empty lookups instead of surfacing low-level process failures - add optional prerelease support to PipIndex so callers can query prerelease-only feeds ROCm Helper Change: - enable prerelease lookup for rocm-sdk-devel on the ROCm multi-arch index so nightly SDK builds are discoverable during Windows ROCm package commands. --- StabilityMatrix.Core/Python/IPyVenvRunner.cs | 6 +- StabilityMatrix.Core/Python/PyVenvRunner.cs | 70 +++++++++++++------ StabilityMatrix.Core/Python/UvVenvRunner.cs | 53 ++++++++++---- .../Services/Rocm/RocmPackageHelper.cs | 6 +- 4 files changed, 98 insertions(+), 37 deletions(-) diff --git a/StabilityMatrix.Core/Python/IPyVenvRunner.cs b/StabilityMatrix.Core/Python/IPyVenvRunner.cs index 6c1600b7..1b36418b 100644 --- a/StabilityMatrix.Core/Python/IPyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/IPyVenvRunner.cs @@ -91,7 +91,11 @@ Task Setup( /// /// Run a pip index command, return result as PipIndexResult. /// - Task PipIndex(string packageName, string? indexUrl = null); + Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ); /// /// Run a custom install command. Waits for the process to exit. diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs index 85fa0697..ff283fd9 100644 --- a/StabilityMatrix.Core/Python/PyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs @@ -332,10 +332,9 @@ public async Task> PipList() StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries ) .Select(line => line.Trim()) - .FirstOrDefault( - line => - line.StartsWith("[", StringComparison.OrdinalIgnoreCase) - && line.EndsWith("]", StringComparison.OrdinalIgnoreCase) + .FirstOrDefault(line => + line.StartsWith("[", StringComparison.OrdinalIgnoreCase) + && line.EndsWith("]", StringComparison.OrdinalIgnoreCase) ); if (jsonLine is null) @@ -370,6 +369,17 @@ public async Task> PipList() ) .ConfigureAwait(false); + var packageNotFound = + result.StandardOutput?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true + || result.StandardError?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true; + + if (packageNotFound) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -378,9 +388,11 @@ public async Task> PipList() ); } - if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:")) + if (string.IsNullOrWhiteSpace(result.StandardOutput)) { - return null; + throw new ProcessException( + $"pip show returned no output for package '{packageName}': {result.StandardError}" + ); } return PipShowResult.Parse(result.StandardOutput); @@ -389,7 +401,11 @@ public async Task> PipList() /// /// Run a pip index command, return result as PipIndexResult. /// - public async Task PipIndex(string packageName, string? indexUrl = null) + public async Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ) { if (!File.Exists(PipPath)) { @@ -413,10 +429,30 @@ public async Task> PipList() args = args.AddKeyedArgs("--index-url", ["--index-url", indexUrl]); } + if (includePrerelease) + { + args = args.AddArg("--pre"); + } + var result = await ProcessRunner .GetProcessResultAsync(PythonPath, args, WorkingDirectory?.FullPath, EnvironmentVariables) .ConfigureAwait(false); + var noMatchingDistribution = + result.StandardOutput?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true + || result.StandardError?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true; + + if (noMatchingDistribution || string.IsNullOrWhiteSpace(result.StandardOutput)) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -425,16 +461,6 @@ public async Task> PipList() ); } - if ( - string.IsNullOrEmpty(result.StandardOutput) - || result - .StandardOutput!.SplitLines() - .Any(l => l.StartsWith("ERROR: No matching distribution found")) - ) - { - return null; - } - return PipIndexResult.Parse(result.StandardOutput); } @@ -617,11 +643,11 @@ public void RunDetached( { // ReSharper disable once StringLiteralTypo var code = $""" - from importlib.metadata import entry_points - - results = entry_points(group='console_scripts', name='{entryPointName}') - print(tuple(results)[0].value, end='') - """; + from importlib.metadata import entry_points + + results = entry_points(group='console_scripts', name='{entryPointName}') + print(tuple(results)[0].value, end='') + """; var result = await Run($"-c \"{code}\"").ConfigureAwait(false); if (result.ExitCode == 0 && !string.IsNullOrWhiteSpace(result.StandardOutput)) diff --git a/StabilityMatrix.Core/Python/UvVenvRunner.cs b/StabilityMatrix.Core/Python/UvVenvRunner.cs index 53a295ab..6fa69fd6 100644 --- a/StabilityMatrix.Core/Python/UvVenvRunner.cs +++ b/StabilityMatrix.Core/Python/UvVenvRunner.cs @@ -386,6 +386,17 @@ public async Task> PipList() ) .ConfigureAwait(false); + var packageNotFound = + result.StandardOutput?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true + || result.StandardError?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true; + + if (packageNotFound) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -394,9 +405,11 @@ public async Task> PipList() ); } - if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:")) + if (string.IsNullOrWhiteSpace(result.StandardOutput)) { - return null; + throw new ProcessException( + $"pip show returned no output for package '{packageName}': {result.StandardError}" + ); } return PipShowResult.Parse(result.StandardOutput); @@ -405,7 +418,11 @@ public async Task> PipList() /// /// Run a pip index command, return result as PipIndexResult. /// - public async Task PipIndex(string packageName, string? indexUrl = null) + public async Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ) { if (!File.Exists(PipPath)) { @@ -429,10 +446,30 @@ public async Task> PipList() args = args.AddKeyedArgs("--index-url", ["--index-url", indexUrl]); } + if (includePrerelease) + { + args = args.AddArg("--pre"); + } + var result = await ProcessRunner .GetProcessResultAsync(PythonPath, args, WorkingDirectory?.FullPath, EnvironmentVariables) .ConfigureAwait(false); + var noMatchingDistribution = + result.StandardOutput?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true + || result.StandardError?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true; + + if (noMatchingDistribution || string.IsNullOrWhiteSpace(result.StandardOutput)) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -441,16 +478,6 @@ public async Task> PipList() ); } - if ( - string.IsNullOrEmpty(result.StandardOutput) - || result - .StandardOutput!.SplitLines() - .Any(l => l.StartsWith("ERROR: No matching distribution found")) - ) - { - return null; - } - return PipIndexResult.Parse(result.StandardOutput); } diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 29d39d1f..46e26aec 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -152,7 +152,11 @@ public async Task EnsureWindowsSdkDevelAsync( } var indexResult = await venvRunner - .PipIndex(RocmSdkDevelPackageName, WindowsRocmSupport.MultiArchPythonPackageIndexUrl) + .PipIndex( + RocmSdkDevelPackageName, + WindowsRocmSupport.MultiArchPythonPackageIndexUrl, + includePrerelease: true + ) .ConfigureAwait(false); var latestVersion = indexResult?.AvailableVersions.FirstOrDefault(); From 1e653e21f148b3368826d3e1e3daff84e66dcfed Mon Sep 17 00:00:00 2001 From: NeuralFault Date: Mon, 11 May 2026 19:55:49 -0400 Subject: [PATCH 18/18] Remove obsolete Windows ROCm HSA GFX override, updated GPUinfo gfxarch translation. Removed setting HSA_OVERRIDE_GFX_VERSION in helper-managed Windows ROCm launch environments, now unused RDNA1 override option from ROCm Environment setttings and deleted RDNA1 arch helper from WindowsRocmSupport.cs This is due to relying on per-GPU ROCm package selection instead of per-GFX-family masking since switching to Mult-Arch repo. --- .../Helper/HardwareInfo/GpuInfo.cs | 16 +++++++++++----- .../Models/Rocm/RocmEnvironmentOptions.cs | 5 ----- .../Models/Rocm/WindowsRocmSupport.cs | 5 ----- .../Services/Rocm/RocmPackageHelper.cs | 5 ----- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index 2bb0f4a8..6a791e7a 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -98,11 +98,14 @@ _ when Has("740M") || Has("760M") || Has("780M") || Has("Z1") || Has("Z2") => "g _ when Has("7400") || Has("7500") || Has("7600") || Has("7650") || Has("7700S") => "gfx1102", // RDNA3 dGPU Navi32 - _ when Has("7700") || Has("RX 7800") || HasNoSpace("RX7800") => "gfx1101", + _ when Has("7700") || Has("RX 7800") || Has("v710)") || HasNoSpace("RX7800") => "gfx1101", // RDNA3 dGPU Navi31 (incl. Pro) _ when Has("W7800") || Has("7900") || Has("7950") || Has("7990") => "gfx1100", + // RDNA2 Raphael APUs + _ when Has("Raphael") || Has("Radeon Graphics") || Has("AMD Radeon Graphics") => "gfx1036", + // RDNA2 APUs (Rembrandt) _ when Has("660M") || Has("680M") => "gfx1035", @@ -120,14 +123,17 @@ _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M" _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031", // RDNA2 Navi21 (big die) - _ when Has("6800") || Has("6900") || Has("6950") => "gfx1030", - - // RDNA1 Navi10 XT (incl. Pro card) - _ when Has("5600") || Has("5700") || Has("v520") => "gfx1010", + _ when Has("6800") || Has("6900") || Has("6950") || Has("v620") => "gfx1030", // RDNA1 Navi10 XTX _ when Has("5500") => "gfx1012", + //RDNA1 Pro Card + _ when Has("v520") => "gfx1011", + + // RDNA1 Navi10 XT + _ when Has("5600") || Has("5700") => "gfx1010", + // Vega/GCN5 Dedicated GPUs _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", _ when Has("radeon vii") || HasNoSpace("radeonvii") || Has("pro vii") || HasNoSpace("provii") => diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs index de5787dc..88216ec1 100644 --- a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -49,9 +49,4 @@ public class RocmEnvironmentOptions /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures. /// public bool ApplyLegacySdpFallback { get; init; } = true; - - /// - /// When true, helper-managed defaults will apply the RDNA1 HSA override mask when needed. - /// - public bool ApplyRdna1Override { get; init; } = true; } diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs index 49fe5d76..b6533cab 100644 --- a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -57,11 +57,6 @@ public static bool PreferLegacyAttentionFallback(string? gfxArch) return IsLegacyArchitecture(gfxArch); } - public static bool IsRdna1Architecture(string? gfxArch) - { - return gfxArch?.StartsWith("gfx101", StringComparison.OrdinalIgnoreCase) == true; - } - public static string? TryGetCanonicalArchitecture(string? gfxArch) { if (string.IsNullOrWhiteSpace(gfxArch)) diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs index 46e26aec..ba4c1e24 100644 --- a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -641,11 +641,6 @@ RocmEnvironmentOptions options environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; } - - if (options.ApplyRdna1Override && WindowsRocmSupport.IsRdna1Architecture(gfxArch)) - { - environment["HSA_OVERRIDE_GFX_VERSION"] = "10.1.0"; - } } private static void SetIfNotNull(IDictionary environment, string key, string? value)