Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions problems/linalg/qr_py/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ def _band_mask(n: int, bandwidth: int, device: torch.device) -> torch.Tensor:
return (idx[:, None] - idx[None, :]).abs() <= bandwidth


def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t:
assert batch > 0, "batch must be positive"
assert n > 0, "n must be positive"
assert cond >= 0, "cond must be non-negative"

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = torch.Generator(device=device)
gen.manual_seed(seed)
_MIXED_PROFILES = (
"dense",
"rankdef",
"nearrank",
"clustered",
"band",
"rowscale",
"nearcollinear",
)
_MIXED_WEIGHTS = (6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)

case = case.lower()
a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen)

def _apply_case(a: torch.Tensor, case: str, cond: int, gen: torch.Generator) -> torch.Tensor:
batch, n = a.shape[0], a.shape[-1]
device = a.device
if case == "dense":
a = _apply_column_scaling(a, cond)
elif case == "upper":
Expand Down Expand Up @@ -83,6 +86,48 @@ def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense"
a = scales.reshape(1, n, 1) * a
else:
raise ValueError(f"unknown QR test case: {case}")
return a


def _generate_mixed(a: torch.Tensor, cond: int, gen: torch.Generator) -> torch.Tensor:
batch = a.shape[0]
device = a.device
weights = torch.tensor(_MIXED_WEIGHTS, dtype=torch.float32, device=device)
labels = torch.multinomial(weights, batch, replacement=True, generator=gen)

if batch >= 2:
is_dense = labels == 0
if not bool(is_dense.any()):
labels[int(torch.randint(0, batch, (1,), device=device, generator=gen))] = 0
elif bool(is_dense.all()):
pos = int(torch.randint(0, batch, (1,), device=device, generator=gen))
labels[pos] = int(
torch.randint(1, len(_MIXED_PROFILES), (1,), device=device, generator=gen)
)

for idx, profile in enumerate(_MIXED_PROFILES):
mask = labels == idx
if bool(mask.any()):
a[mask] = _apply_case(a[mask], profile, cond, gen)
return a


def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t:
assert batch > 0, "batch must be positive"
assert n > 0, "n must be positive"
assert cond >= 0, "cond must be non-negative"

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = torch.Generator(device=device)
gen.manual_seed(seed)

case = case.lower()
a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen)

if case == "mixed":
a = _generate_mixed(a, cond, gen)
else:
a = _apply_case(a, case, cond, gen)

return a.contiguous()

Expand Down Expand Up @@ -143,27 +188,44 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]:

q = torch.linalg.householder_product(h, tau)
r = torch.triu(h)
if not torch.isfinite(q).all().item():
return False, "Q materialized from `(H, tau)` contains NaN or Inf"
if not torch.isfinite(r).all().item():
return False, "R extracted from `triu(H)` contains NaN or Inf"

a_check = a.double()
q_check = q.double()
r_check = r.double()
projected = q_check.transpose(-1, -2) @ a_check
factor_residual = _matrix_l1_norm(r_check - projected).amax()
factor_scale = _matrix_l1_norm(a_check).amax()
if not torch.isfinite(projected).all().item():
return False, "Q.T @ A contains NaN or Inf"

factor_residual = _matrix_l1_norm(r_check - projected)
factor_scale = _matrix_l1_norm(a_check)
factor_allowed = factor_rtol * factor_scale
factor_scaled = _scaled_residual(factor_residual, factor_scale, n)
if factor_residual.item() > factor_allowed.item():
if not torch.isfinite(factor_scaled).all().item():
return False, "R - Q.T @ A residual produced NaN or Inf"
factor_failed = factor_residual > factor_allowed
if bool(factor_failed.any().item()):
worst = int(factor_scaled.argmax().item())
return False, (
"R - Q.T @ A is too large: "
f"residual={factor_residual.item():.3g}, allowed={factor_allowed.item():.3g}, "
f"scaled={factor_scaled.item():.3g}"
f"matrix={worst}, residual={factor_residual[worst].item():.3g}, "
f"allowed={factor_allowed[worst].item():.3g}, "
f"scaled={factor_scaled[worst].item():.3g}"
)

