Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 44 additions & 37 deletions dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import typing as t
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -8,8 +7,7 @@
from sqlmesh.core.config import Config as MeshConfig
from sqlmesh.core.config.loader import load_configs

if t.TYPE_CHECKING:
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator


@dataclass
Expand All @@ -28,23 +26,54 @@ class SQLMeshContextConfig(Config):
running via dagster.

The config also manages the translator class used for converting SQLMesh
models to Dagster assets. You can specify a custom translator by setting
the translator_class_name field to the fully qualified class name.
models to Dagster assets and provides a consistent translator to
dagster-sqlmesh. The config must always be provided to the SQLMeshResource
in order for the integration to function correctly. For example, when
setting up the dagster Definitions, you must provide the
SQLMeshContextConfig as a resource along with the SQLMeshResource as
follows:

```python
sqlmesh_context_config = SQLMeshContextConfig(
path="/path/to/sqlmesh/project", gateway="local",
)

@sqlmesh_assets(
environment="dev", config=sqlmesh_context_config,
enabled_subsetting=True,
) def sqlmesh_project(
context: AssetExecutionContext, sqlmesh: SQLMeshResource,
sqlmesh_context_config: SQLMeshContextConfig
) -> t.Iterator[MaterializeResult[t.Any]]:
yield from sqlmesh.run(context, config=sqlmesh_config)


defs = Definitions(
assets=[sqlmesh_project], resources={
"sqlmesh": SQLMeshResource(), "sqlmesh_context_config":
sqlmesh_context_config,
},
)
```

In order to provide a custom translator, you will need to subclass this
class and return a different translator. However, due to the way that
dagster assets/jobs/ops are run, you will need to ensure that the custom
translator is _instantiated_ within the get_translator method rather than
simply returning an instance variable. This is because dagster will
serialize/deserialize the config object and any instance variables will
not be preserved. Therefore, any options you'd like to pass to the translator
must be serializable within your custom SQLMeshContextConfig subclass.

This class provides the minimum configuration required to run dagster-sqlmesh.
"""

path: str
gateway: str
config_override: dict[str, t.Any] | None = Field(default_factory=lambda: None)
translator_class_name: str = Field(
default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator",
description="Fully qualified class name of the SQLMesh Dagster translator to use"
)

def get_translator(self) -> "SQLMeshDagsterTranslator":
"""Get a translator instance using the configured class name.

Imports and validates the translator class, then creates a new instance.
The class must inherit from SQLMeshDagsterTranslator.
def get_translator(self) -> SQLMeshDagsterTranslator:
"""Get a translator instance. Override this method to provide a custom translator.

Returns:
SQLMeshDagsterTranslator: A new instance of the configured translator class
Expand All @@ -53,29 +82,7 @@ def get_translator(self) -> "SQLMeshDagsterTranslator":
ValueError: If the imported object is not a class or does not inherit
from SQLMeshDagsterTranslator
"""
from importlib import import_module

from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

module_name, class_name = self.translator_class_name.rsplit(".", 1)
module = import_module(module_name)
translator_class = getattr(module, class_name)

# Validate that the imported class inherits from SQLMeshDagsterTranslator
if not inspect.isclass(translator_class):
raise ValueError(
f"'{self.translator_class_name}' is not a class. "
f"Expected a class that inherits from SQLMeshDagsterTranslator."
)

if not issubclass(translator_class, SQLMeshDagsterTranslator):
raise ValueError(
f"Translator class '{self.translator_class_name}' must inherit from "
f"SQLMeshDagsterTranslator. Found class that inherits from: "
f"{[base.__name__ for base in translator_class.__bases__]}"
)

return translator_class()
return SQLMeshDagsterTranslator()

@property
def sqlmesh_config(self) -> MeshConfig:
Expand Down
22 changes: 15 additions & 7 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,13 @@ def errors(self) -> list[Exception]:


class SQLMeshResource(dg.ConfigurableResource):
config: SQLMeshContextConfig
is_testing: bool = False

