Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
from data_designer.engine.context import format_row_group_tag
from data_designer.engine.context import current_row_group_start_offset, format_row_group_tag
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.logging import LOG_INDENT
Expand Down Expand Up @@ -43,7 +43,11 @@ def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
if num_records <= 0:
raise ValueError("πŸ›‘ `num_records` must be positive.")

if self._batch_reader is None:
row_group_start_offset = current_row_group_start_offset.get()
if self.config.sampling_strategy == SamplingStrategy.ORDERED and row_group_start_offset is not None:
self._df_remaining = None
self._reset_batch_reader(num_records, record_offset=row_group_start_offset)
elif self._batch_reader is None:
self._reset_batch_reader(num_records)

return self._sample_records(num_records)
Expand Down Expand Up @@ -81,14 +85,36 @@ def _resolve_index_range(self) -> IndexRange | None:
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
return index_range

def _reset_batch_reader(self, num_records: int) -> None:
def _reset_batch_reader(self, num_records: int, *, record_offset: int = 0) -> None:
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
self._batch_reader = self.resource_provider.seed_reader.create_batch_reader(
batch_size=num_records,
index_range=self._index_range,
index_range=self._index_range_at_offset(record_offset),
shuffle=shuffle,
)

def _index_range_at_offset(self, record_offset: int) -> IndexRange | None:
# ORDERED sampling cycles through the index range when more records are
# requested than the selection contains. ``record_offset`` is the count
# of records already produced for prior row groups, so it may exceed
# ``selected_size`` after one or more full cycles. Modulo by selection
# size gives the next read position within the current cycle; when it
# lands at 0 we fall back to the original range so the next read starts
# at ``selected_start`` like a fresh cycle.
if record_offset <= 0:
return self._index_range

selected_start = self._index_range.start if self._index_range is not None else 0
selected_end = self._index_range.end if self._index_range is not None else self._seed_dataset_size - 1
selected_size = selected_end - selected_start + 1
if selected_size <= 0:
return self._index_range

relative_offset = record_offset % selected_size
if relative_offset == 0:
return self._index_range
return IndexRange(start=selected_start + relative_offset, end=selected_end)

def _sample_records(self, num_records: int) -> pd.DataFrame:
logger.info(f"🌱 {format_row_group_tag()}Sampling {num_records} records from seed dataset")
logger.info(f"{LOG_INDENT}seed dataset size: {self._seed_dataset_size} records")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

from contextvars import ContextVar

# Set by the async scheduler before executing each task.
# Set per row group by both engines: the async scheduler sets it before each
# task executes, and the sync engine's ``_run_batch`` sets it for each batch.
# Value: (current_rg_index, total_rg_count) or None.
current_row_group: ContextVar[tuple[int, int] | None] = ContextVar("current_row_group", default=None)

# Set while generating a row group. The value is the row group's planned start
# offset in the full dataset, including row groups skipped during resume.
current_row_group_start_offset: ContextVar[int | None] = ContextVar("current_row_group_start_offset", default=None)


def format_row_group_tag() -> str:
"""Return a '(x/X) ' prefix if a row group context is active, else ''."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
RequestAdmissionConfigSnapshot,
RowGroupAdmission,
)
from data_designer.engine.context import current_row_group
from data_designer.engine.context import current_row_group, current_row_group_start_offset
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
Expand Down Expand Up @@ -170,6 +170,8 @@ def __init__(
progress_bar: bool = False,
scheduler_event_sink: SchedulerAdmissionEventSink | None = None,
run_id: str | None = None,
row_group_start_offsets: dict[int, int] | None = None,
initial_completed_records: int = 0,
adaptive_row_group_admission: bool = False,
adaptive_row_group_initial_target: int = 1,
request_pressure_provider: RequestPressureSnapshotProvider | None = None,
Expand Down Expand Up @@ -288,11 +290,16 @@ def __init__(

# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)
self._rg_start_offset_map: dict[int, int] = row_group_start_offsets or self._build_row_group_start_offsets(
row_groups
)
self._max_concurrent_row_groups = max_concurrent_row_groups
self._max_in_flight_tasks = max_in_flight_tasks
self._max_model_task_admission = max_model_task_admission
self._num_records = num_records
self._buffer_size = buffer_size
self._scheduled_records = sum(size for _, size in row_groups)
self._initial_completed_records = initial_completed_records
self._observed_max_row_groups_in_flight = 0
self._observed_max_task_leases_by_resource: dict[str, int] = {}
self._observed_max_queued_by_group: dict[str, int] = {}
Expand Down Expand Up @@ -324,6 +331,15 @@ def __init__(
self._progress_bar = StickyProgressBar() if progress_bar else None
self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval)

@staticmethod
def _build_row_group_start_offsets(row_groups: list[tuple[int, int]]) -> dict[int, int]:
offsets: dict[int, int] = {}
next_offset = 0
for rg_id, rg_size in row_groups:
offsets[rg_id] = next_offset
next_offset += rg_size
return offsets

def _setup_async_progress_reporter(
self,
num_records: int,
Expand All @@ -342,6 +358,7 @@ def _setup_async_progress_reporter(
total_records=task_counts[col],
label=f"column '{col}'",
quiet=True,
initial_completed=self._initial_completed_records,
)

if not trackers:
Expand Down Expand Up @@ -1017,7 +1034,7 @@ async def run(self) -> None:

with self._progress_bar or contextlib.nullcontext():
if self._reporter:
self._reporter.log_start(num_row_groups=num_rgs)
self._reporter.log_start(num_row_groups=num_rgs, scheduled_records=self._scheduled_records)

self._emit_scheduler_event("scheduler_job_started", diagnostics=self._scheduler_job_diagnostics())
self._emit_scheduler_health_snapshot("start")
Expand Down Expand Up @@ -1550,6 +1567,7 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_
"""Core task execution logic."""
num_rgs = len(self._row_groups)
token = current_row_group.set((task.row_group, num_rgs))
start_offset_token = current_row_group_start_offset.set(self._rg_start_offset_map.get(task.row_group))
group = lease.item.group
identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16]
correlation_token = runtime_correlation_provider.set(
Expand All @@ -1567,6 +1585,7 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_
await self._execute_task_inner_impl(task, lease, task_execution_id)
finally:
runtime_correlation_provider.reset(correlation_token)
current_row_group_start_offset.reset(start_offset_token)
current_row_group.reset(token)

async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None:
Expand Down
Loading
Loading