eye = torch.eye(n, device=a.device, dtype=torch.float64).expand(batch, n, n)
qtq = q_check.transpose(-1, -2) @ q_check
if not torch.isfinite(qtq).all().item():
return False, "Q.T @ Q contains NaN or Inf"
orth_residual = _matrix_l1_norm(qtq - eye).amax()
orth_scale = _matrix_l1_norm(eye).amax()
orth_allowed = orth_rtol * orth_scale
orth_scaled = _scaled_residual(orth_residual, orth_scale, n)
if not torch.isfinite(orth_scaled).all().item():
return False, "Q.T @ Q residual produced NaN or Inf"
if orth_residual.item() > orth_allowed.item():
return False, (
"Q is not orthogonal enough: "
Expand All @@ -177,14 +239,16 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]:
tri_scaled = _scaled_residual(tri_residual, tri_scale, n)

recon = q_check @ r_check
if not torch.isfinite(recon).all().item():
return False, "Q @ R contains NaN or Inf"
recon_residual = _matrix_l1_norm(recon - a_check).amax()
recon_scale = _matrix_l1_norm(a_check).amax()
recon_scaled = _scaled_residual(recon_residual, recon_scale, n)

return True, (
f"factor_rtol={factor_rtol:.3g}; "
f"orth_rtol={orth_rtol:.3g}; "
f"scaled_factor_residual={factor_scaled.item():.3g}; "
f"scaled_factor_residual={factor_scaled.amax().item():.3g}; "
f"scaled_reconstruction_residual={recon_scaled.item():.3g}; "
f"scaled_triangular_residual={tri_scaled.item():.3g}; "
f"scaled_orthogonality_residual={orth_scaled.item():.3g}; "
Expand Down
16 changes: 16 additions & 0 deletions problems/linalg/qr_py/task.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ description: |
structure, such as rank-deficient, near-rank-deficient, banded, row-scaled,
near-collinear, upper-triangular, or clustered-scale inputs.

The `mixed` case builds a heterogeneous batch: each matrix is independently
assigned a conditioning profile at a random seeded position in the batch. This
mirrors the optimizer-statistics regime, where factors batched into one call
can have very different conditioning, rather than all sharing one structure.
Benchmarks include both `mixed` batches and fully ill-conditioned homogeneous
batches, so conditioning robustness is ranked, not only gated. Each matrix in
the batch must be factored correctly on its own merits.

Correctness is a hard gate against the original FP32 input and the FP32
`torch.geqrf` compact-factor contract. Low-bit FP16, FP8, or NVFP4 work is
allowed only as an internal implementation strategy: returned factors must
Expand Down Expand Up @@ -89,6 +97,9 @@ tests:
- {"batch": 2, "n": 2048, "cond": 2, "seed": 224466, "case": "dense"}
- {"batch": 2, "n": 2048, "cond": 0, "seed": 224467, "case": "rankdef"}
- {"batch": 1, "n": 4096, "cond": 0, "seed": 75343, "case": "upper"}
- {"batch": 16, "n": 512, "cond": 2, "seed": 32530, "case": "mixed"}
- {"batch": 4, "n": 1024, "cond": 2, "seed": 4332, "case": "mixed"}
- {"batch": 2, "n": 2048, "cond": 2, "seed": 224468, "case": "mixed"}

benchmarks:
- {"batch": 20, "n": 32, "cond": 1, "seed": 43214}
Expand All @@ -98,3 +109,8 @@ benchmarks:
- {"batch": 60, "n": 1024, "cond": 2, "seed": 75342}
- {"batch": 8, "n": 2048, "cond": 1, "seed": 224466}
- {"batch": 2, "n": 4096, "cond": 1, "seed": 32412}
- {"batch": 640, "n": 512, "cond": 2, "seed": 770001, "case": "mixed"}
- {"batch": 60, "n": 1024, "cond": 2, "seed": 770002, "case": "mixed"}
- {"batch": 640, "n": 512, "cond": 0, "seed": 770003, "case": "rankdef"}
- {"batch": 640, "n": 512, "cond": 0, "seed": 770004, "case": "clustered"}
- {"batch": 60, "n": 1024, "cond": 0, "seed": 770005, "case": "nearrank"}