Skip to content
10 changes: 9 additions & 1 deletion nodescraper/plugins/inband/package/analyzer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@
class PackageAnalyzerArgs(AnalyzerArgs):
exp_package_ver: Dict[str, Optional[str]] = Field(default_factory=dict)
regex_match: bool = False
# rocm_regex is optional and should be specified in plugin_config.json if needed
rocm_regex: Optional[str] = None
enable_rocm_regex: bool = False

@classmethod
def build_from_model(cls, datamodel: PackageDataModel) -> "PackageAnalyzerArgs":
return cls(exp_package_ver=datamodel.version_info)
# Use custom rocm_regex from collection_args if enable_rocm_regex is true
rocm_regex = None
if datamodel.enable_rocm_regex and datamodel.rocm_regex:
rocm_regex = datamodel.rocm_regex

return cls(exp_package_ver=datamodel.version_info, rocm_regex=rocm_regex)
60 changes: 57 additions & 3 deletions nodescraper/plugins/inband/package/package_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from nodescraper.models import TaskResult
from nodescraper.utils import get_exception_details

from .analyzer_args import PackageAnalyzerArgs
from .packagedata import PackageDataModel


class PackageCollector(InBandDataCollector[PackageDataModel, None]):
class PackageCollector(InBandDataCollector[PackageDataModel, PackageAnalyzerArgs]):
"""Collecting Package information from the system"""

DATA_MODEL = PackageDataModel
Expand Down Expand Up @@ -181,9 +182,34 @@ def _handle_command_failure(self, command_artifact: CommandArtifact):
self.result.message = "Failed to run Package Manager command"
self.result.status = ExecutionStatus.EXECUTION_FAILURE

def collect_data(self, args=None) -> tuple[TaskResult, Optional[PackageDataModel]]:
def _filter_rocm_packages(self, packages: dict[str, str], rocm_pattern: str) -> dict[str, str]:
"""Filter ROCm-related packages from a package dictionary.

This method searches package names for ROCm-related patterns and returns
only the matching packages.

Args:
packages (dict[str, str]): Dictionary with package names as keys and versions as values.
rocm_pattern (str): Regex pattern to match ROCm-related package names.

Returns:
dict[str, str]: Filtered dictionary containing only ROCm-related packages.
"""
rocm_packages = {}
pattern = re.compile(rocm_pattern, re.IGNORECASE)
for package_name, version in packages.items():
if pattern.search(package_name):
rocm_packages[package_name] = version
return rocm_packages

def collect_data(
self, args: Optional[PackageAnalyzerArgs] = None
) -> tuple[TaskResult, Optional[PackageDataModel]]:
"""Collect package information from the system.

Args:
args (Optional[PackageAnalyzerArgs]): Optional arguments containing ROCm regex pattern.

Returns:
tuple[TaskResult, Optional[PackageDataModel]]: tuple containing the task result and a PackageDataModel instance
with the collected package information, or None if there was an error.
Expand All @@ -205,8 +231,36 @@ def collect_data(self, args=None) -> tuple[TaskResult, Optional[PackageDataModel
self.result.message = "Unsupported OS"
self.result.status = ExecutionStatus.NOT_RAN
return self.result, None

# Filter and log ROCm packages if on Linux and rocm_regex is provided
if self.system_info.os_family == OSFamily.LINUX and packages:
# Get ROCm pattern from args if provided
rocm_pattern = args.rocm_regex if args else None
if rocm_pattern:
self.logger.info("Using rocm_pattern: %s", rocm_pattern)
rocm_packages = self._filter_rocm_packages(packages, rocm_pattern)
if rocm_packages:
self.result.message = (
f"Found {len(rocm_packages)} ROCm-related packages installed"
)
self.result.status = ExecutionStatus.OK
self._log_event(
category=EventCategory.OS,
description=f"Found {len(rocm_packages)} ROCm-related packages installed",
priority=EventPriority.INFO,
data={"rocm_packages": sorted(rocm_packages.keys())},
)
else:
self.logger.info("No rocm_regex provided, skipping ROCm package filtering")

# Extract rocm_regex and enable_rocm_regex from args if provided
rocm_regex = args.rocm_regex if (args and args.rocm_regex) else ""
enable_rocm_regex = getattr(args, "enable_rocm_regex", False) if args else False

try:
package_model = PackageDataModel(version_info=packages)
package_model = PackageDataModel(
version_info=packages, rocm_regex=rocm_regex, enable_rocm_regex=enable_rocm_regex
)
except ValidationError as val_err:
self._log_event(
category=EventCategory.RUNTIME,
Expand Down
4 changes: 4 additions & 0 deletions nodescraper/plugins/inband/package/packagedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class PackageDataModel(DataModel):
Attributes:
version_info (dict[str, str]): The version information for the package
Key is the package name and value is the version of the package
rocm_regex (str): Regular expression pattern for ROCm package filtering
enable_rocm_regex (bool): Whether to use custom ROCm regex from collection_args
"""

