Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/api_server/maxtext_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

from dataclasses import dataclass, field

from MaxText import maxengine, pyconfig
from MaxText import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.multimodal import processor as mm_processor
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_logging, max_utils
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/mmlu/mmlu_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from tqdm import tqdm

from MaxText import pyconfig
from MaxText import maxengine
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import jax

from MaxText import maxengine
from MaxText import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_utils


Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

from absl import app

from MaxText import maxengine
from MaxText import pyconfig
from maxtext.common import profiler
from maxtext.common.gcloud_stub import jetstream, is_decoupled
from maxtext.inference.maxengine import maxengine
from maxtext.multimodal import processor as mm_processor
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_utils
Comment on lines -25 to 31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't related to your PR, but can this be moved to inference/?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

from transformers import AutoModelForCausalLM, AutoConfig

from MaxText import maxengine
from MaxText import pyconfig
from MaxText.globals import MAXTEXT_PKG_DIR
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/inference/decode_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

import jax

from MaxText import maxengine, pyconfig
from MaxText import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_utils

_NUM_STREAMS = 5
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from absl import app
from collections.abc import MutableMapping

from MaxText import maxengine
from MaxText import pyconfig
from maxtext.common import profiler
from maxtext.inference.maxengine import maxengine
from maxtext.input_pipeline.packing import prefill_packing
from maxtext.utils import gcs_utils
from maxtext.utils import max_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# limitations under the License.

cd /maxtext
python3 -m MaxText.maxengine_server $@
python3 -m maxtext.inference.maxengine.maxengine_server $@
15 changes: 15 additions & 0 deletions src/maxtext/inference/maxengine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MaxEngine inference module."""
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ def create_engine_from_config_flags(
assert "load_parameters_path" in args, "load_parameters_path must be defined"
if maxengine_config_filepath is None:
maxengine_config_filepath = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
updated_args = [os.path.join(MAXTEXT_PKG_DIR, "maxengine_server.py"), maxengine_config_filepath]
updated_args = [os.path.join(MAXTEXT_PKG_DIR, "inference", "maxengine", "maxengine_server.py"), maxengine_config_filepath]
for k, v in args.items():
option = f"{k}={v}"
updated_args.append(option)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import jax

from maxtext.common.gcloud_stub import jetstream, is_decoupled
from maxtext.inference.maxengine import maxengine

config_lib, engine_api, _token_utils, _tokenizer_api, _token_params_ns = jetstream()

from MaxText import maxengine


# TODO: merge it with the above create_maxengine().
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import jax

from MaxText import pyconfig
from MaxText import maxengine_config
from maxtext.common import gcloud_stub
from maxtext.inference.maxengine import maxengine_config

# _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on')
# _THREADS = flags.DEFINE_integer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# limitations under the License.

cd /maxtext
python3 -m MaxText.maxengine_server \
python3 -m maxtext.inference.maxengine.maxengine_server \
maxtext/configs/base.yml $@
15 changes: 15 additions & 0 deletions src/maxtext/inference/mlperf/microbenchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLPerf microbenchmarks module."""
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@

from absl import app

from MaxText import maxengine
from MaxText import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_utils

_WARMUP_ITERS = 2
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/inference/mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from jetstream.engine import engine_api

# pylint: disable=no-name-in-module
from MaxText.maxengine import MaxEngine
from MaxText.maxengine import set_engine_vars_from_base_engine
from maxtext.inference.maxengine.maxengine import MaxEngine
from maxtext.inference.maxengine.maxengine import set_engine_vars_from_base_engine
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import mlperf_loadgen as lg # pytype: disable=import-error
# pylint: disable=no-name-in-module

from MaxText.maxengine import create_engine_from_config_flags
from maxtext.inference.maxengine.maxengine import create_engine_from_config_flags
from maxtext.inference.mlperf import offline_inference


Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/offline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from jax.sharding import Mesh
from jax.experimental import mesh_utils

from MaxText.maxengine import MaxEngine
from maxtext.inference.maxengine.maxengine import MaxEngine
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor
from maxtext.utils import max_logging
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/inference/scripts/decode_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

import jax

from MaxText import maxengine, pyconfig
from MaxText import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_utils

_NUM_STREAMS = 5
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/input_pipeline/packing/prefill_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np

from maxtext.common.gcloud_stub import jetstream, is_decoupled
from maxtext.inference.maxengine.maxengine import MaxEngine

config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream()

Expand All @@ -29,8 +30,6 @@
if is_decoupled() and jetstream_is_stub:
raise RuntimeError("prefill_packing imported while DECOUPLE_GCLOUD=TRUE. This module depends on JetStream.")

from MaxText.maxengine import MaxEngine

