-
Notifications
You must be signed in to change notification settings - Fork 456
dynamic sparse rewards #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a8f6e49
fc6ec2b
fe41304
849b1c1
95a9546
12a0aad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,20 +86,35 @@ class GenerateOutputs(BaseModel): | |
| reward: list[float] | ||
| metrics: dict[str, list[float]] = Field(default_factory=dict) | ||
| metadata: GenerateMetadata | ||
| sparse_metrics: dict[str, list[bool]] | None = Field(default=None) | ||
| # ^^ pptional sparse tracking for multi-domain environments | ||
| # When present, sparse_metrics[metric_name] indicates which rollout values should be | ||
| # excluded from averaging (e.g., domain-specific metrics evaluated on irrelevant tasks). | ||
| # True = sparse (exclude from average), False = relevant (include in average) | ||
| # Example: chemistry_reward=[50.0, 0.0, 75.0] with sparse_metrics={"chemistry_reward": [False, True, False]} | ||
| # would average to 62.5 instead of 41.7, excluding the irrelevant 0.0 score. | ||
|
|
||
|
|
||
| class RolloutScore(BaseModel): | ||
| """Pydantic model for rollout scores.""" | ||
|
|
||
| reward: float | ||
| metrics: dict[str, float] = Field(default_factory=dict) | ||
| sparse_metrics: set[str] | None = Field(default=None) | ||
| # ^^ set of metric names that should be excluded from averaging for this rollout | ||
| # Used by rubrics to mark domain-specific metrics as irrelevant for certain tasks | ||
| # Example: {"chemistry_reward", "physics_reward"} when evaluating a finance task | ||
|
|
||
|
|
||
| class RolloutScores(BaseModel): | ||
| """Pydantic model for rubric outputs.""" | ||
|
|
||
| reward: list[float] | ||
| metrics: dict[str, list[float]] = Field(default_factory=dict) | ||
| sparse_metrics: dict[str, list[bool]] | None = Field(default=None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i dont like the name sparse metrics here. to me, this implies that this is the actual float metrics after filtering. would prefer a name that is indicative of the fact that these are boolean flags, maybe smth like |
||
| # ^^ per-rollout exclusion flags for batch scoring | ||
| # Maps metric names to lists of boolean flags (True = sparse, False = relevant) | ||
| # Length matches the rollout lists in reward/metrics. Aggregated from individual RolloutScore.sparse_metrics | ||
|
|
||
|
|
||
| class ProcessedOutputs(BaseModel): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,10 +89,62 @@ def print_results(results: GenerateOutputs, num_samples: int = 1): | |
| print(out) | ||
| for k in results.metrics: | ||
| v = results.metrics[k] | ||
| print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") | ||
|
|
||
| # selective averaging that excludes sparse values | ||
| # only average over relevant (non-sparse) values | ||
| # instead of including misleading zeros in the calculation | ||
| if ( | ||
| hasattr(results, "sparse_metrics") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is always true? |
||
| and results.sparse_metrics | ||
| and k in results.sparse_metrics | ||
| ): | ||
| # filter out sparse values from averaging calculation | ||
| # sparse_flags[i] = True means exclude rollout i from averaging | ||
| sparse_flags = results.sparse_metrics[k] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yea look here you call it sparse_flags as well haha, this is alr better than sparse_metrics |
||
| relevant_values = [ | ||
| val for val, is_sparse in zip(v, sparse_flags) if not is_sparse | ||
| ] | ||
|
|
||
| if relevant_values: | ||
| # calculate statistics over only the relevant (non-sparse) values | ||
| # this gives mathematically correct domain-specific averages | ||
| avg = sum(relevant_values) / len(relevant_values) | ||
| std = np.std(relevant_values) | ||
| sparsity_info = f" (relevant: {len(relevant_values)}/{len(v)})" | ||
| print(f"{k}: avg - {avg:.3f}, std - {std:.3f}{sparsity_info}") | ||
| else: | ||
| # all values marked sparse - no relevant data to average | ||
| print(f"{k}: no relevant data (all values sparse)") | ||
| else: | ||
| # standard averaging for non-sparse metrics (backwards compatible) | ||
| # this preserves existing behavior for environments without sparse metrics | ||
| print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") | ||
|
|
||
| # enhanced rollout display that shows sparsity clearly | ||
| # Instead of showing misleading 0.0 values, display "-" for sparse metrics | ||
| # This makes it immediately obvious which rollouts are relevant vs excluded | ||
| for i in range(r): | ||
| # rounded to 3 decimal places | ||
| trials = [round(v[(i * n) + j], 3) for j in range(n)] | ||
| if ( | ||
| hasattr(results, "sparse_metrics") | ||
| and results.sparse_metrics | ||
| and k in results.sparse_metrics | ||
| ): | ||
| # For sparse metrics: "-" indicates sparse (irrelevant), numbers show actual values | ||
| # This visual distinction prevents confusion about which values contribute to averages | ||
| sparse_flags = results.sparse_metrics[k] | ||
| trials = [] | ||
| for j in range(n): | ||
| idx = (i * n) + j | ||
| if sparse_flags[idx]: | ||
| # sparse value - show "-" instead of 0.0 to indicate exclusion from averaging | ||
| trials.append("-") | ||
| else: | ||
| # non-sparse value - show actual computed score | ||
| trials.append(round(v[idx], 3)) | ||
| else: | ||
| # standard rollout printing for non-sparse metrics (backwards compatible) | ||
| # all values shown as numbers since none are excluded from averaging | ||
| trials = [round(v[(i * n) + j], 3) for j in range(n)] | ||
| out = f"r{i + 1}: {trials}" | ||
| print(out) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob leftover from smth else?