version_info: dict[str, str]
rocm_regex: str = ""
enable_rocm_regex: bool = False
4 changes: 4 additions & 0 deletions test/functional/fixtures/package_plugin_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
"global_args": {},
"plugins": {
"PackagePlugin": {
"collection_args": {
"rocm_regex": "rocm|hip|hsa|amdgpu",
"enable_rocm_regex": true
},
"analysis_args": {
"exp_package_ver": {
"gcc": "11.4.0"
Expand Down
109 changes: 109 additions & 0 deletions test/unit/plugin/test_package_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,112 @@ def test_bad_splits_ubuntu(collector, conn_mock, command_results):
]
res, _ = collector.collect_data()
assert res.status == ExecutionStatus.OK


def test_rocm_package_filtering_custom_regex(collector, conn_mock, command_results):
"""Test ROCm package filtering with custom regex pattern."""
from nodescraper.plugins.inband.package.analyzer_args import PackageAnalyzerArgs

# Mock Ubuntu system with ROCm packages
ubuntu_packages = """rocm-core 5.7.0
hip-runtime-amd 5.7.0
hsa-rocr 1.9.0
amdgpu-dkms 6.3.6
gcc 11.4.0
python3 3.10.12"""

conn_mock.run_command.side_effect = [
CommandArtifact(
command="",
exit_code=0,
stdout=command_results["ubuntu_rel"],
stderr="",
),
CommandArtifact(
command="",
exit_code=0,
stdout=ubuntu_packages,
stderr="",
),
]

# Use custom regex that only matches 'rocm' and 'hip'
args = PackageAnalyzerArgs(rocm_regex="rocm|hip")
res, data = collector.collect_data(args)
assert res.status == ExecutionStatus.OK
# Check that ROCm packages are found
assert "found 2 rocm-related packages" in res.message.lower()
assert data is not None


def test_rocm_package_filtering_no_matches(collector, conn_mock, command_results):
"""Test ROCm package filtering when no ROCm packages are installed."""
from nodescraper.plugins.inband.package.analyzer_args import PackageAnalyzerArgs

# Mock Ubuntu system without ROCm packages
ubuntu_packages = """gcc 11.4.0
python3 3.10.12
vim 8.2.3995"""

conn_mock.run_command.side_effect = [
CommandArtifact(
command="",
exit_code=0,
stdout=command_results["ubuntu_rel"],
stderr="",
),
CommandArtifact(
command="",
exit_code=0,
stdout=ubuntu_packages,
stderr="",
),
]

args = PackageAnalyzerArgs(rocm_regex="rocm|hip|hsa")
res, data = collector.collect_data(args)
assert res.status == ExecutionStatus.OK
# No ROCm packages found, so message should not mention them
assert "rocm" not in res.message.lower() or res.message == ""
assert data is not None
assert len(data.version_info) == 3


def test_filter_rocm_packages_method(collector):
"""Test _filter_rocm_packages method directly."""
packages = {
"rocm-core": "5.7.0",
"hip-runtime-amd": "5.7.0",
"hsa-rocr": "1.9.0",
"amdgpu-dkms": "6.3.6",
"gcc": "11.4.0",
"python3": "3.10.12",
}

# Test with default-like pattern
rocm_pattern = "rocm|hip|hsa|amdgpu"
filtered = collector._filter_rocm_packages(packages, rocm_pattern)

assert len(filtered) == 4
assert "rocm-core" in filtered
assert "hip-runtime-amd" in filtered
assert "hsa-rocr" in filtered
assert "amdgpu-dkms" in filtered
assert "gcc" not in filtered
assert "python3" not in filtered


def test_filter_rocm_packages_case_insensitive(collector):
"""Test that ROCm package filtering is case-insensitive."""
packages = {
"ROCM-Core": "5.7.0",
"HIP-Runtime-AMD": "5.7.0",
"gcc": "11.4.0",
}

rocm_pattern = "rocm|hip"
filtered = collector._filter_rocm_packages(packages, rocm_pattern)

assert len(filtered) == 2
assert "ROCM-Core" in filtered
assert "HIP-Runtime-AMD" in filtered