Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,32 @@ name: Test

on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v3
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
ruff check --output-format=github .
- name: Test with pytest
run: |
python -m pytest tests
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v3
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Install pre-commit hooks
run: pre-commit install
- name: Run pre-commit hooks for linting and other checks
run: pre-commit run --all-files
- name: Test with pytest
run: |
python -m pytest tests
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

default_stages: [pre-commit]
repos:
- repo: https://github.com/rbubley/mirrors-prettier
rev: "v3.8.1" # Use the sha / tag you want to point at
hooks:
- id: prettier
additional_dependencies:
- prettier@2.1.2
- "@prettier/plugin-xml@0.12.0"
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
Expand All @@ -19,6 +26,6 @@ repos:
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
args: [--fix]
# Run the formatter.
- id: ruff-format
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Rule-Parser

DRAM LARK based Rule Parser for misciellaneous rules parsing tasks such as traits, summarize, and product.
3 changes: 3 additions & 0 deletions src/rules.lark
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ list_lit: "[" expr ("," expr)+ "]" -> step_

name_ref: IDENT ":" IDENT -> qualified_name
| IDENT -> simple_name
| BTICK_NAME -> quoted_name
| IDENT ":" BTICK_NAME -> quoted_name

IDENT: /[A-Za-z0-9_\-\.]+/
BTICK_NAME: /`([^`\\]|\\.)*`/
NUMBER: /[0-9]+(\.[0-9]+)?/
STRING: ESCAPED_STRING

Expand Down
53 changes: 37 additions & 16 deletions src/rules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

import functools
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional, Set, Iterable
import operator
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple
import functools

import networkx as nx
import numpy as np
import polars as pl
from lark import Lark, LarkError, Transformer
from lark import Lark, Transformer, LarkError

OP_TO_EXPR = {
"gt": operator.gt,
Expand Down Expand Up @@ -71,6 +71,7 @@ class RuleError(Exception):
.str.split(",")
.list.eval(pl.element().str.strip_chars().str.split(" ").list.first())
),
"DEFAULT": lambda col: pl.col(col).cast(pl.Utf8).cast(pl.List(pl.Utf8)),
}

CALL_FUNCTIONS = {
Expand Down Expand Up @@ -264,6 +265,15 @@ def qualified_name(self, items):
db, value = items
return Name(value=str(value), db=str(db))

def quoted_name(self, items):
val_index = 0 if (len(items) == 1) else 1
items[val_index] = (
items[val_index][1:-1].replace(r"\`", "`").replace(r"\\", "\\")
) # remove backticks
if len(items) == 2:
return self.qualified_name(items)
return self.simple_name(items)

def number(self, items):
return Number(float(str(items[0])))

Expand Down Expand Up @@ -323,7 +333,7 @@ def from_rules(cls, *args, **kwargs) -> CompiledRules:
}

# Expand rules using expanded defs
# we need to hit again in case defs is empty (no alias col)
# we need to hit again in case defs is empty (no parent col)
# and we still need to add needed features from rules
features_by_rules = {k: set() for k in rules.keys()}
trees_by_rules = {k: nx.DiGraph() for k in rules.keys()}
Expand Down Expand Up @@ -351,15 +361,15 @@ def load_rules(
rules_path: str = None,
rules: pl.LazyFrame = None,
label_col: str = "name",
alias_col: str = "alias",
parent_col: str = "alias",
rules_col: str = "rule",
allow_visualize_functions: bool = False,
) -> Tuple[Dict[str, Expr], Dict[str, Expr]]:
"""
Assumes TSV has columns at least: name, rule
Assumes TSV has columns at least: name, parent, child
Convention used here:
- if `name` is non-empty: this is an OUTPUT RULE whose expression is in `rule`
- if `name` is empty and `alias` is non-empty: this is a DEFINITION macro: alias := rule
- if `name` is non-empty: this is an OUTPUT RULE whose expression is in `child`
- if `name` is empty and `parent` is non-empty: this is a DEFINITION macro: parent := child
If your file uses a slightly different convention, adjust this function only.
"""
assert (rules_path is not None) != (rules is not None), (
Expand All @@ -380,7 +390,7 @@ def load_rules(
f"rules TSV missing required columns {required}. Found: {lf.columns}"
)

has_alias_col = alias_col and (alias_col in cols)
has_parent_col = parent_col and (parent_col in cols)

with open(Path(__file__).parent.absolute() / "rules.lark") as f:
parser = Lark(
Expand Down Expand Up @@ -431,11 +441,11 @@ def parse_rule_expr(expr_str: str) -> Expr:
}

definitions = {}
if has_alias_col:
if has_parent_col:
definitions = {
a: b
for a, b in lf.filter(~pl.col(alias_col).is_null())
.select([pl.col(alias_col), pl.col(rules_col)])
for a, b in lf.filter(~pl.col(parent_col).is_null())
.select([pl.col(parent_col), pl.col(rules_col)])
.collect()
.iter_rows()
}
Expand Down Expand Up @@ -567,7 +577,7 @@ def recurse(expr: Expr, add_name_to_needed: bool = True) -> Expr:
out = PipeChain(calls=tuple(calls))
else:
raise TypeError(expr)
# Some nodes need validation and have to be done after rules are expanded
# Some nodes need validation and have to be done after children are expanded
try:
out.validate()
except Exception:
Expand All @@ -593,8 +603,19 @@ def build_present_map(
"""Build present_map of needed gene_ids from annotations DataFrame"""
additional_cols = additional_cols or []
besthit_cols = [col for col in besthit_cols if col in lf.columns]
besthit_cols.extend(
[
col
for col in lf.columns
if (col.endswith("_id") and col not in besthit_cols and col != "query_id")
]
)
for col in besthit_cols:
lf = lf.with_columns(ID_EXPR_DICT[col].alias(col)).explode(col)
if col not in ID_EXPR_DICT:
explode_col = ID_EXPR_DICT["DEFAULT"](col).alias(col)
else:
explode_col = ID_EXPR_DICT[col].alias(col)
lf = lf.with_columns(explode_col).explode(col)
lf = lf.select([sample_col] + besthit_cols + additional_cols)

# unpivot to long (sample, hit)
Expand Down Expand Up @@ -833,7 +854,7 @@ def not_(x: np.ndarray = None) -> np.ndarray | pl.Expr:
if isinstance(x, pl.Expr):
x = [x]
masks = x
if len(masks) == 0:
if len(masks) == 1:
mask = masks[0]
else:
mask = masks[0].and_(*[mask for mask in masks[1:]])
Expand Down
Loading