Skip to content

Commit 823abba

Browse files
authored
fix: fix integration for custom translators that take options (#57)
* fix: fixes issues with setting a custom translator that takes options * chore: update lock * fix: use custom translator properly in sample dagster * fix: remove test_config * fix: address more issues with moving config around
1 parent 5ef684d commit 823abba

File tree

7 files changed

+103
-136
lines changed

7 files changed

+103
-136
lines changed

dagster_sqlmesh/config.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import typing as t
32
from dataclasses import dataclass
43
from pathlib import Path
@@ -8,8 +7,7 @@
87
from sqlmesh.core.config import Config as MeshConfig
98
from sqlmesh.core.config.loader import load_configs
109

11-
if t.TYPE_CHECKING:
12-
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
10+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
1311

1412

1513
@dataclass
@@ -28,23 +26,54 @@ class SQLMeshContextConfig(Config):
2826
running via dagster.
2927
3028
The config also manages the translator class used for converting SQLMesh
31-
models to Dagster assets. You can specify a custom translator by setting
32-
the translator_class_name field to the fully qualified class name.
29+
models to Dagster assets and provides a consistent translator to
30+
dagster-sqlmesh. The config must always be provided to the SQLMeshResource
31+
in order for the integration to function correctly. For example, when
32+
setting up the dagster Definitions, you must provide the
33+
SQLMeshContextConfig as a resource along with the SQLMeshResource as
34+
follows:
35+
36+
```python
37+
sqlmesh_context_config = SQLMeshContextConfig(
38+
path="/path/to/sqlmesh/project", gateway="local",
39+
)
40+
41+
@sqlmesh_assets(
42+
environment="dev", config=sqlmesh_context_config,
43+
enabled_subsetting=True,
44+
) def sqlmesh_project(
45+
context: AssetExecutionContext, sqlmesh: SQLMeshResource,
46+
sqlmesh_context_config: SQLMeshContextConfig
47+
) -> t.Iterator[MaterializeResult[t.Any]]:
48+
yield from sqlmesh.run(context, config=sqlmesh_config)
49+
50+
51+
defs = Definitions(
52+
assets=[sqlmesh_project], resources={
53+
"sqlmesh": SQLMeshResource(), "sqlmesh_context_config":
54+
sqlmesh_context_config,
55+
},
56+
)
57+
```
58+
59+
In order to provide a custom translator, you will need to subclass this
60+
class and return a different translator. However, due to the way that
61+
dagster assets/jobs/ops are run, you will need to ensure that the custom
62+
translator is _instantiated_ within the get_translator method rather than
63+
simply returning an instance variable. This is because dagster will
64+
serialize/deserialize the config object and any instance variables will
65+
not be preserved. Therefore, any options you'd like to pass to the translator
66+
must be serializable within your custom SQLMeshContextConfig subclass.
67+
68+
This class provides the minimum configuration required to run dagster-sqlmesh.
3369
"""
3470

3571
path: str
3672
gateway: str
3773
config_override: dict[str, t.Any] | None = Field(default_factory=lambda: None)
38-
translator_class_name: str = Field(
39-
default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator",
40-
description="Fully qualified class name of the SQLMesh Dagster translator to use"
41-
)
4274

43-
def get_translator(self) -> "SQLMeshDagsterTranslator":
44-
"""Get a translator instance using the configured class name.
45-
46-
Imports and validates the translator class, then creates a new instance.
47-
The class must inherit from SQLMeshDagsterTranslator.
75+
def get_translator(self) -> SQLMeshDagsterTranslator:
76+
"""Get a translator instance. Override this method to provide a custom translator.
4877
4978
Returns:
5079
SQLMeshDagsterTranslator: A new instance of the configured translator class
@@ -53,29 +82,7 @@ def get_translator(self) -> "SQLMeshDagsterTranslator":
5382
ValueError: If the imported object is not a class or does not inherit
5483
from SQLMeshDagsterTranslator
5584
"""
56-
from importlib import import_module
57-
58-
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
59-
60-
module_name, class_name = self.translator_class_name.rsplit(".", 1)
61-
module = import_module(module_name)
62-
translator_class = getattr(module, class_name)
63-
64-
# Validate that the imported class inherits from SQLMeshDagsterTranslator
65-
if not inspect.isclass(translator_class):
66-
raise ValueError(
67-
f"'{self.translator_class_name}' is not a class. "
68-
f"Expected a class that inherits from SQLMeshDagsterTranslator."
69-
)
70-
71-
if not issubclass(translator_class, SQLMeshDagsterTranslator):
72-
raise ValueError(
73-
f"Translator class '{self.translator_class_name}' must inherit from "
74-
f"SQLMeshDagsterTranslator. Found class that inherits from: "
75-
f"{[base.__name__ for base in translator_class.__bases__]}"
76-
)
77-
78-
return translator_class()
85+
return SQLMeshDagsterTranslator()
7986

8087
@property
8188
def sqlmesh_config(self) -> MeshConfig:

dagster_sqlmesh/resource.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -580,13 +580,13 @@ def errors(self) -> list[Exception]:
580580

581581

582582
class SQLMeshResource(dg.ConfigurableResource):
583-
config: SQLMeshContextConfig
584583
is_testing: bool = False
585584

586585
def run(
587586
self,
588587
context: dg.AssetExecutionContext,
589588
*,
589+
config: SQLMeshContextConfig,
590590
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
591591
environment: str = "dev",
592592
start: TimeLike | None = None,
@@ -606,7 +606,9 @@ def run(
606606
logger = context.log
607607

608608
controller = self.get_controller(
609-
context_factory=context_factory, log_override=logger
609+
config=config,
610+
context_factory=context_factory,
611+
log_override=logger
610612
)
611613

612614
with controller.instance(environment) as mesh:
@@ -618,7 +620,7 @@ def run(
618620
[model.fqn for model, _ in mesh.non_external_models_dag()]
619621
)
620622
selected_models_set, models_map, select_models = (
621-
self._get_selected_models_from_context(context=context, models=models)
623+
self._get_selected_models_from_context(context=context, config=config, models=models)
622624
)
623625

624626
if all_available_models == selected_models_set or select_models is None:
@@ -632,6 +634,7 @@ def run(
632634

633635
event_handler = self.create_event_handler(
634636
context=context,
637+
config=config,
635638
models_map=models_map,
636639
dag=dag,
637640
prefix="sqlmesh: ",
@@ -686,13 +689,14 @@ def create_event_handler(
686689
self,
687690
*,
688691
context: dg.AssetExecutionContext,
692+
config: SQLMeshContextConfig,
689693
dag: DAG[str],
690694
models_map: dict[str, Model],
691695
prefix: str,
692696
is_testing: bool,
693697
materializations_enabled: bool,
694698
) -> DagsterSQLMeshEventHandler:
695-
translator = self.config.get_translator()
699+
translator = config.get_translator()
696700
return DagsterSQLMeshEventHandler(
697701
context=context,
698702
dag=dag,
@@ -704,7 +708,10 @@ def create_event_handler(
704708
)
705709

706710
def _get_selected_models_from_context(
707-
self, context: dg.AssetExecutionContext, models: MappingProxyType[str, Model]
711+
self,
712+
context: dg.AssetExecutionContext,
713+
config: SQLMeshContextConfig,
714+
models: MappingProxyType[str, Model]
708715
) -> tuple[set[str], dict[str, Model], list[str] | None]:
709716
models_map = models.copy()
710717
try:
@@ -718,7 +725,7 @@ def _get_selected_models_from_context(
718725
else:
719726
raise e
720727

721-
translator = self.config.get_translator()
728+
translator = config.get_translator()
722729
select_models: list[str] = []
723730
models_map = {}
724731
for key, model in models.items():
@@ -733,11 +740,12 @@ def _get_selected_models_from_context(
733740

734741
def get_controller(
735742
self,
743+
config: SQLMeshContextConfig,
736744
context_factory: ContextFactory[ContextCls],
737745
log_override: logging.Logger | None = None,
738746
) -> DagsterSQLMeshController[ContextCls]:
739747
return DagsterSQLMeshController.setup_with_config(
740-
config=self.config,
748+
config=config,
741749
context_factory=context_factory,
742750
log_override=log_override,
743751
)

dagster_sqlmesh/test_config.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

dagster_sqlmesh/test_resource.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def test_sqlmesh_resource_should_report_no_errors(
5959
):
6060
resource = sample_sqlmesh_resource_initialization.resource
6161
dg_context = sample_sqlmesh_resource_initialization.dagster_context
62+
context_config = sample_sqlmesh_resource_initialization.test_context.context_config
6263

6364
success = True
6465
try:
65-
for result in resource.run(dg_context):
66+
for result in resource.run(context=dg_context, config=context_config):
6667
pass
6768
except PlanOrRunFailedError as e:
6869
success = False
@@ -83,10 +84,11 @@ def test_sqlmesh_resource_properly_reports_errors(
8384
)
8485
resource = sqlmesh_resource_initialization.resource
8586
dg_context = sqlmesh_resource_initialization.dagster_context
87+
context_config = sqlmesh_resource_initialization.test_context.context_config
8688

8789
caught_failure = False
8890
try:
89-
for result in resource.run(dg_context):
91+
for result in resource.run(context=dg_context, config=context_config):
9092
pass
9193
except PlanOrRunFailedError as e:
9294
caught_failure = True
@@ -106,6 +108,7 @@ def test_sqlmesh_resource_properly_reports_errors_not_thrown(
106108
):
107109
dg_context = sample_sqlmesh_resource_initialization.dagster_context
108110
resource = sample_sqlmesh_resource_initialization.resource
111+
context_config = sample_sqlmesh_resource_initialization.test_context.context_config
109112

110113
def event_handler_factory(
111114
*args: t.Any, **kwargs: t.Any
@@ -120,7 +123,7 @@ def event_handler_factory(
120123

121124
caught_failure = False
122125
try:
123-
for result in resource.run(dg_context):
126+
for result in resource.run(context=dg_context, config=context_config):
124127
pass
125128
except PlanOrRunFailedError as e:
126129
caught_failure = True
@@ -151,17 +154,19 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru
151154
resource = sample_sqlmesh_resource_initialization.resource
152155
dg_context = sample_sqlmesh_resource_initialization.dagster_context
153156
dg_instance = sample_sqlmesh_resource_initialization.dagster_instance
157+
context_config = sample_sqlmesh_resource_initialization.test_context.context_config
154158

155159
# First run should materialize all models
156160
initial_results: list[dg.MaterializeResult] = []
157-
for result in resource.run(dg_context):
161+
for result in resource.run(context=dg_context, config=context_config):
158162
initial_results.append(result)
159163
assert result.asset_key is not None, "Expected asset key to be present."
160-
dg_instance.report_runless_asset_event(dg.AssetMaterialization(
161-
asset_key=result.asset_key,
162-
metadata=result.metadata,
163-
))
164-
164+
dg_instance.report_runless_asset_event(
165+
dg.AssetMaterialization(
166+
asset_key=result.asset_key,
167+
metadata=result.metadata,
168+
)
169+
)
165170

166171
# All metadata times should be set to the same time
167172
initial_times: set[float] = set()
@@ -180,7 +185,7 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru
180185

181186
# Second run should also materialize all models
182187
second_results: list[dg.MaterializeResult] = []
183-
for result in resource.run(dg_context):
188+
for result in resource.run(context=dg_context, config=context_config):
184189
second_results.append(result)
185190

186191
assert len(second_results) > 0, "Expected second results to be non-empty."
@@ -204,7 +209,9 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru
204209
# Third run will restate the full model
205210
third_results: list[dg.MaterializeResult] = []
206211
for result in resource.run(
207-
dg_context, restate_models=["sqlmesh_example.full_model"]
212+
context=dg_context,
213+
config=context_config,
214+
restate_models=["sqlmesh_example.full_model"],
208215
):
209216
third_results.append(result)
210217

dagster_sqlmesh/testing/context.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class TestSQLMeshResource(SQLMeshResource):
5656
It allows for easy setup and teardown of the SQLMesh context.
5757
"""
5858

59-
def __init__(self, config: SQLMeshContextConfig, is_testing: bool = False):
60-
super().__init__(config=config, is_testing=is_testing)
59+
def __init__(self, is_testing: bool = False):
60+
super().__init__(is_testing=is_testing)
6161
def default_event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler:
6262
"""Default event handler factory for the SQLMesh resource."""
6363
return DagsterSQLMeshEventHandler(*args, **kwargs)
@@ -82,7 +82,9 @@ def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshE
8282
DagsterSQLMeshEventHandler: The created event handler.
8383
"""
8484
# Ensure translator is passed to the event handler factory
85-
kwargs['translator'] = self.config.get_translator()
85+
# FIXME: this is a hack to deal with an older signature that didn't expected the config
86+
config = t.cast(SQLMeshContextConfig, kwargs.pop("config"))
87+
kwargs["translator"] = config.get_translator()
8688
return self._event_handler_factory(*args, **kwargs)
8789

8890

@@ -99,9 +101,7 @@ def create_controller(self) -> DagsterSQLMeshController[Context]:
99101
)
100102

101103
def create_resource(self) -> TestSQLMeshResource:
102-
return TestSQLMeshResource(
103-
config=self.context_config, is_testing=True,
104-
)
104+
return TestSQLMeshResource(is_testing=True)
105105

106106
def query(self, *args: t.Any, **kwargs: t.Any) -> list[t.Any]:
107107
conn = duckdb.connect(self.db_path)

0 commit comments

Comments
 (0)