diff --git a/python/private/pypi/pypi_repo_utils.bzl b/python/private/pypi/pypi_repo_utils.bzl index d8e320014f..8ec7bd1dbe 100644 --- a/python/private/pypi/pypi_repo_utils.bzl +++ b/python/private/pypi/pypi_repo_utils.bzl @@ -177,7 +177,6 @@ def _find_namespace_package_files(rctx, install_dir): to namespace packages. """ - repo_root = str(rctx.path(".")) + "/" namespace_package_files = [] for top_level_dir in install_dir.readdir(): if not is_importable_name(top_level_dir.basename): @@ -192,7 +191,9 @@ def _find_namespace_package_files(rctx, install_dir): if ("__path__ =" in content and "pkgutil" in content and "extend_path(" in content): - namespace_package_files.append(str(init_py).removeprefix(repo_root)) + namespace_package_files.append( + repo_utils.repo_root_relative_path(rctx, init_py), + ) return namespace_package_files diff --git a/python/private/repo_utils.bzl b/python/private/repo_utils.bzl index a558fa08e1..7ec45eda5b 100644 --- a/python/private/repo_utils.bzl +++ b/python/private/repo_utils.bzl @@ -334,7 +334,7 @@ def _mkdir(mrctx, path): repo_root = str(mrctx.path(".")) path_str = str(path) - if not path_str.startswith(repo_root): + if not _is_relative_to(mrctx, path_str, repo_root): mkdir_bin = mrctx.which("mkdir") if not mkdir_bin: return None @@ -348,6 +348,30 @@ def _mkdir(mrctx, path): mrctx.delete(placeholder) return path +def _norm_path(mrctx, p): + p = str(p) + + # Windows is case-insensitive + if _get_platforms_os_name(mrctx) == "windows": + return p.lower() + return p + +def _relative_to(mrctx, path, parent, fail = fail): + path_str = str(path) + parent_str = str(parent) + path_d = _norm_path(mrctx, path_str) + "/" + parent_d = _norm_path(mrctx, parent_str) + "/" + if path_d.startswith(parent_d): + return path_str[len(parent_str):].removeprefix("/") + else: + fail("{} is not relative to {}".format(path, parent)) + +def _is_relative_to(mrctx, path, parent): + """Tell if `path` is equal to or beneath `parent`.""" + path_d = _norm_path(mrctx, path) + "/" + parent_d = _norm_path(mrctx, parent) + "/" + return path_d.startswith(parent_d) + def _repo_root_relative_path(mrctx, path): """Takes a path object and returns a repo-relative path string. @@ -360,14 +384,7 @@ def _repo_root_relative_path(mrctx, path): """ repo_root = str(mrctx.path(".")) path_str = str(path) - relative_path = path_str[len(repo_root):] - if relative_path[0] != "/": - fail("{path} not under {repo_root}".format( - path = path, - repo_root = repo_root, - )) - relative_path = relative_path[1:] - return relative_path + return _relative_to(mrctx, path_str, repo_root) def _args_to_str(arguments): return " ".join([_arg_repr(a) for a in arguments]) @@ -516,6 +533,9 @@ repo_utils = struct( is_repo_debug_enabled = _is_repo_debug_enabled, logger = _logger, mkdir = _mkdir, + norm_path = _norm_path, + relative_to = _relative_to, + is_relative_to = _is_relative_to, repo_root_relative_path = _repo_root_relative_path, which_checked = _which_checked, which_unchecked = _which_unchecked, diff --git a/tests/repo_utils/BUILD.bazel b/tests/repo_utils/BUILD.bazel new file mode 100644 index 0000000000..74e8e37489 --- /dev/null +++ b/tests/repo_utils/BUILD.bazel @@ -0,0 +1,3 @@ +load(":repo_utils_test.bzl", "repo_utils_test_suite") + +repo_utils_test_suite(name = "repo_utils_tests") diff --git a/tests/repo_utils/repo_utils_test.bzl b/tests/repo_utils/repo_utils_test.bzl new file mode 100644 index 0000000000..ce9e48b5a6 --- /dev/null +++ b/tests/repo_utils/repo_utils_test.bzl @@ -0,0 +1,59 @@ +"""Unit tests for repo_utils.bzl.""" + +load("@rules_testing//lib:test_suite.bzl", "test_suite") +load("//python/private:repo_utils.bzl", "repo_utils") # buildifier: disable=bzl-visibility +load("//tests/support:mocks.bzl", "mocks") + +_tests = [] + +def _test_get_platforms_os_name(env): + mock_mrctx = mocks.rctx(os_name = "Mac OS X") + got = repo_utils.get_platforms_os_name(mock_mrctx) + env.expect.that_str(got).equals("osx") + +_tests.append(_test_get_platforms_os_name) + +def _test_relative_to(env): + mock_mrctx_linux = mocks.rctx(os_name = "linux") + mock_mrctx_win = mocks.rctx(os_name = "windows") + + # Case-sensitive matching (Linux) + got = repo_utils.relative_to(mock_mrctx_linux, "foo/bar/baz", "foo/bar") + env.expect.that_str(got).equals("baz") + + # Case-insensitive matching (Windows) + got = repo_utils.relative_to(mock_mrctx_win, "C:/Foo/Bar/Baz", "c:/foo/bar") + env.expect.that_str(got).equals("Baz") + + # Failure case + failures = [] + + def _mock_fail(msg): + failures.append(msg) + + repo_utils.relative_to(mock_mrctx_linux, "foo/bar/baz", "qux", fail = _mock_fail) + env.expect.that_collection(failures).contains_exactly(["foo/bar/baz is not relative to qux"]) + +_tests.append(_test_relative_to) + +def _test_is_relative_to(env): + mock_mrctx_linux = mocks.rctx(os_name = "linux") + mock_mrctx_win = mocks.rctx(os_name = "windows") + + # Case-sensitive matching (Linux) + env.expect.that_bool(repo_utils.is_relative_to(mock_mrctx_linux, "foo/bar/baz", "foo/bar")).equals(True) + env.expect.that_bool(repo_utils.is_relative_to(mock_mrctx_linux, "foo/bar/baz", "qux")).equals(False) + + # Case-insensitive matching (Windows) + env.expect.that_bool(repo_utils.is_relative_to(mock_mrctx_win, "C:/Foo/Bar/Baz", "c:/foo/bar")).equals(True) + env.expect.that_bool(repo_utils.is_relative_to(mock_mrctx_win, "C:/Foo/Bar/Baz", "D:/Foo")).equals(False) + +_tests.append(_test_is_relative_to) + +def repo_utils_test_suite(name): + """Create the test suite. + + Args: + name: the name of the test suite + """ + test_suite(name = name, basic_tests = _tests) diff --git a/tests/support/mocks.bzl b/tests/support/mocks.bzl new file mode 100644 index 0000000000..2a4ccd0fc4 --- /dev/null +++ b/tests/support/mocks.bzl @@ -0,0 +1,31 @@ +"""Mocks for testing.""" + +def _rctx(os_name = "linux", os_arch = "x86_64", environ = None, **kwargs): + """Creates a mock of repository_ctx or module_ctx. + + Args: + os_name: The OS name to mock (e.g., "linux", "Mac OS X", "windows"). + os_arch: The OS architecture to mock (e.g., "x86_64", "aarch64"). + environ: A dictionary representing the environment variables. + **kwargs: Additional attributes to add to the mock struct. + + Returns: + A struct mocking repository_ctx. + """ + if environ == None: + environ = {} + + attrs = { + "getenv": environ.get, + "os": struct( + name = os_name, + arch = os_arch, + ), + } + attrs.update(kwargs) + + return struct(**attrs) + +mocks = struct( + rctx = _rctx, +)