diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 4f15c28ca8..5bd220f4e1 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -32,6 +32,7 @@ from jax.experimental.layout import DeviceLocalLayout as DLL # type: ignore from flax import linen as nn +from flax import nnx from flax import struct from flax.linen import partitioning as nn_partitioning import flax @@ -44,8 +45,10 @@ from maxtext.inference.page_manager import PageManager, PageState from maxtext.multimodal import processor as mm_processor from maxtext.utils import lora_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE @@ -111,10 +114,33 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # MaxEngine serves Linen-format inference checkpoints; the surface stays - # Linen-shaped via transformer_as_linen regardless of pure_nnx. + # Model and Optimizer definition. quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + # We need both PREFILL and AR abstract models because the cache vars inherit + # CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and + # bulk_insert searches for the substring "cache_batch" in the AR-mode names. + # Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids + # the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm". + _create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL) + _create_model_ar = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model) + abstract_model_ar = nnx.eval_shape(_create_model_ar) + self.model = abstract_model + self.model_ar = abstract_model_ar + # 3-way split so JIT bodies can pass (params, cache, rest) separately to + # nnx.merge. `rest` (RNG state etc.) is materialized in load_params. + graphdef, _, _, _ = nnx.split(abstract_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._create_model_fn = _create_model + self._nnx_rest_state = None + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.graphdef = None + self._create_model_fn = None self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -140,6 +166,65 @@ def print_stats(self, label: str): max_utils.print_mem_stats(label) max_utils.print_cpu_ram_stats(label) + # NNX cache adapter: bulk_insert / _insert_jit / _maybe_stack_* switch on + # path[-1].key (e.g. "cached_prefill_key"). NNX state would expose ".value" at + # that position, so we convert NNX state <-> plain dict at the JIT boundary + # via to_pure_dict / replace_by_pure_dict. The cache helpers stay unchanged. + + def _nnx_cache_state_template(self, mode: str = MODEL_MODE_PREFILL) -> Any: + """Empty nnx.State template for the model's nnx.Cache vars (PREFILL=batch 1, AR=batch N).""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + return cache_state + + def _nnx_init_cache_dict(self, mode: str = MODEL_MODE_PREFILL) -> dict: + """Zero-filled pure-dict cache matching the abstract NNX model.""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + cache_dict = cache_state.to_pure_dict() + return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict) + + def _nnx_run_model( + self, + params, + cache_dict, + decoder_input_tokens, + decoder_positions, + *, + decoder_segment_ids=None, + enable_dropout=False, + model_mode, + previous_chunk=None, + true_length=None, + slot=None, + page_state=None, + encoder_images=None, + encoder_image_masks=None, + encoder_audios=None, + ): + """NNX equivalent of `model.apply(..., mutable=["cache"])`. Returns (logits, new_cache_dict).""" + cache_state = self._nnx_cache_state_template(mode=model_mode) + nnx.replace_by_pure_dict(cache_state, cache_dict) + # copy=True avoids reusing Variable objects across traces (TraceContextError), + # mirroring the workaround in train.py's diff_wrapper. + model = nnx.merge(self.graphdef, params, cache_state, self._nnx_rest_state, copy=True) + logits = model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=encoder_images, + encoder_image_masks=encoder_image_masks, + encoder_audios=encoder_audios, + enable_dropout=enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_cache = nnx.state(model, nnx.Cache).to_pure_dict() + return logits, new_cache + def generate_aot( self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None ): # returns (new_decode_state, result_tokens) @@ -223,6 +308,9 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + return self._load_params_nnx(params=params, rng=rng) + if self.model.quant and self.config.checkpoint_is_quantized: print("Loading from the quantized checkpoint...") self.model.quant.quant_mode = quantizations.get_quant_mode("serve") @@ -282,11 +370,80 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar return params + def _load_params_nnx(self, params, rng): + """NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.""" + if self.model.quant is not None: + raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.") + + if params: + print("Resharding given NNX params") + _, params_abs, _ = nnx.split(self.model, nnx.Param, ...) + target_shardings = jax.tree.map( + lambda x: x.sharding if hasattr(x, "sharding") else None, + params_abs, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + params_state = jax.device_put(params, target_shardings) + # Build a concrete model once to capture a real `rest` (RNG vars) for nnx.merge. + # Wasteful but simple — the from_pretrained branch below avoids this. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del concrete_model + else: + max_logging.log("Loading NNX params via from_pretrained") + with self._mesh: + nnx_model = model_creation_utils.from_pretrained( + self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + # Refresh graphdef from the concrete loaded model so subsequent merges line up. + graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del nnx_model + + self.abstract_params = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + if isinstance(x, jax.Array) + else None, + params_state, + ) + + self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations_nnx( + self.model, self.config, self._mesh + ) + self.prefill_kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.prefill_kv_cache_annotations, + ) + if self.config.stack_prefill_result_cache: + # With scan_layers=True the NNX cache leaves are already stacked on axis 0, + # so the engine's manual-stack helper (which assumes an unstacked Linen tree) + # doesn't apply. Wiring this up cleanly is a Phase-2 follow-up. + raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.") + # AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL); + # bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes. + self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh) + self.kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, + ) + # state_mesh_annotations is unused on the NNX path; callers reading it + # (e.g. set_engine_vars_from_base_engine) need to be NNX-aware first. + self.state_mesh_annotations = None + + self.print_stats("After load_params (NNX)") + return params_state + def load_single_adapter(self, adapter_path): """ Load Single adapter from adapter_path. Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.") adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -322,6 +479,8 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + quantize_params not yet supported.") self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @@ -476,7 +635,10 @@ def _prefill_jit( if existing_prefix is not None: if not self.use_chunked_prefill: raise ValueError("Using chunked prefill is needed for existing_prefix.") - input_params = params | {"cache": existing_prefix.cache} + # NNX threads existing_prefix.cache via the nnx_cache local below; only + # the Linen path merges cache into input_params (params is a dict there). + if not self.config.pure_nnx: + input_params = params | {"cache": existing_prefix.cache} start_position = existing_prefix.common_prefix_tokens.shape[0] # TODO(yuyanpeng): rename previous_chunk previous_chunk = jnp.expand_dims(existing_prefix.common_prefix_tokens, 0) @@ -508,24 +670,48 @@ def _prefill_jit( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - flat_logits, new_vars = self.model.apply( - input_params, - input_tokens, - positions, - encoder_images=images, - encoder_image_masks=image_masks, - encoder_audios=audio_values, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=MODEL_MODE_PREFILL, - rngs={"params": new_rng}, - mutable=["cache"], - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, + if self.config.pure_nnx: + # Prefill always operates on batch=1 (one padded prompt at a time). + nnx_cache = ( + existing_prefix.cache if existing_prefix is not None else self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL) ) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_cache_dict = self._nnx_run_model( + params=input_params, + cache_dict=nnx_cache, + decoder_input_tokens=input_tokens, + decoder_positions=positions, + decoder_segment_ids=sequence_indicator, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_vars = self.model.apply( + input_params, + input_tokens, + positions, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + rngs={"params": new_rng}, + mutable=["cache"], + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) if return_prompt_logp: prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length) else: @@ -734,6 +920,9 @@ def _prefill_multisampling_jit( prefilling stage. The number of tokens is specified by num_samples. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.") + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) @@ -859,6 +1048,9 @@ def prefill_concat( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.") + if rng is None: rng = jax.random.PRNGKey(0) input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] @@ -1028,17 +1220,30 @@ def _generate_jit( previous_token = decode_state["tokens"] rng, new_rng = jax.random.split(rng) # run one step generation - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - out_logits, new_vars = self.model.apply( - params | {"cache": decode_state["cache"]}, - previous_token, - decode_state["next_pos"], - enable_dropout=False, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs={"params": new_rng}, - mutable=["cache"], - page_state=page_state, - ) + if self.config.pure_nnx: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_cache_dict = self._nnx_run_model( + params=params, + cache_dict=decode_state["cache"], + decoder_input_tokens=previous_token, + decoder_positions=decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_vars = self.model.apply( + params | {"cache": decode_state["cache"]}, + previous_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": new_rng}, + mutable=["cache"], + page_state=page_state, + ) out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) # sampling tokens @@ -1596,6 +1801,9 @@ def init_decode_state( if self.config.attention == "paged" and self.page_manager is not None: page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error + if self.config.pure_nnx: + return self._init_decode_state_nnx(rng=rng, page_state=page_state) + # pylint: disable=unused-argument def init(abstract_params, page_state): x = jnp.ones( @@ -1689,6 +1897,51 @@ def is_lp(k): zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed + def _init_decode_state_nnx(self, rng, page_state) -> DecodeState: + """NNX equivalent of init_decode_state. Returns a decode_state dict with a pure-dict cache.""" + del rng, page_state # cache shape comes from the abstract model + batch = int(self.config.per_device_batch_size * self.mesh.size) + vocab = self.config.vocab_size + + # AR-mode cache so the batch dim matches generate's input shape. + cache_dict_abs = self._nnx_init_cache_dict(mode=MODEL_MODE_AUTOREGRESSIVE) + + @functools.partial(jax.jit, out_shardings=(self.kv_cache_shardings,)) + def _init_cache(): + return (jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict_abs),) + + (cache,) = _init_cache() + + # Per-leaf logical axes for bulk_insert's "cache_batch" lookup. Use model_ar + # so segment_id leaves carry CACHE_BATCH (under PREFILL they'd carry + # CACHE_BATCH_PREFILL, which doesn't contain the "cache_batch" substring). + _, cache_state, _ = nnx.split(self.model_ar, nnx.Cache, ...) + + def _logical_axes_for(var): + # Flax 0.12.6 renamed "sharding" to "out_sharding"; older code may still + # use "sharding_names". Try all three. + meta = var.get_metadata() if hasattr(var, "get_metadata") else {} + out = meta.get("out_sharding") or meta.get("sharding") or meta.get("sharding_names") + if out is None: + return () + return (out,) if isinstance(out, str) else tuple(out) + + annotations_state = jax.tree.map( + _logical_axes_for, + cache_state, + is_leaf=lambda v: isinstance(v, nnx.Variable), + ) + self.kv_cache_annotations_named = annotations_state.to_pure_dict() + + return { + "logits": jnp.zeros((batch, 1, vocab), dtype=jnp.float32), + "cache": cache, + "next_pos": jnp.zeros((batch, 1), dtype=jnp.int32), + "generated_tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "token_logp": jnp.zeros((batch, 1), dtype=jnp.float32), + } + @property def max_concurrent_decodes(self) -> int: """Free slots.""" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 6d62b981dd..ffde9e3607 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1815,6 +1815,30 @@ def init_kv_cache(model, config): return state_mesh_annotations +def _nnx_cache_partition_specs(abstract_model, config, mesh): + """Per-leaf PartitionSpec tree for the abstract model's nnx.Cache vars. + + Returned as a pure dict so the engine can wrap it in NamedSharding the same + way it does for the Linen helpers below. + """ + _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) + # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # active flax partitioning context, so wrap. + with nn_partitioning.axis_rules(config.logical_axis_rules): + named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) + + +def get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_prefill_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + +def get_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + def save_quantized_checkpoint_if_configured(config, params): """Save quantized checkpoint if configured""" assert config.quantization, "quantization must be configured" diff --git a/tests/integration/maxengine_test.py b/tests/integration/maxengine_test.py index eb4a7729d6..34a091131c 100644 --- a/tests/integration/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Tests for the maxengine """ +"""Tests for the maxengine""" import functools import sys @@ -23,6 +23,8 @@ from jax.sharding import Mesh import numpy as np import pytest +from flax import nnx +from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations @@ -31,6 +33,7 @@ from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_serving] @@ -163,6 +166,97 @@ def test_basic_decode(self): self.assertEqual(result_token.data.ndim, 2) self.assertEqual(result_token.data.shape[1], 3) + def _init_nnx_pyconfig(self, **kwargs): + """init_pyconfig with NNX flags on.""" + return self.init_pyconfig(pure_nnx=True, enable_nnx=True, pure_nnx_decoder=True, **kwargs) + + def _build_nnx_params(self, cfg, mesh): + """Materialize an NNX Transformer and return its nnx.Param state.""" + _create_model = model_creation_utils.get_nnx_create_model_fn(cfg, mesh=mesh, model_mode=MODEL_MODE_PREFILL) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = _create_model() + _, params_state, _ = nnx.split(model, nnx.Param, ...) + return params_state + + def test_init_nnx(self): + """NNX engine init exposes graphdef + abstract Transformer.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertIsNotNone(engine.graphdef) + self.assertIsNotNone(engine.model) + self.assertEqual(type(engine.model).__name__, "Transformer") + + def test_basic_prefill_nnx(self): + """NNX prefill returns a Linen-shape result dict with finite values.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=input_tokens, true_length=true_length) + + self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) + self.assertEqual(prefill_result["tokens"].size, 1) + self.assertTrue(jnp.array_equal(first_token.data.size, 3)) + self.assertEqual(first_token.log_prob.shape, (1, 1)) + self.assertIn("cache", prefill_result) + self.assertIsInstance(prefill_result["cache"], dict) + # Catch silent NaN/inf from a bad nnx.merge or cache round-trip. + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + cache_leaves, _ = jax.tree.flatten(prefill_result["cache"]) + for leaf in cache_leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}") + # scan_layers=True (default in test config) ⇒ leading axis is num_decoder_layers. + for leaf in cache_leaves: + self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}") + + def test_basic_decode_nnx(self): + """NNX prefill → insert → 4 generate steps. Verifies next_pos advances and logits stay finite.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304]) + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + decode_state = engine.init_decode_state() + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + decode_state = engine.insert(prefill_result, decode_state, slot=0) + + # 4 steps is enough to catch off-by-one cache pointer bugs. + initial_next_pos = int(decode_state["next_pos"][0, 0]) + for step in range(4): + decode_state, result_token = engine.generate(params=params, decode_state=decode_state) + self.assertEqual(result_token.log_prob.ndim, 2) + self.assertEqual(result_token.log_prob.shape[1], 1) + self.assertEqual(result_token.data.ndim, 2) + self.assertEqual(result_token.data.shape[1], 3) + self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"]))) + self.assertEqual( + int(decode_state["next_pos"][0, 0]), + initial_next_pos + step + 1, + msg=f"next_pos didn't advance at step {step}", + ) + + def test_quantize_raises_for_nnx(self): + """pure_nnx + quantization raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig(quantization="int8") + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_params(rng=self.rng) + + def test_lora_raises_for_nnx(self): + """pure_nnx + LoRA raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_single_adapter("/nonexistent/adapter/path") + @pytest.mark.skip(reason="Can only pass on CPU.") def test_chunked_prefill(self): """Test identical result between chunked prefill with single and multiple chunked.