|
1 | 1 | import logging |
| 2 | +from typing import Optional |
| 3 | +import ast |
| 4 | +import re |
2 | 5 |
|
3 | | -import pandas as pd |
4 | | - |
| 6 | +from duckdb import DuckDBPyConnection, DuckDBPyRelation |
5 | 7 | from countess import VERSION |
6 | 8 | from countess.core.parameters import BooleanParam, PerColumnArrayParam, TextParam |
7 | | -from countess.core.plugins import PandasSimplePlugin |
| 9 | +from countess.core.plugins import DuckdbSimplePlugin |
| 10 | +from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal |
8 | 11 |
|
9 | 12 | logger = logging.getLogger(__name__) |
10 | 13 |
|
11 | | - |
12 | | -def process(df: pd.DataFrame, codes): |
13 | | - for code in codes: |
14 | | - if not code: |
15 | | - continue |
16 | | - |
17 | | - try: |
18 | | - result = df.eval(code) |
19 | | - except Exception as exc: # pylint: disable=W0718 |
20 | | - logger.warning("Exception", exc_info=exc) |
21 | | - continue |
22 | | - |
23 | | - if isinstance(result, pd.Series): |
24 | | - # this was a filter |
25 | | - df = df.copy() |
26 | | - df["__filter"] = result |
27 | | - df = df.query("__filter != 0").drop(columns="__filter") |
28 | | - else: |
29 | | - # this was a column assignment |
30 | | - df = result |
31 | | - |
32 | | - return df |
33 | | - |
34 | | - |
35 | | -class ExpressionPlugin(PandasSimplePlugin): |
| 14 | +UNOPS = { ast.UAdd: '+', ast.USub: '-' } |
| 15 | +BINOPS = { ast.Add: '+', ast.Mult: '*', ast.Div: '/', ast.Sub: '-', ast.FloorDiv: '//', ast.Mod: '%', ast.Pow: '**' } |
| 16 | +FUNCOPS = {'abs', 'len', 'sin', 'cos', 'tan', 'sqrt', 'log', 'log2', 'log10', 'pow', 'exp' } |
| 17 | +COMPOP = { ast.Eq: '=', ast.NotEq: '!=', ast.Lt: '<', ast.LtE: '<=', ast.Gt: '>', ast.GtE: '>=' } |
| 18 | + |
| 19 | +def _transmogrify(ast_node): |
| 20 | + """Transform an AST node back into a string which can be parsed by DuckDB's expression |
| 21 | + parser. This is a pretty small subset of all the things you might write but on the |
| 22 | + other hand it saved actually writing a parser.""" |
| 23 | + # XXX might have to write a parser anyway since the AST parser handles decimal |
| 24 | + # literals badly. Worry about that later. |
| 25 | + if type(ast_node) is ast.Name: |
| 26 | + return duckdb_escape_identifier(ast_node.id) |
| 27 | + elif type(ast_node) is ast.Constant: |
| 28 | + return duckdb_escape_literal(ast_node.value) |
| 29 | + elif type(ast_node) is ast.UnaryOp and type(ast_node.op) in UNOPS: |
| 30 | + return "(" + UNOPS[type(ast_node.op)] + _transmogrify(ast_node.operand) + ")" |
| 31 | + elif type(ast_node) is ast.BinOp and type(ast_node.op) in BINOPS: |
| 32 | + return "(" + _transmogrify(ast_node.left) + BINOPS[type(ast_node.op)] + _transmogrify(ast_node.right) + ")" |
| 33 | + elif type(ast_node) is ast.Compare and all(type(op) in COMPOP for op in ast_node.ops): |
| 34 | + args = [ _transmogrify(x) for x in [ ast_node.left ] + ast_node.comparators ] |
| 35 | + return "(" + (" AND ".join( |
| 36 | + args[num] + COMPOP[type(op)] + args[num+1] |
| 37 | + for num, op in enumerate(ast_node.ops) |
| 38 | + )) + ")" |
| 39 | + elif type(ast_node) is ast.Call and ast_node.func.id in FUNCOPS: |
| 40 | + args = [ _transmogrify(x) for x in ast_node.args ] |
| 41 | + return ast_node.func.id + "(" + (",".join(args)) + ")" |
| 42 | + else: |
| 43 | + raise NotImplementedError(f"Unknown Node {ast_node}") |
| 44 | + |
| 45 | + |
| 46 | +class ExpressionPlugin(DuckdbSimplePlugin): |
36 | 47 | name = "Expression" |
37 | | - description = "Apply simple expressions" |
| 48 | + description = "Apply simple expressions to each row" |
38 | 49 | version = VERSION |
39 | 50 |
|
40 | 51 | code = TextParam("Expressions") |
41 | 52 | drop = PerColumnArrayParam("Drop Columns", BooleanParam("Drop")) |
| 53 | + projection = None |
42 | 54 |
|
43 | | - def process_dataframe(self, dataframe: pd.DataFrame) -> pd.DataFrame: |
44 | | - codes = [c.replace("\n", " ").strip() for c in str(self.code).split("\n\n")] |
45 | | - df = process(dataframe, codes) |
46 | | - |
47 | | - drop_names = [label for label, param in self.drop.get_column_params() if param.value] |
48 | | - |
49 | | - drop_indexes = [col for col in drop_names if col in df.index.names] |
50 | | - if drop_indexes: |
51 | | - df = df.reset_index(drop_indexes, drop=True) |
52 | | - |
53 | | - drop_columns = [col for col in drop_names if col in df.columns] |
54 | | - if drop_columns: |
55 | | - df.drop(columns=drop_columns, inplace=True) |
| 55 | + def prepare(self, *a) -> None: |
| 56 | + super().prepare(*a) |
| 57 | + self.projection = [] |
| 58 | + try: |
| 59 | + ast_root = ast.parse(self.code.value or "") |
| 60 | + except SyntaxError as exc: |
| 61 | + logger.debug("Syntax Error %s", exc) |
| 62 | + |
| 63 | + for ast_node in ast_root.body: |
| 64 | + try: |
| 65 | + if type(ast_node) is ast.Assign: |
| 66 | + for ast_target in ast_node.targets: |
| 67 | + if type(ast_target) is ast.Name: |
| 68 | + expr = _transmogrify(ast_node.value) |
| 69 | + tgt = duckdb_escape_identifier(ast_target.id) |
| 70 | + self.projection.append(f"{expr} AS {tgt}") |
| 71 | + elif type(ast_node) is ast.Expr: |
| 72 | + tgt = duckdb_escape_identifier(re.sub(r'_+$', '', re.sub(r'\W+', '_', ast.unparse(ast_node)))) |
| 73 | + expr = _transmogrify(ast_node.value) |
| 74 | + self.projection.append(f"{expr} AS {tgt}") |
| 75 | + |
| 76 | + except (NotImplementedError, KeyError) as exc: |
| 77 | + logger.debug("Bad AST Node: %s %s", ast_node, exc) |
| 78 | + |
| 79 | + def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> Optional [DuckDBPyRelation]: |
| 80 | + |
| 81 | + column_params = dict(self.drop.get_column_params()) |
| 82 | + if any(v.value for k, v in column_params.items()): |
| 83 | + projection = [ |
| 84 | + duckdb_escape_identifier(column) |
| 85 | + for column in source.columns |
| 86 | + if not column_params[column].value |
| 87 | + ] |
| 88 | + else: |
| 89 | + projection = [ '*' ] |
56 | 90 |
|
57 | | - return df |
| 91 | + sql = ', '.join(projection + self.projection) |
| 92 | + print(">>> " + sql) |
| 93 | + logger.debug("ExpressionPlugin.execute %s", sql) |
| 94 | + return source.project(sql) |
0 commit comments