From bd924c5fc45bc609cc66796b4dd4e365e205c8d9 Mon Sep 17 00:00:00 2001 From: Tsering Paljor Date: Wed, 5 Nov 2025 12:58:12 +0400 Subject: [PATCH 1/2] Ensure equal/not_equal can accept any arg type (#4164) * Make equal/not_equal accept any arg type. * Lint fix --- backend/src/baserow/core/formula/argument_types.py | 8 ++++++++ .../baserow/core/formula/runtime_formula_types.py | 9 +++++---- .../modules/core/runtimeFormulaArgumentTypes.js | 10 ++++++++++ web-frontend/modules/core/runtimeFormulaTypes.js | 13 +++++++------ 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/backend/src/baserow/core/formula/argument_types.py b/backend/src/baserow/core/formula/argument_types.py index 2baea8bfc2..d2acb19597 100644 --- a/backend/src/baserow/core/formula/argument_types.py +++ b/backend/src/baserow/core/formula/argument_types.py @@ -105,3 +105,11 @@ def test(self, value): def parse(self, value): return ensure_string(value) + + +class AnyBaserowRuntimeFormulaArgumentType(BaserowRuntimeFormulaArgumentType): + def test(self, value): + return True + + def parse(self, value): + return value diff --git a/backend/src/baserow/core/formula/runtime_formula_types.py b/backend/src/baserow/core/formula/runtime_formula_types.py index a5c2ed6f38..4e21344e3a 100644 --- a/backend/src/baserow/core/formula/runtime_formula_types.py +++ b/backend/src/baserow/core/formula/runtime_formula_types.py @@ -9,6 +9,7 @@ from baserow.core.formula.argument_types import ( AddableBaserowRuntimeFormulaArgumentType, + AnyBaserowRuntimeFormulaArgumentType, BooleanBaserowRuntimeFormulaArgumentType, DateTimeBaserowRuntimeFormulaArgumentType, DictBaserowRuntimeFormulaArgumentType, @@ -113,8 +114,8 @@ def execute(self, context: FormulaContext, args: FormulaArgs): class RuntimeEqual(RuntimeFormulaFunction): type = "equal" args = [ - NumberBaserowRuntimeFormulaArgumentType(), - NumberBaserowRuntimeFormulaArgumentType(), + AnyBaserowRuntimeFormulaArgumentType(), + AnyBaserowRuntimeFormulaArgumentType(), ] def validate_number_of_args(self, args): @@ -127,8 +128,8 @@ def execute(self, context: FormulaContext, args: FormulaArgs): class RuntimeNotEqual(RuntimeFormulaFunction): type = "not_equal" args = [ - NumberBaserowRuntimeFormulaArgumentType(), - NumberBaserowRuntimeFormulaArgumentType(), + AnyBaserowRuntimeFormulaArgumentType(), + AnyBaserowRuntimeFormulaArgumentType(), ] def validate_number_of_args(self, args): diff --git a/web-frontend/modules/core/runtimeFormulaArgumentTypes.js b/web-frontend/modules/core/runtimeFormulaArgumentTypes.js index 902f741f72..ba22b85eff 100644 --- a/web-frontend/modules/core/runtimeFormulaArgumentTypes.js +++ b/web-frontend/modules/core/runtimeFormulaArgumentTypes.js @@ -121,3 +121,13 @@ export class TimezoneBaserowRuntimeFormulaArgumentType extends BaserowRuntimeFor return ensureString(value) } } + +export class AnyBaserowRuntimeFormulaArgumentType extends BaserowRuntimeFormulaArgumentType { + test(value) { + return true + } + + parse(value) { + return value + } +} diff --git a/web-frontend/modules/core/runtimeFormulaTypes.js b/web-frontend/modules/core/runtimeFormulaTypes.js index 7f52e8a505..bbffd6cf2e 100644 --- a/web-frontend/modules/core/runtimeFormulaTypes.js +++ b/web-frontend/modules/core/runtimeFormulaTypes.js @@ -6,6 +6,7 @@ import { ObjectBaserowRuntimeFormulaArgumentType, BooleanBaserowRuntimeFormulaArgumentType, TimezoneBaserowRuntimeFormulaArgumentType, + AnyBaserowRuntimeFormulaArgumentType, } from '@baserow/modules/core/runtimeFormulaArgumentTypes' import { InvalidFormulaArgumentType, @@ -487,8 +488,8 @@ export class RuntimeEqual extends RuntimeFormulaFunction { get args() { return [ - new NumberBaserowRuntimeFormulaArgumentType(), - new NumberBaserowRuntimeFormulaArgumentType(), + new AnyBaserowRuntimeFormulaArgumentType(), + new AnyBaserowRuntimeFormulaArgumentType(), ] } @@ -502,7 +503,7 @@ export class RuntimeEqual extends RuntimeFormulaFunction { } getExamples() { - return ['2 = 3 = false'] + return ['2 = 3 = false', '"foo" = "bar" = false', '"foo" = "foo" = true'] } } @@ -525,8 +526,8 @@ export class RuntimeNotEqual extends RuntimeFormulaFunction { get args() { return [ - new NumberBaserowRuntimeFormulaArgumentType(), - new NumberBaserowRuntimeFormulaArgumentType(), + new AnyBaserowRuntimeFormulaArgumentType(), + new AnyBaserowRuntimeFormulaArgumentType(), ] } @@ -540,7 +541,7 @@ export class RuntimeNotEqual extends RuntimeFormulaFunction { } getExamples() { - return ['2 != 3 = true'] + return ['2 != 3 = true', '"foo" != "foo" = false', '"foo" != "bar" = true'] } } From 4b93d373cb5b9831fff5d05714023e8bb0220b4c Mon Sep 17 00:00:00 2001 From: dimmur-brw Date: Wed, 5 Nov 2025 14:13:23 +0100 Subject: [PATCH 2/2] Add filters support for AI field (#4166) Add filters support for AI Field --- .../contrib/database/fields/registries.py | 13 + .../contrib/database/views/registries.py | 20 +- .../contrib/database/views/view_filters.py | 9 +- .../database/field/test_field_types.py | 10 +- .../local_baserow/test_service_types.py | 4 +- ...3801_add_filters_support_for_ai_field.json | 8 + .../src/baserow_premium/fields/field_types.py | 12 + .../fields/test_ai_field_filters.py | 872 ++++++++++++++++++ .../fields/test_ai_field_type.py | 13 + .../modules/baserow_premium/fieldTypes.js | 12 + web-frontend/modules/database/fieldTypes.js | 4 + web-frontend/modules/database/viewFilters.js | 42 +- .../test/unit/database/utils/field.spec.js | 15 +- .../unit/database/viewFiltersMatch.spec.js | 2 +- 14 files changed, 1013 insertions(+), 23 deletions(-) create mode 100644 changelog/entries/unreleased/feature/3801_add_filters_support_for_ai_field.json create mode 100644 premium/backend/tests/baserow_premium_tests/fields/test_ai_field_filters.py diff --git a/backend/src/baserow/contrib/database/fields/registries.py b/backend/src/baserow/contrib/database/fields/registries.py index f1d7a63e48..b9168b38da 100644 --- a/backend/src/baserow/contrib/database/fields/registries.py +++ b/backend/src/baserow/contrib/database/fields/registries.py @@ -1674,6 +1674,19 @@ def check_can_filter_by(self, field: Field) -> bool: return len(compatible_vft) > 0 + def get_compatible_filter_field_type(self, field: Field) -> "FieldType": + """ + Returns the canonical field type to use when determining compatibility with + view filters. By default this returns self, but field types that should be + treated as another field type for filtering can override this and return that + other field type instance. + + :param field: The concrete field instance being checked. + :return: The FieldType that should be used for filter compatibility. + """ + + return self + def check_can_order_by(self, field: Field, sort_type: str) -> bool: """ Override this method if this field type can sometimes be ordered or sometimes diff --git a/backend/src/baserow/contrib/database/views/registries.py b/backend/src/baserow/contrib/database/views/registries.py index 22bdf4544e..8c23ef9b2d 100644 --- a/backend/src/baserow/contrib/database/views/registries.py +++ b/backend/src/baserow/contrib/database/views/registries.py @@ -921,10 +921,11 @@ def get_filter(self, field_name, value): compatible_field_types: List[Union[str, Callable[["Field"], bool]]] = [] """ - Defines which field types are compatible with the filter. Only the supported ones - can be used in combination with the field. The values in this list can either be - the literal field_type.type string, or a callable which takes the field being - checked and returns True if compatible or False if not. + Defines which field types are compatible with the filter. The values in this list + can either be the literal `FieldType.type` string, or a callable taking the + concrete field and returning True/False. When checking compatibility the + `FieldType.get_compatible_filter_field_type(field)` indirection is applied first, + allowing field types to alias themselves to another type for filtering purposes. """ def default_filter_on_exception(self): @@ -1000,9 +1001,9 @@ def field_is_compatible(self, field: "Field") -> bool: Given a particular instance of a field returns a list of Type[FieldType] which are compatible with this particular field type. - Works by checking the field_type against this view filters list of compatible - field types or compatibility checking functions defined in - self.allowed_field_types. + Works by checking the field's canonical filter type (as returned by + `FieldType.get_compatible_filter_field_type(field)`) against this filter's + `compatible_field_types`. :param field: The field to check. :return: True if the field is compatible, False otherwise. @@ -1011,9 +1012,12 @@ def field_is_compatible(self, field: "Field") -> bool: from baserow.contrib.database.fields.registries import field_type_registry field_type = field_type_registry.get_by_model(field.specific_class) + # Allow field types to map themselves to another field type for filter + # compatibility checks. + compatible_field_type = field_type.get_compatible_filter_field_type(field) return any( - callable(t) and t(field) or t == field_type.type + (callable(t) and t(field)) or t == compatible_field_type.type for t in self.compatible_field_types ) diff --git a/backend/src/baserow/contrib/database/views/view_filters.py b/backend/src/baserow/contrib/database/views/view_filters.py index 89e94e55b2..429d0682fa 100644 --- a/backend/src/baserow/contrib/database/views/view_filters.py +++ b/backend/src/baserow/contrib/database/views/view_filters.py @@ -1104,7 +1104,8 @@ def get_filter(self, field_name, value, model_field, field): return Q() field_type = field_type_registry.get_by_model(field) - filter_function = self.filter_functions[field_type.type] + effective_field_type = field_type.get_compatible_filter_field_type(field) + filter_function = self.filter_functions[effective_field_type.type] return filter_function(field_name, value, model_field, field) def set_import_serialized_value(self, value, id_mapping): @@ -1160,7 +1161,8 @@ def get_filter(self, field_name, value: str, model_field, field): return self.default_filter_on_exception() field_type = field_type_registry.get_by_model(field) - filter_function = self.filter_functions[field_type.type] + effective_field_type = field_type.get_compatible_filter_field_type(field) + filter_function = self.filter_functions[effective_field_type.type] return filter_function(field_name, option_ids, model_field, field) def set_import_serialized_value(self, value: str | None, id_mapping: dict) -> str: @@ -1419,7 +1421,8 @@ def get_filter(self, field_name, value: str, model_field, field): return self.default_filter_on_exception() field_type = field_type_registry.get_by_model(field) - filter_function = self.filter_functions[field_type.type] + effective_field_type = field_type.get_compatible_filter_field_type(field) + filter_function = self.filter_functions[effective_field_type.type] return filter_function(field_name, option_ids, model_field, field) def set_import_serialized_value(self, value: str | None, id_mapping: dict) -> str: diff --git a/backend/tests/baserow/contrib/database/field/test_field_types.py b/backend/tests/baserow/contrib/database/field/test_field_types.py index c4319c66e6..bf6c128aba 100644 --- a/backend/tests/baserow/contrib/database/field/test_field_types.py +++ b/backend/tests/baserow/contrib/database/field/test_field_types.py @@ -304,7 +304,6 @@ "is_even_and_whole", ], "password": ["empty", "not_empty"], - "ai": [], } @@ -1216,7 +1215,14 @@ def test_field_type_prepare_db_value_with_invalid_values(data_fixture): field_type.prepare_value_for_db(field, test_payload) -@pytest.mark.parametrize("field_type", field_type_registry.get_all()) +@pytest.mark.parametrize( + "field_type", + [ + ft + for ft in field_type_registry.get_all() + if ft.type in COMPATIBLE_FIELD_TYPE_VIEW_FILTER_TYPES + ], +) @patch("baserow.contrib.database.fields.registries.FieldTypeRegistry.get_by_model") def test_field_type_check_can_filter_by(mock_get_by_model, field_type): mock_get_by_model.return_value = field_type diff --git a/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py b/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py index 3264046f57..5680831cd5 100644 --- a/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py +++ b/backend/tests/baserow/contrib/integrations/local_baserow/test_service_types.py @@ -1050,7 +1050,7 @@ def reset_metadata(schema, field_name): "default": None, "searchable": True, "sortable": True, - "filterable": False, + "filterable": True, "original_type": "ai", "metadata": {}, "type": "string", @@ -1060,7 +1060,7 @@ def reset_metadata(schema, field_name): "default": None, "searchable": True, "sortable": True, - "filterable": False, + "filterable": True, "original_type": "ai", "metadata": {}, "properties": { diff --git a/changelog/entries/unreleased/feature/3801_add_filters_support_for_ai_field.json b/changelog/entries/unreleased/feature/3801_add_filters_support_for_ai_field.json new file mode 100644 index 0000000000..7e7d60a907 --- /dev/null +++ b/changelog/entries/unreleased/feature/3801_add_filters_support_for_ai_field.json @@ -0,0 +1,8 @@ +{ + "type": "feature", + "message": "Add filters support for AI field", + "domain": "database", + "issue_number": 3801, + "bullet_points": [], + "created_at": "2025-10-23" +} \ No newline at end of file diff --git a/premium/backend/src/baserow_premium/fields/field_types.py b/premium/backend/src/baserow_premium/fields/field_types.py index c960a09967..06165661cd 100644 --- a/premium/backend/src/baserow_premium/fields/field_types.py +++ b/premium/backend/src/baserow_premium/fields/field_types.py @@ -166,6 +166,18 @@ def contains_query(self, field_name, value, model_field, field): baserow_field_type = self.get_baserow_field_type(field) return baserow_field_type.contains_query(field_name, value, model_field, field) + def parse_filter_value(self, field, model_field, value): + baserow_field_type = self.get_baserow_field_type(field) + return baserow_field_type.parse_filter_value(field, model_field, value) + + def get_compatible_filter_field_type(self, field): + """ + For AI fields, return the underlying core field type used for filtering. + """ + + ai_field_instance = field.specific if hasattr(field, "specific") else field + return self.get_baserow_field_type(ai_field_instance) + def contains_word_query(self, field_name, value, model_field, field): baserow_field_type = self.get_baserow_field_type(field) return baserow_field_type.contains_word_query( diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_filters.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_filters.py new file mode 100644 index 0000000000..bf982c947e --- /dev/null +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_filters.py @@ -0,0 +1,872 @@ +import pytest + +from baserow.contrib.database.rows.handler import RowHandler +from baserow.contrib.database.views.handler import ViewHandler +from baserow.contrib.database.views.registries import view_filter_type_registry + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_empty_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Some text"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": ""} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": None} + ) + + # Create an empty filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="empty", + value="", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only rows with empty values + assert queryset.count() == 2 + assert row_2 in queryset + assert row_3 in queryset + assert row_1 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_not_empty_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Some text"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": ""} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": None} + ) + + # Create a not_empty filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="not_empty", + value="", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only rows with non-empty values + assert queryset.count() == 1 + assert row_1 in queryset + assert row_2 not in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_contains_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hello world"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Goodbye world"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Test message"} + ) + + # Create a contains filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="contains", + value="world", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only rows containing "world" + assert queryset.count() == 2 + assert row_1 in queryset + assert row_2 in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_contains_not_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hello world"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Goodbye world"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Test message"} + ) + + # Create a contains_not filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="contains_not", + value="world", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only rows not containing "world" + assert queryset.count() == 1 + assert row_3 in queryset + assert row_1 not in queryset + assert row_2 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_contains_word_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hello world"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "worldwide coverage"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Test world"} + ) + + # Create a contains_word filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="contains_word", + value="world", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only rows containing "world" as a whole word + assert queryset.count() == 2 + assert row_1 in queryset + assert row_3 in queryset + assert row_2 not in queryset # "worldwide" doesn't contain "world" as whole word + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_equal_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "exact match"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Exact match"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "different"} + ) + + # Create an equal filter + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="equal", + value="exact match", + ) + + # Apply the filter + queryset = view_handler.get_queryset(grid_view) + + # Should return only the exact match (case-sensitive) + assert queryset.count() == 1 + assert row_1 in queryset + assert row_2 not in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_equal_filter_with_numeric_string(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "17"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "170"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "different"} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="equal", + value="17", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 1 + assert row_1 in queryset + assert row_2 not in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_not_equal_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "exact match"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Exact match"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "different"} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="not_equal", + value="exact match", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 2 + assert row_1 not in queryset + assert row_2 in queryset + assert row_3 in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_text_output_supports_length_is_lower_than_filter( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hi"} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hello"} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": "Hello World"} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="length_is_lower_than", + value="6", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 2 + assert row_1 in queryset + assert row_2 in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_is_compatible_with_text_filters(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Text", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="text", + ai_prompt="'test'", + ) + + text_filter_types = [ + "empty", + "not_empty", + "contains", + "contains_not", + "contains_word", + "doesnt_contain_word", + "equal", + "not_equal", + "length_is_lower_than", + ] + + for filter_type_name in text_filter_types: + filter_type = view_filter_type_registry.get(filter_type_name) + is_compatible = filter_type.field_is_compatible(ai_field) + assert is_compatible, f"AI field should be compatible with {filter_type_name}" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_not_compatible_with_text_only_filters( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + filter_type = view_filter_type_registry.get("length_is_lower_than") + is_compatible = filter_type.field_is_compatible(ai_field) + assert ( + not is_compatible + ), "AI field with choice output should NOT be compatible with length_is_lower_than" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_is_compatible_with_select_filters( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + select_filter_types = [ + "single_select_equal", + "single_select_not_equal", + "single_select_is_any_of", + "single_select_is_none_of", + "empty", + "not_empty", + ] + + for filter_type_name in select_filter_types: + filter_type = view_filter_type_registry.get(filter_type_name) + is_compatible = filter_type.field_is_compatible(ai_field) + assert ( + is_compatible + ), f"AI field with choice output should be compatible with {filter_type_name}" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_empty_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": None} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="empty", + value="", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 1 + assert row_2 in queryset + assert row_1 not in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_not_empty_filter(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": None} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="not_empty", + value="", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 2 + assert row_1 in queryset + assert row_3 in queryset + assert row_2 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_single_select_equal_filter( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + option_c = premium_data_fixture.create_select_option(field=ai_field, value="C") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_c.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="single_select_equal", + value=str(option_a.id), + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 1 + assert row_1 in queryset + assert row_2 not in queryset + assert row_3 not in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_single_select_not_equal_filter( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + option_c = premium_data_fixture.create_select_option(field=ai_field, value="C") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_c.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="single_select_not_equal", + value=str(option_a.id), + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 2 + assert row_1 not in queryset + assert row_2 in queryset + assert row_3 in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_single_select_is_any_of_filter( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + option_c = premium_data_fixture.create_select_option(field=ai_field, value="C") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_c.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="single_select_is_any_of", + value=f"{option_a.id},{option_c.id}", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 2 + assert row_1 in queryset + assert row_2 not in queryset + assert row_3 in queryset + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_choice_output_supports_single_select_is_none_of_filter( + premium_data_fixture, +): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + grid_view = premium_data_fixture.create_grid_view(table=table) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field( + table=table, + order=1, + name="AI Choice", + ai_generative_ai_type="test_generative_ai", + ai_generative_ai_model="test_1", + ai_output_type="choice", + ai_prompt="'test'", + ) + + option_a = premium_data_fixture.create_select_option(field=ai_field, value="A") + option_b = premium_data_fixture.create_select_option(field=ai_field, value="B") + option_c = premium_data_fixture.create_select_option(field=ai_field, value="C") + + handler = RowHandler() + model = table.get_model() + + row_1 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_a.id} + ) + row_2 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_b.id} + ) + row_3 = handler.create_row( + user=user, table=table, values={f"field_{ai_field.id}": option_c.id} + ) + + view_handler = ViewHandler() + view_handler.create_filter( + user=user, + view=grid_view, + field=ai_field, + type_name="single_select_is_none_of", + value=f"{option_a.id},{option_c.id}", + ) + + queryset = view_handler.get_queryset(grid_view) + + assert queryset.count() == 1 + assert row_1 not in queryset + assert row_2 in queryset + assert row_3 not in queryset diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py index 2433272901..d21eff9bde 100644 --- a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_type.py @@ -1193,3 +1193,16 @@ def test_ai_field_can_be_used_in_lookup_expression(premium_data_fixture): assert formula_field is not None assert formula_field.formula_type == "array" + + +@pytest.mark.django_db +@pytest.mark.field_ai +def test_ai_field_type_check_can_filter_by(premium_data_fixture): + user = premium_data_fixture.create_user() + table = premium_data_fixture.create_database_table(user=user) + premium_data_fixture.register_fake_generate_ai_type() + + ai_field = premium_data_fixture.create_ai_field(table=table, ai_output_type="text") + + ai_field_type = field_type_registry.get("ai") + assert ai_field_type.check_can_filter_by(ai_field) is True diff --git a/premium/web-frontend/modules/baserow_premium/fieldTypes.js b/premium/web-frontend/modules/baserow_premium/fieldTypes.js index df82757a81..a2eff7dc81 100644 --- a/premium/web-frontend/modules/baserow_premium/fieldTypes.js +++ b/premium/web-frontend/modules/baserow_premium/fieldTypes.js @@ -182,6 +182,14 @@ export class AIFieldType extends FieldType { return this.getBaserowFieldType(field).isEqual(field, value1, value2) } + parseFilterValue(field, filterValue) { + return this.getBaserowFieldType(field).parseFilterValue(field, filterValue) + } + + formatFilterValue(field, value) { + return this.getBaserowFieldType(field).formatFilterValue(field, value) + } + canBeReferencedByFormulaField(field) { return this.getBaserowFieldType(field).canBeReferencedByFormulaField(field) } @@ -214,6 +222,10 @@ export class AIFieldType extends FieldType { prepareValueForPaste(field, value) { return this.getBaserowFieldType(field).prepareValueForPaste(field, value) } + + getCompatibleFilterFieldType(field) { + return this.getBaserowFieldType(field) + } } export class PremiumFormulaFieldType extends FormulaFieldType { diff --git a/web-frontend/modules/database/fieldTypes.js b/web-frontend/modules/database/fieldTypes.js index 55a87be25b..992e5593ed 100644 --- a/web-frontend/modules/database/fieldTypes.js +++ b/web-frontend/modules/database/fieldTypes.js @@ -319,6 +319,10 @@ export class FieldType extends Registerable { return null } + getCompatibleFilterFieldType(field) { + return this + } + /** * In some cases, for example with the kanban view or the gallery view, we want to * only show the visible cards. In order to calculate the correct position of diff --git a/web-frontend/modules/database/viewFilters.js b/web-frontend/modules/database/viewFilters.js index c306615a9c..cfe2c6f408 100644 --- a/web-frontend/modules/database/viewFilters.js +++ b/web-frontend/modules/database/viewFilters.js @@ -128,12 +128,35 @@ export class ViewFilterType extends Registerable { } /** - * Returns if a given field is compatible with this view filter or not. Uses the - * list provided by getCompatibleFieldTypes to calculate this. + * Returns if a given field is compatible with this view filter or not. + * + * Checks compatibility by resolving the field's canonical filter type via + * getCompatibleFilterFieldType(field) and then matching against + * this.compatibleFieldTypes (strings or predicates). */ fieldIsCompatible(field) { - const valuesMap = this.getCompatibleFieldTypes().map((type) => [type, true]) - return this.getCompatibleFieldValue(field, valuesMap, false) + // Resolve the canonical field type for filter compatibility. + const fieldType = this.app.$registry.get('field', field.type) + const canonicalFieldType = fieldType + ? fieldType.getCompatibleFilterFieldType(field) + : fieldType + + const compareType = canonicalFieldType + ? canonicalFieldType.constructor.getType() + : field.type + + // Check static compatible list and predicates + for (const typeOrFunc of this.compatibleFieldTypes) { + if (typeOrFunc instanceof Function) { + if (typeOrFunc(field)) { + return true + } + } else if (compareType === typeOrFunc) { + return true + } + } + + return false } /** @@ -152,12 +175,21 @@ export class ViewFilterType extends Registerable { * @returns {any} The value that is compatible with the field or the notFoundValue. */ getCompatibleFieldValue(field, valuesMap, notFoundValue = null) { + // Resolve the canonical field type for filter compatibility. + const fieldType = this.app?.$registry?.get('field', field.type) + const canonicalFieldType = fieldType + ? fieldType.getCompatibleFilterFieldType(field) + : fieldType + const compareType = canonicalFieldType + ? canonicalFieldType.constructor.getType() + : field.type + for (const [typeOrFunc, value] of valuesMap) { if (typeOrFunc instanceof Function) { if (typeOrFunc(field)) { return value } - } else if (field.type === typeOrFunc) { + } else if (compareType === typeOrFunc) { return value } } diff --git a/web-frontend/test/unit/database/utils/field.spec.js b/web-frontend/test/unit/database/utils/field.spec.js index 36355bb68d..1c46de6346 100644 --- a/web-frontend/test/unit/database/utils/field.spec.js +++ b/web-frontend/test/unit/database/utils/field.spec.js @@ -3,8 +3,19 @@ import { hasCompatibleFilterTypes, } from '@baserow/modules/database/utils/field' import { EqualViewFilterType } from '@baserow/modules/database/viewFilters' +import { TestApp } from '@baserow/test/helpers/testApp' describe('test field utils', () => { + let testApp = null + + beforeAll(() => { + testApp = new TestApp() + }) + + afterEach((done) => { + testApp.afterEach().then(done) + }) + describe('getPrimaryOrFirstField', () => { it('should find the primary field in a list of fields', () => { const fields = [ @@ -53,7 +64,7 @@ describe('test field utils', () => { type: 'multiple_collaborators', primary: false, } - const filterType = new EqualViewFilterType() + const filterType = new EqualViewFilterType({ app: testApp._app }) expect(hasCompatibleFilterTypes(field, [filterType])).toBeFalsy() }) @@ -63,7 +74,7 @@ describe('test field utils', () => { type: 'text', primary: false, } - const filterType = new EqualViewFilterType() + const filterType = new EqualViewFilterType({ app: testApp._app }) expect(hasCompatibleFilterTypes(field, [filterType])).toBeTruthy() }) }) diff --git a/web-frontend/test/unit/database/viewFiltersMatch.spec.js b/web-frontend/test/unit/database/viewFiltersMatch.spec.js index 771dd0f274..99f75c81d2 100644 --- a/web-frontend/test/unit/database/viewFiltersMatch.spec.js +++ b/web-frontend/test/unit/database/viewFiltersMatch.spec.js @@ -2695,7 +2695,7 @@ describe('All Tests', () => { values.filterType ) const fieldType = new FormulaFieldType({ app: testApp._app }) - const field = { formula_type: 'url', formula: '' } + const field = { formula_type: 'url', formula: '', type: 'formula' } const result = filterClass.matches( values.rowValue !== undefined ? values.rowValue