def run(
self,
context: dg.AssetExecutionContext,
*,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
environment: str = "dev",
start: TimeLike | None = None,
Expand All @@ -606,7 +606,9 @@ def run(
logger = context.log

controller = self.get_controller(
context_factory=context_factory, log_override=logger
config=config,
context_factory=context_factory,
log_override=logger
)

with controller.instance(environment) as mesh:
Expand All @@ -618,7 +620,7 @@ def run(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context=context, models=models)
self._get_selected_models_from_context(context=context, config=config, models=models)
)

if all_available_models == selected_models_set or select_models is None:
Expand All @@ -632,6 +634,7 @@ def run(

event_handler = self.create_event_handler(
context=context,
config=config,
models_map=models_map,
dag=dag,
prefix="sqlmesh: ",
Expand Down Expand Up @@ -686,13 +689,14 @@ def create_event_handler(
self,
*,
context: dg.AssetExecutionContext,
config: SQLMeshContextConfig,
dag: DAG[str],
models_map: dict[str, Model],
prefix: str,
is_testing: bool,
materializations_enabled: bool,
) -> DagsterSQLMeshEventHandler:
translator = self.config.get_translator()
translator = config.get_translator()
return DagsterSQLMeshEventHandler(
context=context,
dag=dag,
Expand All @@ -704,7 +708,10 @@ def create_event_handler(
)

def _get_selected_models_from_context(
self, context: dg.AssetExecutionContext, models: MappingProxyType[str, Model]
self,
context: dg.AssetExecutionContext,
config: SQLMeshContextConfig,
models: MappingProxyType[str, Model]
) -> tuple[set[str], dict[str, Model], list[str] | None]:
models_map = models.copy()
try:
Expand All @@ -718,7 +725,7 @@ def _get_selected_models_from_context(
else:
raise e

translator = self.config.get_translator()
translator = config.get_translator()
select_models: list[str] = []
models_map = {}
for key, model in models.items():
Expand All @@ -733,11 +740,12 @@ def _get_selected_models_from_context(

def get_controller(
self,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
) -> DagsterSQLMeshController[ContextCls]:
return DagsterSQLMeshController.setup_with_config(
config=self.config,
config=config,
context_factory=context_factory,
log_override=log_override,
)
65 changes: 0 additions & 65 deletions dagster_sqlmesh/test_config.py

This file was deleted.

29 changes: 18 additions & 11 deletions dagster_sqlmesh/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def test_sqlmesh_resource_should_report_no_errors(
):
resource = sample_sqlmesh_resource_initialization.resource
dg_context = sample_sqlmesh_resource_initialization.dagster_context
context_config = sample_sqlmesh_resource_initialization.test_context.context_config

success = True
try:
for result in resource.run(dg_context):
for result in resource.run(context=dg_context, config=context_config):
pass
except PlanOrRunFailedError as e:
success = False
Expand All @@ -83,10 +84,11 @@ def test_sqlmesh_resource_properly_reports_errors(
)
resource = sqlmesh_resource_initialization.resource
dg_context = sqlmesh_resource_initialization.dagster_context
context_config = sqlmesh_resource_initialization.test_context.context_config

caught_failure = False
try:
for result in resource.run(dg_context):
for result in resource.run(context=dg_context, config=context_config):
pass
except PlanOrRunFailedError as e:
caught_failure = True
Expand All @@ -106,6 +108,7 @@ def test_sqlmesh_resource_properly_reports_errors_not_thrown(
):
dg_context = sample_sqlmesh_resource_initialization.dagster_context
resource = sample_sqlmesh_resource_initialization.resource
context_config = sample_sqlmesh_resource_initialization.test_context.context_config

def event_handler_factory(
*args: t.Any, **kwargs: t.Any
Expand All @@ -120,7 +123,7 @@ def event_handler_factory(

caught_failure = False
try:
for result in resource.run(dg_context):
for result in resource.run(context=dg_context, config=context_config):
pass
except PlanOrRunFailedError as e:
caught_failure = True
Expand Down Expand Up @@ -151,17 +154,19 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru
resource = sample_sqlmesh_resource_initialization.resource
dg_context = sample_sqlmesh_resource_initialization.dagster_context
dg_instance = sample_sqlmesh_resource_initialization.dagster_instance
context_config = sample_sqlmesh_resource_initialization.test_context.context_config

# First run should materialize all models
initial_results: list[dg.MaterializeResult] = []
for result in resource.run(dg_context):
for result in resource.run(context=dg_context, config=context_config):
initial_results.append(result)
assert result.asset_key is not None, "Expected asset key to be present."
dg_instance.report_runless_asset_event(dg.AssetMaterialization(
asset_key=result.asset_key,
metadata=result.metadata,
))

dg_instance.report_runless_asset_event(
dg.AssetMaterialization(
asset_key=result.asset_key,
metadata=result.metadata,
)
)

# All metadata times should be set to the same time
initial_times: set[float] = set()
Expand All @@ -180,7 +185,7 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru

# Second run should also materialize all models
second_results: list[dg.MaterializeResult] = []
for result in resource.run(dg_context):
for result in resource.run(context=dg_context, config=context_config):
second_results.append(result)

assert len(second_results) > 0, "Expected second results to be non-empty."
Expand All @@ -204,7 +209,9 @@ def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_ru
# Third run will restate the full model
third_results: list[dg.MaterializeResult] = []
for result in resource.run(
dg_context, restate_models=["sqlmesh_example.full_model"]
context=dg_context,
config=context_config,
restate_models=["sqlmesh_example.full_model"],
):
third_results.append(result)

Expand Down
12 changes: 6 additions & 6 deletions dagster_sqlmesh/testing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class TestSQLMeshResource(SQLMeshResource):
It allows for easy setup and teardown of the SQLMesh context.
"""

def __init__(self, config: SQLMeshContextConfig, is_testing: bool = False):
super().__init__(config=config, is_testing=is_testing)
def __init__(self, is_testing: bool = False):
super().__init__(is_testing=is_testing)
def default_event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler:
"""Default event handler factory for the SQLMesh resource."""
return DagsterSQLMeshEventHandler(*args, **kwargs)
Expand All @@ -82,7 +82,9 @@ def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshE
DagsterSQLMeshEventHandler: The created event handler.
"""
# Ensure translator is passed to the event handler factory
kwargs['translator'] = self.config.get_translator()
# FIXME: this is a hack to deal with an older signature that didn't expected the config
config = t.cast(SQLMeshContextConfig, kwargs.pop("config"))
kwargs["translator"] = config.get_translator()
return self._event_handler_factory(*args, **kwargs)


Expand All @@ -99,9 +101,7 @@ def create_controller(self) -> DagsterSQLMeshController[Context]:
)

def create_resource(self) -> TestSQLMeshResource:
return TestSQLMeshResource(
config=self.context_config, is_testing=True,
)
return TestSQLMeshResource(is_testing=True)

def query(self, *args: t.Any, **kwargs: t.Any) -> list[t.Any]:
conn = duckdb.connect(self.db_path)
Expand Down
Loading