-
Notifications
You must be signed in to change notification settings - Fork 118
refactor: set do_sample=True if a seed is set on HF backends #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
psschwei
wants to merge
3
commits into
generative-computing:main
Choose a base branch
from
psschwei:sample-seed
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+205
−2
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| """Unit tests for HuggingFace backend pure-logic helpers — no model load required.""" | ||
|
|
||
| from types import SimpleNamespace | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("torch", reason="torch not installed — install mellea[hf]") | ||
| pytest.importorskip( | ||
| "transformers", reason="transformers not installed — install mellea[hf]" | ||
| ) | ||
| pytest.importorskip( | ||
| "llguidance", reason="llguidance not installed — install mellea[hf]" | ||
| ) | ||
|
|
||
| from mellea.backends import ModelOption | ||
| from mellea.backends.adapters import IntrinsicAdapter | ||
| from mellea.backends.huggingface import LocalHFBackend | ||
| from mellea.stdlib.components import Intrinsic, Message | ||
| from mellea.stdlib.context import ChatContext | ||
|
|
||
|
|
||
| class _FakeRewrittenRequest: | ||
| def __init__(self, temperature=None): | ||
| self.temperature = temperature | ||
|
|
||
| def model_copy(self, update): | ||
| copied = _FakeRewrittenRequest(self.temperature) | ||
| for key, value in update.items(): | ||
| setattr(copied, key, value) | ||
| return copied | ||
|
|
||
|
|
||
| class _FakeRewriter: | ||
| def __init__(self, *args, **kwargs): | ||
| pass | ||
|
|
||
| def transform(self, request_json, **intrinsic_kwargs): | ||
| return _FakeRewrittenRequest() | ||
|
|
||
|
|
||
| class _FakeResultProcessor: | ||
| def __init__(self, *args, **kwargs): | ||
| pass | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def stub_backend(): | ||
| """Return a stub with the attributes _make_backend_specific_and_remove reads. | ||
|
|
||
| Avoids constructing a real LocalHFBackend (which loads a model from the Hub). | ||
| """ | ||
| return SimpleNamespace( | ||
| from_mellea_model_opts_map={ | ||
| ModelOption.MAX_NEW_TOKENS: "max_new_tokens", | ||
| ModelOption.STOP_SEQUENCES: "stop_strings", | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def _call(stub, opts): | ||
| return LocalHFBackend._make_backend_specific_and_remove(stub, opts) | ||
|
|
||
|
|
||
| def _make_intrinsic_adapter_stub(): | ||
| adapter = IntrinsicAdapter.__new__(IntrinsicAdapter) | ||
| adapter.name = "answerability" | ||
| adapter.qualified_name = "answerability_alora" | ||
| adapter.config = {} | ||
| return adapter | ||
|
|
||
|
|
||
| def _make_intrinsic_backend_stub(stub_backend): | ||
| stub_backend.formatter = SimpleNamespace( | ||
| to_chat_messages=lambda linearized_ctx: [Message("user", "Is the sky blue?")] | ||
| ) | ||
| stub_backend._added_adapters = {} | ||
| stub_backend._tokenizer = object() | ||
| stub_backend._model = object() | ||
| stub_backend._get_hf_model_id = lambda: "stub-model" | ||
| stub_backend._make_backend_specific_and_remove = lambda opts: ( | ||
| LocalHFBackend._make_backend_specific_and_remove(stub_backend, opts) | ||
| ) | ||
| stub_backend.post_processing = lambda *args, **kwargs: None | ||
| stub_backend._generate_with_adapter_lock = ( | ||
| lambda adapter_name, generate_func, *args: generate_func(*args) | ||
| ) | ||
| return stub_backend | ||
|
|
||
|
|
||
| def test_seed_forces_do_sample_true(stub_backend): | ||
| """Issue #40: a seed alone must flip do_sample=True so it isn't ignored.""" | ||
| out = _call(stub_backend, {ModelOption.SEED: 42}) | ||
| assert out["do_sample"] is True | ||
|
|
||
|
|
||
| def test_nonzero_temperature_forces_do_sample_true(stub_backend): | ||
| out = _call(stub_backend, {ModelOption.TEMPERATURE: 0.7}) | ||
| assert out["do_sample"] is True | ||
| assert out["temperature"] == 0.7 | ||
|
|
||
|
|
||
| def test_zero_temperature_does_not_force_do_sample(stub_backend): | ||
| """temperature=0 means greedy; don't override do_sample.""" | ||
| out = _call(stub_backend, {ModelOption.TEMPERATURE: 0.0}) | ||
| assert "do_sample" not in out | ||
|
|
||
|
|
||
| def test_seed_with_zero_temperature_does_not_force_do_sample(stub_backend): | ||
| """temperature=0 wins over seed — do_sample=True with temperature=0 crashes transformers.""" | ||
| out = _call(stub_backend, {ModelOption.SEED: 42, ModelOption.TEMPERATURE: 0.0}) | ||
| assert "do_sample" not in out | ||
|
|
||
|
|
||
| def test_no_seed_no_temperature_leaves_do_sample_unset(stub_backend): | ||
| out = _call(stub_backend, {ModelOption.MAX_NEW_TOKENS: 32}) | ||
| assert "do_sample" not in out | ||
| assert out["max_new_tokens"] == 32 | ||
|
|
||
|
|
||
| def test_user_do_sample_is_not_overridden(stub_backend): | ||
| """If the caller explicitly set do_sample=False, respect it even with a seed.""" | ||
| out = _call(stub_backend, {ModelOption.SEED: 42, "do_sample": False}) | ||
| assert out["do_sample"] is False | ||
|
|
||
|
|
||
| def test_seed_sentinel_is_stripped(stub_backend): | ||
| """SEED is a Mellea sentinel and must not leak into the backend kwargs.""" | ||
| out = _call(stub_backend, {ModelOption.SEED: 42}) | ||
| assert ModelOption.SEED not in out | ||
|
|
||
|
|
||
| async def test_intrinsic_seed_with_zero_temperature_keeps_greedy(stub_backend): | ||
| """The intrinsic path must not let seed override explicit temperature=0.""" | ||
| backend = _make_intrinsic_backend_stub(stub_backend) | ||
| adapter = _make_intrinsic_adapter_stub() | ||
| captured = {} | ||
|
|
||
| def fake_transformers_inputs(rewritten, tokenizer, model): | ||
| assert rewritten.temperature == 0.0 | ||
| generate_input = {"input_tokens": object(), "do_sample": False} | ||
| captured["generate_input"] = generate_input | ||
| return generate_input, {} | ||
|
|
||
| def fake_generate_with_transformers(tokenizer, model, generate_input, other_input): | ||
| return object() | ||
|
|
||
| with ( | ||
| patch( | ||
| "mellea.backends.huggingface.get_adapter_for_intrinsic", | ||
| return_value=adapter, | ||
| ), | ||
| patch( | ||
| "mellea.backends.huggingface.granite_formatters.IntrinsicsRewriter", | ||
| _FakeRewriter, | ||
| ), | ||
| patch( | ||
| "mellea.backends.huggingface.granite_formatters.IntrinsicsResultProcessor", | ||
| _FakeResultProcessor, | ||
| ), | ||
| patch( | ||
| "mellea.formatters.granite.base.util.chat_completion_request_to_transformers_inputs", | ||
| side_effect=fake_transformers_inputs, | ||
| ), | ||
| patch( | ||
| "mellea.formatters.granite.base.util.generate_with_transformers", | ||
| side_effect=fake_generate_with_transformers, | ||
| ), | ||
| ): | ||
| output = await LocalHFBackend._generate_from_intrinsic( | ||
| backend, | ||
| Intrinsic("answerability"), | ||
| ChatContext().add(Message("user", "Is the sky blue?")), | ||
| model_options={ModelOption.SEED: 42, ModelOption.TEMPERATURE: 0.0}, | ||
| ) | ||
| assert output._generate is not None | ||
| await output._generate | ||
|
|
||
| assert captured["generate_input"]["do_sample"] is False | ||
| assert "temperature" not in captured["generate_input"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small wording issue here — it should be "once sampling is enabled": it's
do_sample=Truepaired withtemperature=0that crashes transformers. Under greedy decoding (do_sample=False) temperature is simply ignored, so there's nothing invalid about it there.