import warnings
import logging

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/scratch_code/gemma_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ export M_MAX_TARGET_LENGTH=2048

#python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false

python3 -m MaxText.maxengine_server "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
python3 -m maxtext.inference.maxengine.maxengine_server "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
import jax.numpy as jnp
from jax.sharding import Mesh
import jsonlines
from MaxText import maxengine
from MaxText import pyconfig
from MaxText.common_types import Array, MODEL_MODE_TRAIN
from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, generate_completions, grpo_loss_fn
from maxtext.experimental.rl.grpo_utils import compute_log_probs
from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
from maxtext.inference.maxengine import maxengine
from maxtext.models import models
from maxtext.utils import maxtext_utils
from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/grpo_trainer_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from jax.sharding import Mesh
import jsonlines
import MaxText as mt
from MaxText import maxengine
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_TRAIN
from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
from maxtext.experimental.rl import grpo_utils
from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, grpo_loss_fn, setup_train_loop
from maxtext.experimental.rl.grpo_utils import compute_log_probs
from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
from maxtext.inference import offline_engine
from maxtext.inference.maxengine import maxengine
from maxtext.inference.offline_engine import InputData
from maxtext.layers import quantizations
from maxtext.models import models
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/vision_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import jax
import jax.numpy as jnp
import jsonlines
from MaxText import maxengine
from MaxText import pyconfig
from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_TEST_ASSETS_ROOT
from maxtext.inference.maxengine import maxengine
from maxtext.models import models
from maxtext.multimodal import processor_gemma3
from maxtext.multimodal import utils as mm_utils
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/gcloud_stub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,26 @@ def test_gcs_utils_guard_raises_when_not_decoupled_and_stub(self):
def test_maxengine_config_create_exp_maxengine_signature_decoupled(self):
# Import lazily under decoupled mode (safe even without JetStream installed).
with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}):
maxengine_config = importlib.import_module("MaxText.maxengine_config")
maxengine_config = importlib.import_module("maxtext.inference.maxengine.maxengine_config")
importlib.reload(maxengine_config)

mock_devices = mock.MagicMock()
mock_config = mock.MagicMock()

with mock.patch("MaxText.maxengine.MaxEngine") as mock_engine:
with mock.patch("maxtext.inference.maxengine.maxengine.MaxEngine") as mock_engine:
maxengine_config.create_exp_maxengine(mock_devices, mock_config)
mock_engine.assert_called_once_with(mock_config)

def test_maxengine_config_create_exp_maxengine_signature_not_decoupled(self):
# Import safely (under decoupled) then flip behavior only for the call.
with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}):
maxengine_config = importlib.import_module("MaxText.maxengine_config")
maxengine_config = importlib.import_module("maxtext.inference.maxengine.maxengine_config")
importlib.reload(maxengine_config)

with mock.patch.object(maxengine_config, "is_decoupled", return_value=False):
mock_devices = mock.MagicMock()
mock_config = mock.MagicMock()
with mock.patch("MaxText.maxengine.MaxEngine") as mock_engine:
with mock.patch("maxtext.inference.maxengine.maxengine.MaxEngine") as mock_engine:
maxengine_config.create_exp_maxengine(mock_devices, mock_config)
mock_engine.assert_called_once_with(config=mock_config, devices=mock_devices)

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/maxengine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText import maxengine, pyconfig
from MaxText import pyconfig
from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL
from maxtext.layers import quantizations
from MaxText.maxengine import MaxEngine
from maxtext.inference.maxengine import maxengine
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this extra import intended?

from maxtext.inference.maxengine.maxengine import MaxEngine
from maxtext.models import models
from maxtext.utils import maxtext_utils
from tests.utils.test_helpers import get_test_config_path
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_env_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from maxtext.common.gcloud_stub import is_decoupled

CORE_IMPORTS = ["jax", "jax.numpy", "flax", "numpy"]
OPTIONAL_IMPORTS = ["transformers", "MaxText", "MaxText.pyconfig", "MaxText.maxengine"]
OPTIONAL_IMPORTS = ["transformers", "MaxText", "MaxText.pyconfig", "maxtext.inference.maxengine.maxengine"]

_defects: list[str] = []

Expand Down
2 changes: 1 addition & 1 deletion tools/data_generation/generate_distillation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
For more information, check out `python3 -m MaxText.generate_distillation_data --help`.
Note:
Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server:
python3 -m MaxText.maxengine_server src/maxtext/configs/base.yml \
python3 -m maxtext.inference.maxengine.maxengine_server src/maxtext/configs/base.yml \
model_name=deepseek2-16b tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \
load_parameters_path=<unscanned checkpoint path> \
max_target_length=2048 max_prefill_predict_length=256 \
Expand Down
Loading