diff --git a/README.md b/README.md index 2b90aff5..98d08fbe 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ OTel instrumentation is opt-in, controlled by environment variables: | Variable | Description | Default | | --------------------------------- | --------------------------------------------------------------------------------------------------------------------- | --------------- | | `OTEL_EXPORTER_OTLP_ENDPOINT` | Base OTLP endpoint (e.g. `http://collector:4318`). If unset, no OTel setup occurs. | _(disabled)_ | -| `OTEL_SERVICE_NAME` | The `service.name` resource attribute. | `flagsmith-api` | +| `OTEL_SERVICE_NAME` | The `service.name` resource attribute. Defaults to `flagsmith-task-processor` when running the task processor. | `flagsmith-api` | | `OTEL_TRACING_EXCLUDED_URL_PATHS` | Comma-separated URL paths to exclude from tracing (e.g. `health/liveness,health/readiness`). | _(none)_ | Standard `OTEL_*` env vars (e.g. `OTEL_RESOURCE_ATTRIBUTES`, `OTEL_EXPORTER_OTLP_HEADERS`) are also respected by the OTel SDK. @@ -121,6 +121,7 @@ When `OTEL_EXPORTER_OTLP_ENDPOINT` is set, `ensure_cli_env()` sets up: - **psycopg2** (`Psycopg2Instrumentor`): creates child spans for each SQL query with `db.system`, `db.statement`, and `db.name` attributes. SQL commenter is enabled, adding trace context as SQL comments for database-side correlation. - **Redis** (`RedisInstrumentor`): creates child spans for each Redis command with `db.system` and `db.statement` attributes. - **Structured log export**: A structlog processor that emits each log event as both an OTLP log record and a span event (when an active span exists). +- **Task processor trace propagation**: When a task is enqueued via `TaskHandler.delay()`, the current W3C trace context (including baggage) is serialized into the task's `trace_context` field. When the task processor executes the task, the context is extracted and a child span is created, linking the task execution back to the originating request trace. Span attributes (`task_identifier`, `task_type`, `result`) match the Prometheus metric labels for cross-signal correlation. #### Emitting OTel log events via structlog diff --git a/pyproject.toml b/pyproject.toml index b4a34ff6..cd6fc818 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ optional-dependencies = { test-tools = [ "backoff (>=2.2.1,<3.0.0)", "django (>4,<6)", "django-health-check", + "opentelemetry-api (>=1.25,<2)", "prometheus-client (>=0.0.16)", ], flagsmith-schemas = [ "simplejson", diff --git a/src/common/core/main.py b/src/common/core/main.py index 9f393888..6178ba7a 100644 --- a/src/common/core/main.py +++ b/src/common/core/main.py @@ -66,7 +66,12 @@ def ensure_cli_env() -> typing.Generator[None, None, None]: setup_tracing, ) - service_name = env.str("OTEL_SERVICE_NAME", "flagsmith-api") + default_service_name = ( + "flagsmith-task-processor" + if "task-processor" in sys.argv + else "flagsmith-api" + ) + service_name = env.str("OTEL_SERVICE_NAME", default_service_name) log_provider = build_otel_log_provider( endpoint=f"{otel_endpoint}/v1/logs", service_name=service_name, diff --git a/src/common/test_tools/types.py b/src/common/test_tools/types.py index 63b3eccd..ac627dfd 100644 --- a/src/common/test_tools/types.py +++ b/src/common/test_tools/types.py @@ -20,7 +20,7 @@ def __call__( class RunTasksFixture(Protocol): def __call__( self, - num_tasks: int, + num_tasks: int = 1, ) -> "list[TaskRun]": ... diff --git a/src/task_processor/decorators.py b/src/task_processor/decorators.py index 0749f5eb..e396dae0 100644 --- a/src/task_processor/decorators.py +++ b/src/task_processor/decorators.py @@ -6,12 +6,13 @@ from django.conf import settings from django.db.transaction import on_commit from django.utils import timezone +from opentelemetry import propagate from task_processor import metrics, task_registry from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError from task_processor.models import RecurringTask, Task, TaskPriority from task_processor.task_run_method import TaskRunMethod -from task_processor.types import TaskCallable, TaskParameters +from task_processor.types import TaskCallable, TaskParameters, TraceContext from task_processor.utils import get_task_identifier_from_function logger = logging.getLogger(__name__) @@ -92,6 +93,8 @@ def delay( task_identifier=task_identifier ).inc() try: + carrier: TraceContext = {} + propagate.inject(carrier) task = Task.create( task_identifier=task_identifier, scheduled_for=delay_until or timezone.now(), @@ -100,6 +103,7 @@ def delay( timeout=self.timeout, args=args, kwargs=kwargs, + trace_context=carrier or None, ) except TaskQueueFullError as e: logger.warning(e) diff --git a/src/task_processor/migrations/0014_add_trace_context.py b/src/task_processor/migrations/0014_add_trace_context.py new file mode 100644 index 00000000..56bf31ca --- /dev/null +++ b/src/task_processor/migrations/0014_add_trace_context.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.12 on 2026-04-10 14:29 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("task_processor", "0013_add_last_picked_at"), + ] + + operations = [ + migrations.AddField( + model_name="recurringtask", + name="trace_context", + field=models.JSONField(blank=True, null=True), + ), + migrations.AddField( + model_name="task", + name="trace_context", + field=models.JSONField(blank=True, null=True), + ), + ] diff --git a/src/task_processor/models.py b/src/task_processor/models.py index 6d198a52..47851f92 100644 --- a/src/task_processor/models.py +++ b/src/task_processor/models.py @@ -10,7 +10,7 @@ from task_processor.exceptions import TaskQueueFullError from task_processor.managers import RecurringTaskManager, TaskManager from task_processor.task_registry import get_task, registered_tasks -from task_processor.types import TaskCallable +from task_processor.types import TaskCallable, TraceContext _django_json_encoder_default = DjangoJSONEncoder().default @@ -31,6 +31,7 @@ class AbstractBaseTask(models.Model): serialized_kwargs = models.TextField(blank=True, null=True) is_locked = models.BooleanField(default=False) timeout = models.DurationField(blank=True, null=True) + trace_context = models.JSONField(null=True, blank=True) class Meta: abstract = True @@ -112,6 +113,7 @@ def create( args: typing.Tuple[typing.Any, ...] | None = None, kwargs: typing.Dict[str, typing.Any] | None = None, timeout: timedelta | None = timedelta(seconds=60), + trace_context: TraceContext | None = None, ) -> "Task": if queue_size and cls._is_queue_full(task_identifier, queue_size): raise TaskQueueFullError( @@ -125,6 +127,7 @@ def create( serialized_args=cls.serialize_data(args or tuple()), serialized_kwargs=cls.serialize_data(kwargs or dict()), timeout=timeout, + trace_context=trace_context, ) @classmethod diff --git a/src/task_processor/processor.py b/src/task_processor/processor.py index daab3aee..2f72a2ef 100644 --- a/src/task_processor/processor.py +++ b/src/task_processor/processor.py @@ -4,9 +4,12 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack from datetime import timedelta +from importlib.metadata import version from django.conf import settings from django.utils import timezone +from opentelemetry import context as otel_context +from opentelemetry import propagate, trace from task_processor import metrics from task_processor.exceptions import TaskBackoffError @@ -130,6 +133,11 @@ def _run_task( result: str executor = None + extracted_ctx = propagate.extract(task.trace_context or {}) + tracer = trace.get_tracer("task_processor", version("flagsmith-common")) + span = tracer.start_span(task_identifier, context=extracted_ctx) + otel_token = otel_context.attach(trace.set_span_in_context(span, extracted_ctx)) + try: # Use explicit executor management to avoid blocking on shutdown # when tasks timeout but continue running in worker threads. @@ -151,6 +159,9 @@ def _run_task( # fall back to using repr. err_msg = str(e) or repr(e) + span.set_status(trace.StatusCode.ERROR, err_msg) + span.record_exception(e) + task.mark_failure() task_run.result = result = TaskResult.FAILURE.value @@ -199,4 +210,8 @@ def _run_task( metrics.flagsmith_task_processor_finished_tasks_total.labels(**labels).inc() + span.set_attributes(labels) + span.end() + otel_context.detach(otel_token) + return task, task_run diff --git a/src/task_processor/types.py b/src/task_processor/types.py index 3b32a81e..92e84e0d 100644 --- a/src/task_processor/types.py +++ b/src/task_processor/types.py @@ -5,6 +5,8 @@ TaskCallable: TypeAlias = Callable[TaskParameters, None] +TraceContext: TypeAlias = dict[str, str] + @dataclass class TaskProcessorConfig: diff --git a/tests/conftest.py b/tests/conftest.py index c7b2f566..afe35873 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,6 +79,12 @@ def otel_tracing() -> Generator[InMemorySpanExporter, None, None]: yield exporter +@pytest.fixture() +def span_exporter(otel_tracing: InMemorySpanExporter) -> InMemorySpanExporter: + otel_tracing.clear() + return otel_tracing + + @pytest.fixture(scope="session") def test_metric() -> prometheus_client.Counter: return prometheus_client.Counter( diff --git a/tests/integration/core/test_otel.py b/tests/integration/core/test_otel.py index 9a7be41b..3f9cb5d6 100644 --- a/tests/integration/core/test_otel.py +++ b/tests/integration/core/test_otel.py @@ -64,12 +64,6 @@ def setup_logging_fixture( structlog.reset_defaults() -@pytest.fixture() -def span_exporter(otel_tracing: InMemorySpanExporter) -> InMemorySpanExporter: - otel_tracing.clear() - return otel_tracing - - def test_structlog_otel_log_record__basic_event__body_event_name_severity_attributes( log_exporter: InMemoryLogExporter, ) -> None: diff --git a/tests/unit/common/core/test_main.py b/tests/unit/common/core/test_main.py index 7a75af7a..8d6d00d9 100644 --- a/tests/unit/common/core/test_main.py +++ b/tests/unit/common/core/test_main.py @@ -168,3 +168,70 @@ def test_ensure_cli_env__task_processor_in_argv__sets_run_by_processor( # When / Then with ensure_cli_env(): assert os.environ.get("RUN_BY_PROCESSOR") == "true" + + +def test_ensure_cli_env__task_processor__expected_otel_service_name( + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, +) -> None: + # Given + monkeypatch.setattr("sys.argv", ["flagsmith", "task-processor"]) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://collector:4318") + + mock_build_log = mocker.patch( + "common.core.otel.build_otel_log_provider", + return_value=mocker.MagicMock(spec=LoggerProvider), + ) + mock_build_tracer = mocker.patch( + "common.core.otel.build_tracer_provider", + return_value=mocker.MagicMock(spec=TracerProvider), + ) + mocker.patch("common.core.otel.setup_tracing") + + # When + with ensure_cli_env(): + pass + + # Then + mock_build_log.assert_called_once_with( + endpoint="http://collector:4318/v1/logs", + service_name="flagsmith-task-processor", + ) + mock_build_tracer.assert_called_once_with( + endpoint="http://collector:4318/v1/traces", + service_name="flagsmith-task-processor", + ) + + +def test_ensure_cli_env__env_service_name__expected_otel_service_name( + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, +) -> None: + # Given + monkeypatch.setattr("sys.argv", ["flagsmith", "task-processor"]) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://collector:4318") + monkeypatch.setenv("OTEL_SERVICE_NAME", "my-custom") + + mock_build_log = mocker.patch( + "common.core.otel.build_otel_log_provider", + return_value=mocker.MagicMock(spec=LoggerProvider), + ) + mock_build_tracer = mocker.patch( + "common.core.otel.build_tracer_provider", + return_value=mocker.MagicMock(spec=TracerProvider), + ) + mocker.patch("common.core.otel.setup_tracing") + + # When + with ensure_cli_env(): + pass + + # Then + mock_build_log.assert_called_once_with( + endpoint="http://collector:4318/v1/logs", + service_name="my-custom", + ) + mock_build_tracer.assert_called_once_with( + endpoint="http://collector:4318/v1/traces", + service_name="my-custom", + ) diff --git a/tests/unit/task_processor/test_unit_task_processor_decorators.py b/tests/unit/task_processor/test_unit_task_processor_decorators.py index d54a380a..952c019d 100644 --- a/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest +from opentelemetry import baggage, context, trace from pytest_django import DjangoCaptureOnCommitCallbacks from pytest_django.fixtures import SettingsWrapper from pytest_mock import MockerFixture @@ -252,3 +253,56 @@ def my_function(*args: typing.Any, **kwargs: typing.Any) -> None: # Then assert task assert task.priority == TaskPriority.HIGH + + +@pytest.mark.django_db +def test_delay__active_trace__persists_trace_context_on_task() -> None: + # Given + @register_task_handler() + def my_function() -> None: ... + + tracer = trace.get_tracer("test") + + # When + with tracer.start_as_current_span("test-request"): + task = my_function.delay() + + # Then + assert task is not None + assert task.trace_context is not None + assert "traceparent" in task.trace_context + + +@pytest.mark.django_db +def test_delay__no_active_trace__persists_empty_trace_context() -> None: + # Given + @register_task_handler() + def my_function() -> None: ... + + # When + task = my_function.delay() + + # Then + assert task is not None + assert task.trace_context is None + + +@pytest.mark.django_db +def test_delay__baggage__persists_baggage_in_trace_context() -> None: + # Given + @register_task_handler() + def my_function() -> None: ... + + tracer = trace.get_tracer("test") + ctx = baggage.set_baggage("amplitude.device_id", "device-123") + context.attach(ctx) + + # When + with tracer.start_as_current_span("test-request"): + task = my_function.delay() + + # Then + assert task is not None + assert task.trace_context is not None + assert "baggage" in task.trace_context + assert "amplitude.device_id=device-123" in task.trace_context["baggage"] diff --git a/tests/unit/task_processor/test_unit_task_processor_models.py b/tests/unit/task_processor/test_unit_task_processor_models.py index fd82ad31..6da31acc 100644 --- a/tests/unit/task_processor/test_unit_task_processor_models.py +++ b/tests/unit/task_processor/test_unit_task_processor_models.py @@ -117,3 +117,30 @@ def test_recurring_task_should_execute__first_run_time_before_midnight__returns_ # When & Then assert task.should_execute is True + + +@pytest.mark.parametrize( + "trace_context", + [ + pytest.param( + {"traceparent": "00-abcdef-123456-01", "baggage": "key=val"}, + id="with_trace_context", + ), + pytest.param(None, id="without_trace_context"), + ], +) +@pytest.mark.django_db +def test_task_create__trace_context__persists_expected( + trace_context: dict[str, str] | None, +) -> None: + # Given / When + task = Task.create( + task_identifier="test_task", + scheduled_for=timezone.now(), + trace_context=trace_context, + ) + task.save() + + # Then + task.refresh_from_db() + assert task.trace_context == trace_context diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 92f2e734..bcc59b4b 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -8,10 +8,15 @@ from django.core.cache import cache from django.utils import timezone from freezegun import freeze_time +from opentelemetry import trace +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from opentelemetry.trace import StatusCode from pytest_django.fixtures import SettingsWrapper from pytest_mock import MockerFixture -from common.test_tools.types import AssertMetricFixture +from common.test_tools.types import AssertMetricFixture, RunTasksFixture from task_processor.decorators import ( TaskHandler, register_recurring_task, @@ -1032,3 +1037,138 @@ def test_run_task__timeout__does_not_block( task.refresh_from_db(using=current_database) assert task.completed is False assert task.num_failures == 1 + + +PARENT_TRACE_ID = 0x0AF7651916CD43DD8448EB211C80319C +PARENT_SPAN_ID = 0xB7AD6B7169203331 +TRACEPARENT = f"00-{PARENT_TRACE_ID:032x}-{PARENT_SPAN_ID:016x}-01" + + +@pytest.mark.parametrize( + "trace_context, expected_parent_trace_id, expected_parent_span_id", + [ + pytest.param( + {"traceparent": TRACEPARENT}, + PARENT_TRACE_ID, + PARENT_SPAN_ID, + id="with_parent_span", + ), + pytest.param({}, None, None, id="empty_carrier"), + pytest.param(None, None, None, id="null"), + ], +) +def test_run_task__success__expected_span( + trace_context: dict[str, str] | None, + expected_parent_trace_id: int | None, + expected_parent_span_id: int | None, + run_tasks: RunTasksFixture, + dummy_task: TaskHandler[[str, str]], + span_exporter: InMemorySpanExporter, +) -> None: + # Given + task = Task.create( + dummy_task.task_identifier, + scheduled_for=timezone.now(), + trace_context=trace_context, + ) + task.save() + + # When + run_tasks() + + # Then + spans = span_exporter.get_finished_spans() + task_spans = [s for s in spans if s.name == dummy_task.task_identifier] + assert len(task_spans) == 1 + + task_span = task_spans[0] + assert task_span.attributes is not None + assert task_span.attributes["task_identifier"] == dummy_task.task_identifier + assert task_span.attributes["task_type"] == "standard" + assert task_span.attributes["result"] == "success" + assert getattr(task_span.parent, "trace_id", None) == expected_parent_trace_id + assert getattr(task_span.parent, "span_id", None) == expected_parent_span_id + + +@pytest.mark.parametrize( + "trace_context, expected_parent_trace_id, expected_parent_span_id", + [ + pytest.param( + {"traceparent": TRACEPARENT}, + PARENT_TRACE_ID, + PARENT_SPAN_ID, + id="with_parent_span", + ), + pytest.param({}, None, None, id="empty_carrier"), + pytest.param(None, None, None, id="null"), + ], +) +def test_run_task__failure__expected_span( + trace_context: dict[str, str] | None, + expected_parent_trace_id: int | None, + expected_parent_span_id: int | None, + run_tasks: RunTasksFixture, + raise_exception_task: TaskHandler[[str]], + span_exporter: InMemorySpanExporter, +) -> None: + # Given + task = Task.create( + raise_exception_task.task_identifier, + scheduled_for=timezone.now(), + args=("test error",), + trace_context=trace_context, + ) + task.save() + + # When + run_tasks() + + # Then + spans = span_exporter.get_finished_spans() + task_spans = [s for s in spans if s.name == raise_exception_task.task_identifier] + assert len(task_spans) == 1 + + task_span = task_spans[0] + assert task_span.status.status_code == StatusCode.ERROR + assert task_span.attributes is not None + assert ( + task_span.attributes["task_identifier"] == raise_exception_task.task_identifier + ) + assert task_span.attributes["task_type"] == "standard" + assert task_span.attributes["result"] == "failure" + assert getattr(task_span.parent, "trace_id", None) == expected_parent_trace_id + assert getattr(task_span.parent, "span_id", None) == expected_parent_span_id + + +def test_delay_and_run__active_trace__propagates_to_task_span( + run_tasks: RunTasksFixture, + span_exporter: InMemorySpanExporter, +) -> None: + # Given + @register_task_handler() + def _traced_task() -> None: + pass + + tracer = trace.get_tracer("test") + + # When — enqueue within an active span + with tracer.start_as_current_span("django-request") as request_span: + task = _traced_task.delay() + + assert task is not None + assert task.trace_context is not None + assert "traceparent" in task.trace_context + + # Run the task + run_tasks() + + # Then — verify full trace linkage + spans = span_exporter.get_finished_spans() + task_spans = [s for s in spans if s.name == _traced_task.task_identifier] + assert len(task_spans) == 1 + + task_span = task_spans[0] + request_ctx = request_span.get_span_context() + assert task_span.context.trace_id == request_ctx.trace_id + assert task_span.parent is not None + assert task_span.parent.span_id == request_ctx.span_id diff --git a/uv.lock b/uv.lock index 6d2b49a5..cecd6463 100644 --- a/uv.lock +++ b/uv.lock @@ -487,6 +487,7 @@ task-processor = [ { name = "backoff" }, { name = "django" }, { name = "django-health-check" }, + { name = "opentelemetry-api" }, { name = "prometheus-client" }, ] test-tools = [ @@ -532,6 +533,7 @@ requires-dist = [ { name = "gunicorn", marker = "extra == 'common-core'", specifier = ">=19.1" }, { name = "inflection", marker = "extra == 'common-core'" }, { name = "opentelemetry-api", marker = "extra == 'common-core'", specifier = ">=1.25,<2" }, + { name = "opentelemetry-api", marker = "extra == 'task-processor'", specifier = ">=1.25,<2" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'common-core'", specifier = ">=1.25,<2" }, { name = "opentelemetry-instrumentation-django", marker = "extra == 'common-core'", specifier = ">=0.46b0,<1" }, { name = "opentelemetry-instrumentation-psycopg2", marker = "extra == 'common-core'", specifier = ">=0.46b0,<1" },