Skip to content

Commit a63f2db

Browse files
bram2wsilvestrid
andauthored
feat(AI Assistant): database formula generator tool (baserow#4085)
* Add database formula generator tool * Address feedback * Fix broken test * Fix imports and created a flake8 plugin to prevent accidental imports * make the tool a bit more flexible: create a formula in a different table if it's a better fit for the user request --------- Co-authored-by: Davide Silvestri <davide@baserow.io>
1 parent 5f21e2d commit a63f2db

File tree

25 files changed

+1292
-359
lines changed

25 files changed

+1292
-359
lines changed

backend/.flake8

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
extend-ignore = E203, W503, F541, E501
33
max-doc-length = 88
44
per-file-ignores =
5-
tests/*: F841
6-
../premium/backend/tests/*: F841
7-
../enterprise/backend/tests/*: F841
8-
src/baserow/contrib/database/migrations/*: X1
9-
src/baserow/core/migrations/*: X1
10-
src/baserow/core/psycopg.py: BRP001
5+
tests/*: F841, BAI001
6+
../premium/backend/tests/*: F841, BAI001
7+
../enterprise/backend/tests/*: F841, BAI001
8+
src/baserow/contrib/database/migrations/*: BDC001
9+
src/baserow/core/migrations/*: BDC001
10+
src/baserow/core/psycopg.py: BPG001
1111
exclude =
1212
.git,
1313
__pycache__,
@@ -16,6 +16,7 @@ exclude =
1616

1717
[flake8:local-plugins]
1818
extension =
19-
X1 = flake8_baserow:DocstringPlugin
20-
BRP001 = flake8_baserow:BaserowPsycopgChecker
19+
BDC001 = flake8_baserow:DocstringPlugin
20+
BPG001 = flake8_baserow:BaserowPsycopgChecker
21+
BAI001 = flake8_baserow:BaserowAIImportsChecker
2122
paths = ./flake8_plugins

backend/flake8_plugins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .flake8_baserow import DocstringPlugin, BaserowPsycopgChecker
1+
from .flake8_baserow import DocstringPlugin, BaserowPsycopgChecker, BaserowAIImportsChecker
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .docstring import Plugin as DocstringPlugin
22
from .psycopg import BaserowPsycopgChecker
3+
from .ai_imports import BaserowAIImportsChecker
34

4-
__all__ = ["DocstringPlugin", "BaserowPsycopgChecker"]
5+
__all__ = ["DocstringPlugin", "BaserowPsycopgChecker", "BaserowAIImportsChecker"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import ast
2+
from typing import Iterator, Tuple, Any
3+
4+
5+
class BaserowAIImportsChecker:
6+
"""
7+
Flake8 plugin to ensure dspy and litellm are only imported locally within
8+
functions/methods, not at module level.
9+
"""
10+
11+
name = "flake8-baserow-ai-imports"
12+
version = "0.1.0"
13+
14+
def __init__(self, tree: ast.AST, filename: str):
15+
self.tree = tree
16+
self.filename = filename
17+
18+
def run(self) -> Iterator[Tuple[int, int, str, Any]]:
19+
"""Check for global imports of dspy and litellm."""
20+
for node in ast.walk(self.tree):
21+
# Check if this is a module-level import (not inside a function/method)
22+
if self._is_global_import(node):
23+
if isinstance(node, ast.Import):
24+
for alias in node.names:
25+
if self._is_ai_module(alias.name):
26+
yield (
27+
node.lineno,
28+
node.col_offset,
29+
f"BAI001 {alias.name} must be imported locally within functions/methods, not globally",
30+
type(self),
31+
)
32+
elif isinstance(node, ast.ImportFrom):
33+
if node.module and self._is_ai_module(node.module):
34+
yield (
35+
node.lineno,
36+
node.col_offset,
37+
f"BAI001 {node.module} must be imported locally within functions/methods, not globally",
38+
type(self),
39+
)
40+
41+
def _is_ai_module(self, module_name: str) -> bool:
42+
"""Check if the module is dspy or litellm (including submodules)."""
43+
if not module_name:
44+
return False
45+
return (
46+
module_name == "dspy"
47+
or module_name.startswith("dspy.")
48+
or module_name == "litellm"
49+
or module_name.startswith("litellm.")
50+
)
51+
52+
def _is_global_import(self, node: ast.AST) -> bool:
53+
"""
54+
Check if an import node is at global scope.
55+
Returns True if the import is not nested inside a function or method.
56+
"""
57+
if not isinstance(node, (ast.Import, ast.ImportFrom)):
58+
return False
59+
60+
# Walk up the AST to find if this import is inside a function/method
61+
# We need to check the parent nodes, but ast.walk doesn't provide parent info
62+
# So we'll traverse the tree differently
63+
return self._check_node_is_global(self.tree, node)
64+
65+
def _check_node_is_global(
66+
self, root: ast.AST, target: ast.AST, in_function: bool = False
67+
) -> bool:
68+
"""
69+
Recursively check if target node is at global scope.
70+
Returns True if the target is found at global scope (not in a function).
71+
"""
72+
if root is target:
73+
return not in_function
74+
75+
# Check if we're entering a function/method
76+
new_in_function = in_function or isinstance(
77+
root, (ast.FunctionDef, ast.AsyncFunctionDef)
78+
)
79+
80+
# Recursively check all child nodes
81+
for child in ast.iter_child_nodes(root):
82+
result = self._check_node_is_global(child, target, new_in_function)
83+
if result is not None:
84+
return result
85+
86+
return None

backend/flake8_plugins/flake8_baserow/docstring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
DocstringType = Union[ast.Constant, ast.Str]
21-
ERR_MSG = "X1 - Baserow plugin: missing empty line after docstring"
21+
ERR_MSG = "BDC001 - Baserow plugin: missing empty line after docstring"
2222

2323

2424
class Token:
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import ast
22
from typing import Iterator, Tuple, Any
33

4+
45
class BaserowPsycopgChecker:
5-
name = 'flake8-baserow-psycopg'
6-
version = '0.1.0'
6+
name = "flake8-baserow-psycopg"
7+
version = "0.1.0"
78

89
def __init__(self, tree: ast.AST, filename: str):
910
self.tree = tree
@@ -13,18 +14,18 @@ def run(self) -> Iterator[Tuple[int, int, str, Any]]:
1314
for node in ast.walk(self.tree):
1415
if isinstance(node, ast.Import):
1516
for alias in node.names:
16-
if alias.name in ('psycopg', 'psycopg2'):
17+
if alias.name in ("psycopg", "psycopg2"):
1718
yield (
1819
node.lineno,
1920
node.col_offset,
20-
'BRP001 Import psycopg/psycopg2 from baserow.core.psycopg instead',
21-
type(self)
21+
"BPG001 Import psycopg/psycopg2 from baserow.core.psycopg instead",
22+
type(self),
2223
)
2324
elif isinstance(node, ast.ImportFrom):
24-
if node.module in ('psycopg', 'psycopg2'):
25+
if node.module in ("psycopg", "psycopg2"):
2526
yield (
2627
node.lineno,
2728
node.col_offset,
28-
'BRP001 Import psycopg/psycopg2 from baserow.core.psycopg instead',
29-
type(self)
30-
)
29+
"BPG001 Import psycopg/psycopg2 from baserow.core.psycopg instead",
30+
type(self),
31+
)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import ast
2+
from flake8_baserow.ai_imports import BaserowAIImportsChecker
3+
4+
5+
def run_checker(code: str):
6+
"""Helper to run the checker on code and return errors."""
7+
tree = ast.parse(code)
8+
checker = BaserowAIImportsChecker(tree, "test.py")
9+
return list(checker.run())
10+
11+
12+
def test_global_dspy_import():
13+
"""Test that global dspy imports are flagged."""
14+
code = """
15+
import dspy
16+
"""
17+
errors = run_checker(code)
18+
assert len(errors) == 1
19+
assert "BAI001" in errors[0][2]
20+
assert "dspy" in errors[0][2]
21+
22+
23+
def test_global_litellm_import():
24+
"""Test that global litellm imports are flagged."""
25+
code = """
26+
import litellm
27+
"""
28+
errors = run_checker(code)
29+
assert len(errors) == 1
30+
assert "BAI001" in errors[0][2]
31+
assert "litellm" in errors[0][2]
32+
33+
34+
def test_global_dspy_from_import():
35+
"""Test that global 'from dspy import' statements are flagged."""
36+
code = """
37+
from dspy import ChainOfThought
38+
from dspy.predict import Predict
39+
"""
40+
errors = run_checker(code)
41+
assert len(errors) == 2
42+
assert all("BAI001" in error[2] for error in errors)
43+
44+
45+
def test_global_litellm_from_import():
46+
"""Test that global 'from litellm import' statements are flagged."""
47+
code = """
48+
from litellm import completion
49+
from litellm.utils import get_llm_provider
50+
"""
51+
errors = run_checker(code)
52+
assert len(errors) == 2
53+
assert all("BAI001" in error[2] for error in errors)
54+
55+
56+
def test_local_import_in_function():
57+
"""Test that local imports within functions are allowed."""
58+
code = """
59+
def my_function():
60+
import dspy
61+
import litellm
62+
from dspy import ChainOfThought
63+
from litellm import completion
64+
return dspy, litellm
65+
"""
66+
errors = run_checker(code)
67+
assert len(errors) == 0
68+
69+
70+
def test_local_import_in_method():
71+
"""Test that local imports within class methods are allowed."""
72+
code = """
73+
class MyClass:
74+
def my_method(self):
75+
import dspy
76+
from litellm import completion
77+
return dspy.ChainOfThought()
78+
"""
79+
errors = run_checker(code)
80+
assert len(errors) == 0
81+
82+
83+
def test_local_import_in_async_function():
84+
"""Test that local imports within async functions are allowed."""
85+
code = """
86+
async def my_async_function():
87+
import dspy
88+
from litellm import acompletion
89+
return await acompletion()
90+
"""
91+
errors = run_checker(code)
92+
assert len(errors) == 0
93+
94+
95+
def test_mixed_global_and_local_imports():
96+
"""Test that global imports are flagged while local imports are not."""
97+
code = """
98+
import dspy # This should be flagged
99+
100+
def my_function():
101+
import litellm # This should be OK
102+
return litellm.completion()
103+
104+
from dspy import ChainOfThought # This should be flagged
105+
"""
106+
errors = run_checker(code)
107+
assert len(errors) == 2
108+
assert all("BAI001" in error[2] for error in errors)
109+
110+
111+
def test_nested_function_imports():
112+
"""Test that imports in nested functions are allowed."""
113+
code = """
114+
def outer_function():
115+
def inner_function():
116+
import dspy
117+
from litellm import completion
118+
return dspy, completion
119+
return inner_function()
120+
"""
121+
errors = run_checker(code)
122+
assert len(errors) == 0
123+
124+
125+
def test_other_imports_not_affected():
126+
"""Test that other imports are not flagged."""
127+
code = """
128+
import os
129+
import sys
130+
from typing import List
131+
from baserow.core.models import User
132+
"""
133+
errors = run_checker(code)
134+
assert len(errors) == 0
135+
136+
137+
def test_multiple_global_imports():
138+
"""Test multiple global AI imports."""
139+
code = """
140+
import dspy
141+
import litellm
142+
from dspy import ChainOfThought
143+
from litellm import completion
144+
import os # This should not be flagged
145+
"""
146+
errors = run_checker(code)
147+
assert len(errors) == 4
148+
assert all("BAI001" in error[2] for error in errors)
149+
150+
151+
def test_import_with_alias():
152+
"""Test that imports with aliases are also caught."""
153+
code = """
154+
import dspy as d
155+
import litellm as llm
156+
157+
def my_function():
158+
import dspy as local_d
159+
return local_d
160+
"""
161+
errors = run_checker(code)
162+
assert len(errors) == 2
163+
assert all("BAI001" in error[2] for error in errors)
164+
165+
166+
def test_submodule_imports():
167+
"""Test that submodule imports are caught at global scope."""
168+
code = """
169+
from dspy.teleprompt import BootstrapFewShot
170+
from litellm.utils import token_counter
171+
172+
def my_function():
173+
from dspy.predict import Predict
174+
from litellm.integrations import log_event
175+
return Predict, log_event
176+
"""
177+
errors = run_checker(code)
178+
assert len(errors) == 2
179+
assert all("BAI001" in error[2] for error in errors)
180+
# Verify the errors are for the global imports
181+
assert errors[0][0] == 2 # Line number of first import
182+
assert errors[1][0] == 3 # Line number of second import
183+
184+
185+
def test_class_method_and_staticmethod():
186+
"""Test that imports in classmethods and staticmethods are allowed."""
187+
code = """
188+
class MyClass:
189+
@classmethod
190+
def my_classmethod(cls):
191+
import dspy
192+
return dspy
193+
194+
@staticmethod
195+
def my_staticmethod():
196+
from litellm import completion
197+
return completion
198+
"""
199+
errors = run_checker(code)
200+
assert len(errors) == 0
201+
202+
203+
def test_lambda_not_considered_function():
204+
"""Test that imports in lambdas (which aren't supported anyway) at module level are flagged."""
205+
code = """
206+
# Note: This is contrived since you can't actually have imports in lambdas,
207+
# but this tests that lambda doesn't count as a function scope
208+
import dspy
209+
"""
210+
errors = run_checker(code)
211+
assert len(errors) == 1
212+
assert "BAI001" in errors[0][2]

0 commit comments

Comments
 (0)