From bc220e5b2fb6b70f90e4dea97af2e1f8881e2979 Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Fri, 20 Feb 2026 19:56:06 +0000 Subject: [PATCH] Move maxengine for restructuring --- benchmarks/api_server/maxtext_generator.py | 3 ++- benchmarks/mmlu/mmlu_eval.py | 2 +- .../load_and_quantize_checkpoint.py | 2 +- src/maxtext/decode.py | 2 +- .../ckpt_conversion_agent/utils/save_param.py | 2 +- src/maxtext/inference/decode_multi.py | 3 ++- src/maxtext/inference/inference_microbenchmark.py | 2 +- .../jetstream_pathways_entrypoint.sh | 2 +- src/maxtext/inference/maxengine/__init__.py | 15 +++++++++++++++ .../inference/maxengine}/maxengine.py | 2 +- .../inference/maxengine}/maxengine_config.py | 2 +- .../inference/maxengine}/maxengine_server.py | 2 +- .../maxengine_server_entrypoint.sh | 2 +- .../inference/mlperf/microbenchmarks/__init__.py | 15 +++++++++++++++ .../microbenchmarks}/benchmark_chunked_prefill.py | 2 +- src/maxtext/inference/mlperf/offline_inference.py | 4 ++-- src/maxtext/inference/mlperf/offline_mode.py | 2 +- src/maxtext/inference/offline_engine.py | 2 +- src/maxtext/inference/scripts/decode_multi.py | 3 ++- .../input_pipeline/packing/prefill_packing.py | 3 +-- src/maxtext/scratch_code/gemma_7b.sh | 2 +- .../generate_grpo_golden_logits.py | 4 ++-- .../integration/grpo_trainer_correctness_test.py | 4 ++-- tests/integration/vision_encoder_test.py | 2 +- tests/unit/gcloud_stub_test.py | 8 ++++---- tests/unit/maxengine_test.py | 5 +++-- tests/unit/test_env_smoke.py | 2 +- .../data_generation/generate_distillation_data.py | 2 +- 28 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 src/maxtext/inference/maxengine/__init__.py rename src/{MaxText => maxtext/inference/maxengine}/maxengine.py (99%) rename src/{MaxText => maxtext/inference/maxengine}/maxengine_config.py (98%) rename src/{MaxText => maxtext/inference/maxengine}/maxengine_server.py (98%) create mode 100644 src/maxtext/inference/mlperf/microbenchmarks/__init__.py rename src/{MaxText => maxtext/inference/mlperf/microbenchmarks}/benchmark_chunked_prefill.py (99%) diff --git a/benchmarks/api_server/maxtext_generator.py b/benchmarks/api_server/maxtext_generator.py index 383a601ce7..b611ebe86f 100644 --- a/benchmarks/api_server/maxtext_generator.py +++ b/benchmarks/api_server/maxtext_generator.py @@ -34,7 +34,8 @@ from dataclasses import dataclass, field -from MaxText import maxengine, pyconfig +from MaxText import pyconfig +from maxtext.inference.maxengine import maxengine from maxtext.multimodal import processor as mm_processor from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils diff --git a/benchmarks/mmlu/mmlu_eval.py b/benchmarks/mmlu/mmlu_eval.py index 3d5b017631..4aa77f0b5d 100644 --- a/benchmarks/mmlu/mmlu_eval.py +++ b/benchmarks/mmlu/mmlu_eval.py @@ -57,7 +57,7 @@ from tqdm import tqdm from MaxText import pyconfig -from MaxText import maxengine +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py b/src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py index d5ba5881c9..190d300123 100644 --- a/src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py +++ b/src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py @@ -21,8 +21,8 @@ import jax -from MaxText import maxengine from MaxText import pyconfig +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_utils diff --git a/src/maxtext/decode.py b/src/maxtext/decode.py index d683ed307e..810dbc38d0 100644 --- a/src/maxtext/decode.py +++ b/src/maxtext/decode.py @@ -22,10 +22,10 @@ from absl import app -from MaxText import maxengine from MaxText import pyconfig from maxtext.common import profiler from maxtext.common.gcloud_stub import jetstream, is_decoupled +from maxtext.inference.maxengine import maxengine from maxtext.multimodal import processor as mm_processor from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_utils diff --git a/src/maxtext/experimental/agent/ckpt_conversion_agent/utils/save_param.py b/src/maxtext/experimental/agent/ckpt_conversion_agent/utils/save_param.py index 82c4ca5a15..6b542f1725 100644 --- a/src/maxtext/experimental/agent/ckpt_conversion_agent/utils/save_param.py +++ b/src/maxtext/experimental/agent/ckpt_conversion_agent/utils/save_param.py @@ -27,9 +27,9 @@ from transformers import AutoModelForCausalLM, AutoConfig -from MaxText import maxengine from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/maxtext/inference/decode_multi.py b/src/maxtext/inference/decode_multi.py index a037f6724a..e53447b38d 100644 --- a/src/maxtext/inference/decode_multi.py +++ b/src/maxtext/inference/decode_multi.py @@ -22,7 +22,8 @@ import jax -from MaxText import maxengine, pyconfig +from MaxText import pyconfig +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_utils _NUM_STREAMS = 5 diff --git a/src/maxtext/inference/inference_microbenchmark.py b/src/maxtext/inference/inference_microbenchmark.py index 0d378faf5c..99220dddc5 100644 --- a/src/maxtext/inference/inference_microbenchmark.py +++ b/src/maxtext/inference/inference_microbenchmark.py @@ -22,9 +22,9 @@ from absl import app from collections.abc import MutableMapping -from MaxText import maxengine from MaxText import pyconfig from maxtext.common import profiler +from maxtext.inference.maxengine import maxengine from maxtext.input_pipeline.packing import prefill_packing from maxtext.utils import gcs_utils from maxtext.utils import max_utils diff --git a/src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh b/src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh index 0d80a93951..eadb6a8490 100644 --- a/src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh +++ b/src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh @@ -15,4 +15,4 @@ # limitations under the License. cd /maxtext -python3 -m MaxText.maxengine_server $@ +python3 -m maxtext.inference.maxengine.maxengine_server $@ diff --git a/src/maxtext/inference/maxengine/__init__.py b/src/maxtext/inference/maxengine/__init__.py new file mode 100644 index 0000000000..fe12f6e842 --- /dev/null +++ b/src/maxtext/inference/maxengine/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MaxEngine inference module.""" diff --git a/src/MaxText/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py similarity index 99% rename from src/MaxText/maxengine.py rename to src/maxtext/inference/maxengine/maxengine.py index d67cddf806..1a6d4be5fb 100644 --- a/src/MaxText/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -1747,7 +1747,7 @@ def create_engine_from_config_flags( assert "load_parameters_path" in args, "load_parameters_path must be defined" if maxengine_config_filepath is None: maxengine_config_filepath = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml") - updated_args = [os.path.join(MAXTEXT_PKG_DIR, "maxengine_server.py"), maxengine_config_filepath] + updated_args = [os.path.join(MAXTEXT_PKG_DIR, "inference", "maxengine", "maxengine_server.py"), maxengine_config_filepath] for k, v in args.items(): option = f"{k}={v}" updated_args.append(option) diff --git a/src/MaxText/maxengine_config.py b/src/maxtext/inference/maxengine/maxengine_config.py similarity index 98% rename from src/MaxText/maxengine_config.py rename to src/maxtext/inference/maxengine/maxengine_config.py index a38cb63f8d..72ddbd218e 100644 --- a/src/MaxText/maxengine_config.py +++ b/src/maxtext/inference/maxengine/maxengine_config.py @@ -19,10 +19,10 @@ import jax from maxtext.common.gcloud_stub import jetstream, is_decoupled +from maxtext.inference.maxengine import maxengine config_lib, engine_api, _token_utils, _tokenizer_api, _token_params_ns = jetstream() -from MaxText import maxengine # TODO: merge it with the above create_maxengine(). diff --git a/src/MaxText/maxengine_server.py b/src/maxtext/inference/maxengine/maxengine_server.py similarity index 98% rename from src/MaxText/maxengine_server.py rename to src/maxtext/inference/maxengine/maxengine_server.py index ecc7d88fee..30ecabf0dd 100644 --- a/src/MaxText/maxengine_server.py +++ b/src/maxtext/inference/maxengine/maxengine_server.py @@ -23,8 +23,8 @@ import jax from MaxText import pyconfig -from MaxText import maxengine_config from maxtext.common import gcloud_stub +from maxtext.inference.maxengine import maxengine_config # _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') # _THREADS = flags.DEFINE_integer( diff --git a/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh b/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh index 3199c9a732..8399f3a7e7 100644 --- a/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh +++ b/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh @@ -15,5 +15,5 @@ # limitations under the License. cd /maxtext -python3 -m MaxText.maxengine_server \ +python3 -m maxtext.inference.maxengine.maxengine_server \ maxtext/configs/base.yml $@ diff --git a/src/maxtext/inference/mlperf/microbenchmarks/__init__.py b/src/maxtext/inference/mlperf/microbenchmarks/__init__.py new file mode 100644 index 0000000000..92af9595ab --- /dev/null +++ b/src/maxtext/inference/mlperf/microbenchmarks/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLPerf microbenchmarks module.""" diff --git a/src/MaxText/benchmark_chunked_prefill.py b/src/maxtext/inference/mlperf/microbenchmarks/benchmark_chunked_prefill.py similarity index 99% rename from src/MaxText/benchmark_chunked_prefill.py rename to src/maxtext/inference/mlperf/microbenchmarks/benchmark_chunked_prefill.py index ee220bbb5b..862846f35e 100644 --- a/src/MaxText/benchmark_chunked_prefill.py +++ b/src/maxtext/inference/mlperf/microbenchmarks/benchmark_chunked_prefill.py @@ -47,8 +47,8 @@ from absl import app -from MaxText import maxengine from MaxText import pyconfig +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_utils _WARMUP_ITERS = 2 diff --git a/src/maxtext/inference/mlperf/offline_inference.py b/src/maxtext/inference/mlperf/offline_inference.py index c1aea9bafe..30171b6579 100644 --- a/src/maxtext/inference/mlperf/offline_inference.py +++ b/src/maxtext/inference/mlperf/offline_inference.py @@ -33,8 +33,8 @@ from jetstream.engine import engine_api # pylint: disable=no-name-in-module -from MaxText.maxengine import MaxEngine -from MaxText.maxengine import set_engine_vars_from_base_engine +from maxtext.inference.maxengine.maxengine import MaxEngine +from maxtext.inference.maxengine.maxengine import set_engine_vars_from_base_engine from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor diff --git a/src/maxtext/inference/mlperf/offline_mode.py b/src/maxtext/inference/mlperf/offline_mode.py index 6adb315839..1544ca9398 100644 --- a/src/maxtext/inference/mlperf/offline_mode.py +++ b/src/maxtext/inference/mlperf/offline_mode.py @@ -35,7 +35,7 @@ import mlperf_loadgen as lg # pytype: disable=import-error # pylint: disable=no-name-in-module -from MaxText.maxengine import create_engine_from_config_flags +from maxtext.inference.maxengine.maxengine import create_engine_from_config_flags from maxtext.inference.mlperf import offline_inference diff --git a/src/maxtext/inference/offline_engine.py b/src/maxtext/inference/offline_engine.py index 8e5d83101a..594ba52eeb 100644 --- a/src/maxtext/inference/offline_engine.py +++ b/src/maxtext/inference/offline_engine.py @@ -53,7 +53,7 @@ from jax.sharding import Mesh from jax.experimental import mesh_utils -from MaxText.maxengine import MaxEngine +from maxtext.inference.maxengine.maxengine import MaxEngine from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor from maxtext.utils import max_logging diff --git a/src/maxtext/inference/scripts/decode_multi.py b/src/maxtext/inference/scripts/decode_multi.py index a037f6724a..e53447b38d 100644 --- a/src/maxtext/inference/scripts/decode_multi.py +++ b/src/maxtext/inference/scripts/decode_multi.py @@ -22,7 +22,8 @@ import jax -from MaxText import maxengine, pyconfig +from MaxText import pyconfig +from maxtext.inference.maxengine import maxengine from maxtext.utils import max_utils _NUM_STREAMS = 5 diff --git a/src/maxtext/input_pipeline/packing/prefill_packing.py b/src/maxtext/input_pipeline/packing/prefill_packing.py index fc5152d19b..8cc8816437 100644 --- a/src/maxtext/input_pipeline/packing/prefill_packing.py +++ b/src/maxtext/input_pipeline/packing/prefill_packing.py @@ -21,6 +21,7 @@ import numpy as np from maxtext.common.gcloud_stub import jetstream, is_decoupled +from maxtext.inference.maxengine.maxengine import MaxEngine config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream() @@ -29,8 +30,6 @@ if is_decoupled() and jetstream_is_stub: raise RuntimeError("prefill_packing imported while DECOUPLE_GCLOUD=TRUE. This module depends on JetStream.") -from MaxText.maxengine import MaxEngine - import warnings import logging diff --git a/src/maxtext/scratch_code/gemma_7b.sh b/src/maxtext/scratch_code/gemma_7b.sh index 1c09ac0100..1ee8674e08 100644 --- a/src/maxtext/scratch_code/gemma_7b.sh +++ b/src/maxtext/scratch_code/gemma_7b.sh @@ -5,4 +5,4 @@ export M_MAX_TARGET_LENGTH=2048 #python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false -python3 -m MaxText.maxengine_server "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false \ No newline at end of file +python3 -m maxtext.inference.maxengine.maxengine_server "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false \ No newline at end of file diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index e96f97c205..b53cd6c541 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -30,12 +30,12 @@ import jax.numpy as jnp from jax.sharding import Mesh import jsonlines -from MaxText import maxengine from MaxText import pyconfig from MaxText.common_types import Array, MODEL_MODE_TRAIN +from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, generate_completions, grpo_loss_fn from maxtext.experimental.rl.grpo_utils import compute_log_probs -from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT +from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs diff --git a/tests/integration/grpo_trainer_correctness_test.py b/tests/integration/grpo_trainer_correctness_test.py index 2179e1777a..1a89efdf57 100644 --- a/tests/integration/grpo_trainer_correctness_test.py +++ b/tests/integration/grpo_trainer_correctness_test.py @@ -35,14 +35,14 @@ from jax.sharding import Mesh import jsonlines import MaxText as mt -from MaxText import maxengine from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_TRAIN +from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT from maxtext.experimental.rl import grpo_utils from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, grpo_loss_fn, setup_train_loop from maxtext.experimental.rl.grpo_utils import compute_log_probs -from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT from maxtext.inference import offline_engine +from maxtext.inference.maxengine import maxengine from maxtext.inference.offline_engine import InputData from maxtext.layers import quantizations from maxtext.models import models diff --git a/tests/integration/vision_encoder_test.py b/tests/integration/vision_encoder_test.py index 632648f441..feed05298e 100644 --- a/tests/integration/vision_encoder_test.py +++ b/tests/integration/vision_encoder_test.py @@ -22,9 +22,9 @@ import jax import jax.numpy as jnp import jsonlines -from MaxText import maxengine from MaxText import pyconfig from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_TEST_ASSETS_ROOT +from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.multimodal import processor_gemma3 from maxtext.multimodal import utils as mm_utils diff --git a/tests/unit/gcloud_stub_test.py b/tests/unit/gcloud_stub_test.py index 2dbcb8acc3..caf77cc3a7 100644 --- a/tests/unit/gcloud_stub_test.py +++ b/tests/unit/gcloud_stub_test.py @@ -153,26 +153,26 @@ def test_gcs_utils_guard_raises_when_not_decoupled_and_stub(self): def test_maxengine_config_create_exp_maxengine_signature_decoupled(self): # Import lazily under decoupled mode (safe even without JetStream installed). with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}): - maxengine_config = importlib.import_module("MaxText.maxengine_config") + maxengine_config = importlib.import_module("maxtext.inference.maxengine.maxengine_config") importlib.reload(maxengine_config) mock_devices = mock.MagicMock() mock_config = mock.MagicMock() - with mock.patch("MaxText.maxengine.MaxEngine") as mock_engine: + with mock.patch("maxtext.inference.maxengine.maxengine.MaxEngine") as mock_engine: maxengine_config.create_exp_maxengine(mock_devices, mock_config) mock_engine.assert_called_once_with(mock_config) def test_maxengine_config_create_exp_maxengine_signature_not_decoupled(self): # Import safely (under decoupled) then flip behavior only for the call. with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}): - maxengine_config = importlib.import_module("MaxText.maxengine_config") + maxengine_config = importlib.import_module("maxtext.inference.maxengine.maxengine_config") importlib.reload(maxengine_config) with mock.patch.object(maxengine_config, "is_decoupled", return_value=False): mock_devices = mock.MagicMock() mock_config = mock.MagicMock() - with mock.patch("MaxText.maxengine.MaxEngine") as mock_engine: + with mock.patch("maxtext.inference.maxengine.maxengine.MaxEngine") as mock_engine: maxengine_config.create_exp_maxengine(mock_devices, mock_config) mock_engine.assert_called_once_with(config=mock_config, devices=mock_devices) diff --git a/tests/unit/maxengine_test.py b/tests/unit/maxengine_test.py index c36e7ce4f6..3528d7bda3 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/unit/maxengine_test.py @@ -21,10 +21,11 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from MaxText import maxengine, pyconfig +from MaxText import pyconfig from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations -from MaxText.maxengine import MaxEngine +from maxtext.inference.maxengine import maxengine +from maxtext.inference.maxengine.maxengine import MaxEngine from maxtext.models import models from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/test_env_smoke.py b/tests/unit/test_env_smoke.py index 8f918ee37d..ef7c7696ba 100644 --- a/tests/unit/test_env_smoke.py +++ b/tests/unit/test_env_smoke.py @@ -33,7 +33,7 @@ from maxtext.common.gcloud_stub import is_decoupled CORE_IMPORTS = ["jax", "jax.numpy", "flax", "numpy"] -OPTIONAL_IMPORTS = ["transformers", "MaxText", "MaxText.pyconfig", "MaxText.maxengine"] +OPTIONAL_IMPORTS = ["transformers", "MaxText", "MaxText.pyconfig", "maxtext.inference.maxengine.maxengine"] _defects: list[str] = [] diff --git a/tools/data_generation/generate_distillation_data.py b/tools/data_generation/generate_distillation_data.py index 06852a5d5e..3bc0d5fb78 100644 --- a/tools/data_generation/generate_distillation_data.py +++ b/tools/data_generation/generate_distillation_data.py @@ -40,7 +40,7 @@ For more information, check out `python3 -m MaxText.generate_distillation_data --help`. Note: Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server: - python3 -m MaxText.maxengine_server src/maxtext/configs/base.yml \ + python3 -m maxtext.inference.maxengine.maxengine_server src/maxtext/configs/base.yml \ model_name=deepseek2-16b tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \ load_parameters_path= \ max_target_length=2048 max_prefill_predict_length=256 \