diff --git a/src/MaxText/generate_param_only_checkpoint.py b/src/MaxText/generate_param_only_checkpoint.py index a2b6bc1008..59582ea162 100644 --- a/src/MaxText/generate_param_only_checkpoint.py +++ b/src/MaxText/generate_param_only_checkpoint.py @@ -32,7 +32,7 @@ from MaxText import optimizers from MaxText import pyconfig from maxtext.common import checkpointing -from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import gcs_utils diff --git a/src/MaxText/gradient_accumulation.py b/src/MaxText/gradient_accumulation.py index f2bf2ffdb0..adcb116520 100644 --- a/src/MaxText/gradient_accumulation.py +++ b/src/MaxText/gradient_accumulation.py @@ -18,7 +18,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding -from MaxText.common_types import ShardMode +from maxtext.common.common_types import ShardMode from MaxText.sharding import maybe_shard_with_name diff --git a/src/MaxText/layerwise_quantization.py b/src/MaxText/layerwise_quantization.py index 4446e981a7..65f19283a7 100644 --- a/src/MaxText/layerwise_quantization.py +++ b/src/MaxText/layerwise_quantization.py @@ -39,8 +39,7 @@ from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp -from MaxText import common_types -from MaxText import pyconfig +from maxtext.common import common_types from maxtext.common import checkpointing from maxtext.layers import quantizations from maxtext.models import deepseek, models @@ -49,6 +48,7 @@ from maxtext.utils import maxtext_utils import orbax.checkpoint as ocp from tqdm import tqdm +from MaxText import pyconfig IGNORE = ocp.PLACEHOLDER PRNGKeyType = Any diff --git a/src/MaxText/maxengine.py b/src/MaxText/maxengine.py index d67cddf806..be0cac067e 100644 --- a/src/MaxText/maxengine.py +++ b/src/MaxText/maxengine.py @@ -37,7 +37,6 @@ import flax from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE from MaxText.globals import MAXTEXT_PKG_DIR from maxtext.models import models from maxtext.layers import quantizations @@ -48,6 +47,7 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled +from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE config_lib, engine_api, token_utils, tokenizer_api, _token_params_ns = jetstream() TokenizerParameters = getattr(_token_params_ns, "TokenizerParameters", object) # type: ignore[assignment] diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 07732d2bd4..0d9cd6dd07 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -29,8 +29,8 @@ import omegaconf from MaxText import pyconfig_deprecated -from MaxText.common_types import DecoderBlockType, ShardMode from MaxText.globals import MAXTEXT_CONFIGS_DIR +from maxtext.common.common_types import DecoderBlockType, ShardMode from maxtext.configs import types from maxtext.configs.types import MaxTextConfig from maxtext.inference.inference_utils import str2bool diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 582d7a122f..a47f246943 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -30,7 +30,7 @@ from MaxText import accelerator_to_spec_map from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR -from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode +from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode from maxtext.utils import gcs_utils from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 7053457d60..3cf858687e 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -25,7 +25,7 @@ import optax -from MaxText.common_types import ShardMode +from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/MaxText/vocabulary_tiling.py b/src/MaxText/vocabulary_tiling.py index ba387f797f..a5ad3c5088 100644 --- a/src/MaxText/vocabulary_tiling.py +++ b/src/MaxText/vocabulary_tiling.py @@ -25,7 +25,7 @@ all_gather_over_fsdp, create_sharding, ) -from MaxText.common_types import ShardMode +from maxtext.common.common_types import ShardMode from maxtext.utils import max_utils diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 011824bd25..6a95f8faa8 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -44,9 +44,9 @@ from jax.sharding import Mesh from MaxText import optimizers from MaxText import pyconfig -from maxtext.common import checkpointing -from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_PKG_DIR +from maxtext.common import checkpointing +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models.models import transformer_as_linen from maxtext.utils import max_logging diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 92161c4725..0409c8ed88 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -74,7 +74,7 @@ from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models diff --git a/src/MaxText/common_types.py b/src/maxtext/common/common_types.py similarity index 98% rename from src/MaxText/common_types.py rename to src/maxtext/common/common_types.py index f36b991cef..085b798817 100644 --- a/src/MaxText/common_types.py +++ b/src/maxtext/common/common_types.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0eeee0582b..a498ac1c35 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -27,18 +27,17 @@ from typing import Any, Literal, NewType, Optional import jax -from MaxText import accelerator_to_spec_map -from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode -from MaxText.globals import MAXTEXT_ASSETS_ROOT +from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode from maxtext.utils import gcs_utils from maxtext.utils import max_utils +from MaxText import accelerator_to_spec_map +from MaxText.globals import MAXTEXT_ASSETS_ROOT from pydantic.config import ConfigDict from pydantic.fields import Field from pydantic.functional_validators import field_validator, model_validator from pydantic.main import BaseModel from pydantic.types import NonNegativeFloat, NonNegativeInt, PositiveInt - class XProfTPUPowerTraceMode(enum.IntEnum): # pylint: disable=invalid-name """Enum for XProfTPUPowerTraceMode.""" diff --git a/src/maxtext/examples/demo_decoding.ipynb b/src/maxtext/examples/demo_decoding.ipynb index de5c064103..68b0515a69 100644 --- a/src/maxtext/examples/demo_decoding.ipynb +++ b/src/maxtext/examples/demo_decoding.ipynb @@ -134,7 +134,7 @@ "import numpy as np\n", "\n", "import MaxText as mt\n", - "from MaxText import common_types\n", + "from maxtext.common import common_types\n", "from MaxText import pyconfig\n", "from MaxText.input_pipeline import _input_pipeline_utils\n", "from maxtext.checkpoint_conversion import to_maxtext\n", diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 5fc4a6e869..352a2b3b8d 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -21,7 +21,7 @@ import jaxtyping from typing import Any, Callable -from MaxText.common_types import DecoderBlockType +from maxtext.common.common_types import DecoderBlockType from maxtext.inference.offline_engine import InputData from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index 0ac2fe1098..e9e708f041 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -29,8 +29,8 @@ from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import variable_to_logically_partitioned -from MaxText.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR -from MaxText.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV +from maxtext.common.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR +from maxtext.common.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV MAX_INT8 = 127.5 diff --git a/src/maxtext/inference/page_manager.py b/src/maxtext/inference/page_manager.py index cd684a8620..eb9d514ef3 100644 --- a/src/maxtext/inference/page_manager.py +++ b/src/maxtext/inference/page_manager.py @@ -31,7 +31,7 @@ from jaxtyping import Array, Integer, Bool -from MaxText.common_types import Config +from maxtext.common.common_types import Config # Aliases using d convention # We use string names for dimensions as they are symbolic within the type hints. diff --git a/src/maxtext/inference/paged_attention.py b/src/maxtext/inference/paged_attention.py index e6bf2248dd..e9d682ccff 100644 --- a/src/maxtext/inference/paged_attention.py +++ b/src/maxtext/inference/paged_attention.py @@ -26,7 +26,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from jax.sharding import PartitionSpec as P -from MaxText.common_types import Array, AxisNames, BATCH, DType, D_KV, HEAD, LENGTH, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL +from maxtext.common.common_types import Array, AxisNames, BATCH, DType, D_KV, HEAD, LENGTH, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL from maxtext.inference import page_manager from maxtext.inference import paged_attention_kernel_v2 from maxtext.layers.initializers import variable_to_logically_partitioned diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 670ee92d32..a942377a7c 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -22,8 +22,8 @@ from jax import numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE from MaxText.globals import MAXTEXT_CONFIGS_DIR +from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE from maxtext.utils import max_logging from maxtext.utils import model_creation_utils diff --git a/src/maxtext/kernels/attention/ragged_attention.py b/src/maxtext/kernels/attention/ragged_attention.py index 4dec71b493..2da57c0186 100644 --- a/src/maxtext/kernels/attention/ragged_attention.py +++ b/src/maxtext/kernels/attention/ragged_attention.py @@ -24,7 +24,7 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp -from MaxText.common_types import DEFAULT_MASK_VALUE +from maxtext.common.common_types import DEFAULT_MASK_VALUE def get_mha_cost_estimate(shape_dtype): diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index fe353319f6..02b1cc2848 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -32,7 +32,7 @@ from flax import nnx -from MaxText.common_types import ( +from maxtext.common.common_types import ( Array, AxisIdxes, AxisNames, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index cadc9f829f..caa7124da8 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -33,7 +33,7 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding -from MaxText.common_types import ( +from maxtext.common.common_types import ( Array, AttentionType, AxisIdxes, diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 0591f1ea91..80307551dc 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -25,7 +25,7 @@ from flax import nnx -from MaxText.common_types import ( +from maxtext.common.common_types import ( DecoderBlockType, BATCH, BATCH_NO_EXP, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 8433fa633f..354bbf07e7 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -27,8 +27,8 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import sharding -from MaxText.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode -from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN +from maxtext.common.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode +from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN from maxtext.inference import page_manager from maxtext.layers import linears from maxtext.layers import mhc diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 4d0303fa86..2d716768e8 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -25,7 +25,7 @@ from flax import nnx from MaxText.sharding import logical_to_mesh_axes, create_sharding -from MaxText.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType +from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned from maxtext.utils import max_logging diff --git a/src/maxtext/layers/encoders.py b/src/maxtext/layers/encoders.py index a04514fc0e..916ffbd5bb 100644 --- a/src/maxtext/layers/encoders.py +++ b/src/maxtext/layers/encoders.py @@ -18,7 +18,7 @@ from flax import nnx from jax.sharding import Mesh -from MaxText.common_types import Config +from maxtext.common.common_types import Config from maxtext.layers import nnx_wrappers from maxtext.layers import initializers diff --git a/src/maxtext/layers/engram.py b/src/maxtext/layers/engram.py index 3b9dba359c..ca75cfb2ea 100644 --- a/src/maxtext/layers/engram.py +++ b/src/maxtext/layers/engram.py @@ -24,7 +24,7 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Array, Config, MODEL_MODE_TRAIN +from maxtext.common.common_types import Array, Config, MODEL_MODE_TRAIN from maxtext.input_pipeline.tokenizer import HFTokenizer from maxtext.layers.embeddings import Embed from maxtext.layers.initializers import NdInitializer, nd_dense_init diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index a6b65c858b..20baf9a633 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -22,7 +22,7 @@ from flax import nnx from aqt.jax.v2 import aqt_tensor -from MaxText.common_types import Array, DType, Shape, PRNGKey +from maxtext.common.common_types import Array, DType, Shape, PRNGKey Initializer = Callable[[PRNGKey, Shape, DType], Array] InitializerAxis = int | tuple[int, ...] diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 168f088ab3..dc19bad5e7 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -30,8 +30,8 @@ import flax.linen as nn from MaxText.sharding import maybe_shard_with_logical -from MaxText.common_types import DecoderBlockType, ShardMode, DType, Array, Config -from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT +from maxtext.common.common_types import DecoderBlockType, ShardMode, DType, Array, Config +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT from maxtext.layers import nnx_wrappers, quantizations from maxtext.layers import normalizations from maxtext.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index a4873a61f2..a4a4771c91 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -19,8 +19,8 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Array, Config -from MaxText.common_types import HyperConnectionType +from maxtext.common.common_types import Array, Config +from maxtext.common.common_types import HyperConnectionType from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init from maxtext.layers.normalizations import RMSNorm diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 15849cf726..37341b0269 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -29,8 +29,8 @@ from jax.sharding import NamedSharding, Mesh from jax.sharding import PartitionSpec as P import jax.numpy as jnp -from MaxText import common_types as ctypes -from MaxText.common_types import ShardMode +from maxtext.common import common_types as ctypes +from maxtext.common.common_types import ShardMode from MaxText.sharding import maybe_shard_with_logical, create_sharding from MaxText.sharding import logical_to_mesh_axes from maxtext.layers import attentions, linears, nnx_wrappers, quantizations diff --git a/src/maxtext/layers/multi_token_prediction.py b/src/maxtext/layers/multi_token_prediction.py index 70e7582003..bab54614ed 100644 --- a/src/maxtext/layers/multi_token_prediction.py +++ b/src/maxtext/layers/multi_token_prediction.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import sharding -from MaxText.common_types import Config, MODEL_MODE_TRAIN +from maxtext.common.common_types import Config, MODEL_MODE_TRAIN from MaxText.globals import EPS from maxtext.layers import nnx_wrappers from maxtext.layers.decoders import DecoderLayer diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 29af4f7e03..195d5bcc14 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -23,7 +23,7 @@ from jax import lax import jax.numpy as jnp from jax.sharding import NamedSharding -from MaxText.common_types import Array, DType, ShardMode +from maxtext.common.common_types import Array, DType, ShardMode from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned from maxtext.utils import max_logging diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index fabedc9a41..4d98dabf15 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -28,7 +28,7 @@ from flax import linen as nn from flax.linen.spmd import LogicallyPartitioned -from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode from MaxText.sharding import ( maybe_shard_with_logical, maybe_shard_with_name, diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index c9d61ed794..2aecfe3667 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -36,7 +36,7 @@ from flax.linen import initializers as flax_initializers import flax.linen as nn -from MaxText.common_types import DType, Config +from maxtext.common.common_types import DType, Config from maxtext.inference.kvcache import KVQuant # Params used to define mixed precision quantization configs diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index d528ee92d2..e999e28542 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -22,8 +22,8 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Config -from MaxText.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import Config +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL from maxtext.inference import page_manager from maxtext.layers import attention_mla from maxtext.layers import initializers diff --git a/src/maxtext/models/gemma.py b/src/maxtext/models/gemma.py index 1090e99ff1..f73fd12ced 100644 --- a/src/maxtext/models/gemma.py +++ b/src/maxtext/models/gemma.py @@ -22,7 +22,7 @@ from jax.sharding import Mesh import jax.numpy as jnp -from MaxText.common_types import Config +from maxtext.common.common_types import Config from maxtext.layers import initializers from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations diff --git a/src/maxtext/models/gemma2.py b/src/maxtext/models/gemma2.py index 48818da6cd..a7315763eb 100644 --- a/src/maxtext/models/gemma2.py +++ b/src/maxtext/models/gemma2.py @@ -22,7 +22,7 @@ from jax.sharding import Mesh import jax.numpy as jnp -from MaxText.common_types import MODEL_MODE_PREFILL, Config +from maxtext.common.common_types import MODEL_MODE_PREFILL, Config from maxtext.layers import attentions from maxtext.layers import initializers from maxtext.layers import nnx_wrappers diff --git a/src/maxtext/models/gemma3.py b/src/maxtext/models/gemma3.py index 6acae303fe..588ffa6db2 100644 --- a/src/maxtext/models/gemma3.py +++ b/src/maxtext/models/gemma3.py @@ -22,7 +22,7 @@ from flax import linen as nn from flax import nnx -from MaxText.common_types import Config, AttentionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import Config, AttentionType, MODEL_MODE_PREFILL from maxtext.layers import quantizations from maxtext.layers import nnx_wrappers from maxtext.layers import initializers diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index b25ddd73d4..f4d3203392 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -27,7 +27,7 @@ from flax import linen as nn from flax import nnx -from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN +from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN from maxtext.layers import initializers, nnx_wrappers from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes from maxtext.layers import quantizations diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index ec973ef485..58a0a2db8f 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -25,7 +25,7 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import AttentionType, Config +from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers from maxtext.layers import moe diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 4157ae1f4d..bfa2f1ed5d 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -21,8 +21,8 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Config -from MaxText.common_types import MODEL_MODE_PREFILL +from maxtext.common.common_types import Config +from maxtext.common.common_types import MODEL_MODE_PREFILL from maxtext.inference import page_manager from maxtext.layers import initializers from maxtext.layers import nnx_wrappers @@ -31,8 +31,8 @@ from maxtext.layers.linears import Dropout, MlpBlock from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant -from MaxText.sharding import create_sharding, maybe_shard_with_logical from maxtext.utils import max_utils +from MaxText.sharding import create_sharding, maybe_shard_with_logical # ----------------------------------------- # The Decoder Layer specific for Llama2 diff --git a/src/maxtext/models/llama4.py b/src/maxtext/models/llama4.py index 3e9b1b5bdc..c66e80440b 100644 --- a/src/maxtext/models/llama4.py +++ b/src/maxtext/models/llama4.py @@ -23,8 +23,8 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Array, AttentionType, Config, MODEL_MODE_TRAIN -from MaxText.common_types import MODEL_MODE_PREFILL +from maxtext.common.common_types import Array, AttentionType, Config, MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_PREFILL from maxtext.inference import page_manager from maxtext.layers import initializers from maxtext.layers import linears diff --git a/src/maxtext/models/mistral.py b/src/maxtext/models/mistral.py index e168cf1dc1..c590a36f85 100644 --- a/src/maxtext/models/mistral.py +++ b/src/maxtext/models/mistral.py @@ -21,7 +21,7 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Config +from maxtext.common.common_types import Config from maxtext.layers import initializers, nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index 63cf258571..46441096d5 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -22,7 +22,7 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Config +from maxtext.common.common_types import Config from maxtext.layers import initializers, nnx_wrappers from maxtext.layers import moe from maxtext.layers import quantizations diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 7d6bfbee29..53173ee929 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN +from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN from maxtext.inference import page_manager from maxtext.layers import initializers from maxtext.layers import nnx_wrappers @@ -49,9 +49,9 @@ class TransformerLinenPure(nn.Module): config: Config mesh: Mesh quant: Quant - # Possible model_mode values can be found in MaxText.common_types. - # We generally use MaxText.common_types.MODEL_MODE_TRAIN or - # MaxText.common_types.MODEL_MODE_PREFILL for initializations here. + # Possible model_mode values can be found in maxtext.common.common_types. + # We generally use maxtext.common.common_types.MODEL_MODE_TRAIN or + # maxtext.common.common_types.MODEL_MODE_PREFILL for initializations here. # TODO: Make model_mode required after confirming no users are affected. model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ # pylint: enable=attribute-defined-outside-init diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index 613c386b6d..c28020d781 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -26,7 +26,7 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.common_types import AttentionType, Config +from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers from maxtext.layers import nnx_wrappers diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index 8da1162542..e9d6778346 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -29,7 +29,7 @@ from flax import linen as nn from flax import nnx -from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN +from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN from maxtext.layers import attentions from maxtext.layers import initializers as max_initializers from maxtext.layers import moe diff --git a/src/maxtext/models/simple_layer.py b/src/maxtext/models/simple_layer.py index db118e13eb..2bd3d7634e 100644 --- a/src/maxtext/models/simple_layer.py +++ b/src/maxtext/models/simple_layer.py @@ -18,7 +18,7 @@ from jax.sharding import Mesh from flax import nnx -from MaxText.common_types import Config, ShardMode +from maxtext.common.common_types import Config, ShardMode from MaxText.sharding import create_sharding from maxtext.layers import quantizations, nnx_wrappers from maxtext.layers.initializers import variable_to_logically_partitioned diff --git a/src/maxtext/scratch_code/demo_from_config.ipynb b/src/maxtext/scratch_code/demo_from_config.ipynb index ac92ff3ab9..2473de393e 100644 --- a/src/maxtext/scratch_code/demo_from_config.ipynb +++ b/src/maxtext/scratch_code/demo_from_config.ipynb @@ -53,7 +53,7 @@ "import numpy as np\n", "from MaxText.input_pipeline import _input_pipeline_utils\n", "import os\n", - "from MaxText import common_types\n", + "from maxtext.common import common_types\n", "import jax\n", "from maxtext.inference import inference_utils\n", "from maxtext.utils import max_logging\n", diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 6bb4e46d5a..bcae3c1ee5 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -38,9 +38,9 @@ from flax import linen as nn from flax.linen import partitioning as nn_partitioning +from maxtext.common.common_types import ShardMode from MaxText import pyconfig from MaxText import sharding -from MaxText.common_types import ShardMode from MaxText.globals import EPS from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index c39f671dd8..a9c9913b4c 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -36,7 +36,7 @@ from MaxText import optimizers from MaxText import pyconfig from MaxText import sharding -from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations from maxtext.models import models from maxtext.trainers.diloco import diloco diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 765122478e..0683b7baa8 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -41,7 +41,7 @@ from maxtext.common.gcloud_stub import is_decoupled from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE from maxtext.utils import max_logging -from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN initialize_multi_tier_checkpointing = initialization.initialize_multi_tier_checkpointing HYBRID_RING_64X4 = "hybrid_ring_64x4" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index b6f487570e..c895351cdb 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -37,7 +37,7 @@ import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager from MaxText import sharding -from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index eb37c622c6..5c744aa373 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -25,7 +25,7 @@ import jax from jax.sharding import AxisType, Mesh from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_utils diff --git a/src/maxtext/vllm_decode.py b/src/maxtext/vllm_decode.py index 355f306bd6..75f522e688 100644 --- a/src/maxtext/vllm_decode.py +++ b/src/maxtext/vllm_decode.py @@ -46,14 +46,14 @@ from maxtext.utils import model_creation_utils from maxtext.utils import max_logging -from MaxText import pyconfig -from MaxText.common_types import Config -from MaxText.globals import MAXTEXT_CONFIGS_DIR +from maxtext.common.common_types import Config from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from tunix.rl.rollout import base_rollout from tunix.rl.rollout.vllm_rollout import VllmRollout from vllm import LLM from vllm.sampling_params import SamplingParams +from MaxText import pyconfig +from MaxText.globals import MAXTEXT_CONFIGS_DIR os.environ["SKIP_JAX_PRECOMPILE"] = "1" os.environ["NEW_MODEL_DESIGN"] = "1" diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index e96f97c205..a42d315f22 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -32,7 +32,7 @@ import jsonlines from MaxText import maxengine from MaxText import pyconfig -from MaxText.common_types import Array, MODEL_MODE_TRAIN +from maxtext.common.common_types import Array, MODEL_MODE_TRAIN from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, generate_completions, grpo_loss_fn from maxtext.experimental.rl.grpo_utils import compute_log_probs from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT diff --git a/tests/inference/kvcache_test.py b/tests/inference/kvcache_test.py index 372ce237ca..276c157158 100644 --- a/tests/inference/kvcache_test.py +++ b/tests/inference/kvcache_test.py @@ -19,7 +19,7 @@ import jax import jax.numpy as jnp -from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from maxtext.inference import kvcache diff --git a/tests/integration/grpo_correctness.py b/tests/integration/grpo_correctness.py index 6bb3fb3897..0751a8fe55 100644 --- a/tests/integration/grpo_correctness.py +++ b/tests/integration/grpo_correctness.py @@ -21,7 +21,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, grpo_loss_fn from maxtext.experimental.rl.grpo_utils import compute_log_probs from MaxText.globals import MAXTEXT_PKG_DIR diff --git a/tests/integration/grpo_trainer_correctness_test.py b/tests/integration/grpo_trainer_correctness_test.py index 2179e1777a..ba67538a30 100644 --- a/tests/integration/grpo_trainer_correctness_test.py +++ b/tests/integration/grpo_trainer_correctness_test.py @@ -37,7 +37,7 @@ import MaxText as mt from MaxText import maxengine from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.experimental.rl import grpo_utils from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, grpo_loss_fn, setup_train_loop from maxtext.experimental.rl.grpo_utils import compute_log_probs diff --git a/tests/integration/sft_trainer_correctness_test.py b/tests/integration/sft_trainer_correctness_test.py index 7c17d8bfff..91a45004a6 100644 --- a/tests/integration/sft_trainer_correctness_test.py +++ b/tests/integration/sft_trainer_correctness_test.py @@ -34,7 +34,7 @@ from jax.sharding import Mesh import jsonlines from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_ASSETS_ROOT from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.globals import MAXTEXT_TEST_ASSETS_ROOT diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 14257b6fa7..608f1d4d14 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -28,8 +28,7 @@ from maxtext.utils import maxtext_utils from maxtext.common.gcloud_stub import is_decoupled -from MaxText import pyconfig -from MaxText.common_types import ( +from maxtext.common.common_types import ( AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, @@ -39,6 +38,7 @@ from maxtext.layers.attention_mla import MLA from maxtext.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask from maxtext.layers.attentions import Attention +from MaxText import pyconfig import numpy as np import pytest diff --git a/tests/unit/deepseek32_vs_reference_test.py b/tests/unit/deepseek32_vs_reference_test.py index 21f25816d4..f2bd58a94b 100644 --- a/tests/unit/deepseek32_vs_reference_test.py +++ b/tests/unit/deepseek32_vs_reference_test.py @@ -52,7 +52,7 @@ from MaxText import pyconfig from maxtext.layers import embeddings, attention_mla -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/gpt3_test.py b/tests/unit/gpt3_test.py index 890c264b9e..5d712e4fbb 100644 --- a/tests/unit/gpt3_test.py +++ b/tests/unit/gpt3_test.py @@ -21,7 +21,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import maxtext_utils diff --git a/tests/unit/llama4_layers_test.py b/tests/unit/llama4_layers_test.py index 9d6bee16b1..e5e5034001 100644 --- a/tests/unit/llama4_layers_test.py +++ b/tests/unit/llama4_layers_test.py @@ -25,11 +25,11 @@ from jax.sharding import Mesh from jax.experimental import mesh_utils -from MaxText.common_types import MODEL_MODE_TRAIN, AttentionType -from MaxText import pyconfig from maxtext.layers import attentions, embeddings from maxtext.models import llama4 +from maxtext.common.common_types import MODEL_MODE_TRAIN, AttentionType from maxtext.utils import maxtext_utils +from MaxText import pyconfig import numpy as np from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/maxengine_test.py b/tests/unit/maxengine_test.py index c36e7ce4f6..7094d75b99 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/unit/maxengine_test.py @@ -22,9 +22,9 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import maxengine, pyconfig -from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL -from maxtext.layers import quantizations from MaxText.maxengine import MaxEngine +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL +from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 8ac66d3421..297698f2f1 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -28,12 +28,12 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from MaxText import pyconfig from MaxText import sharding +from MaxText.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models -from MaxText.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py index 619899ee81..825f3952b4 100644 --- a/tests/unit/mhc_test.py +++ b/tests/unit/mhc_test.py @@ -26,8 +26,8 @@ import numpy as np from MaxText import pyconfig -from MaxText.common_types import HyperConnectionType from MaxText.globals import MAXTEXT_PKG_DIR +from maxtext.common.common_types import HyperConnectionType from maxtext.layers import attention_mla, linears, mhc, moe from maxtext.layers.initializers import nd_dense_init from maxtext.layers.normalizations import RMSNorm diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index c3529d2485..f7797f7633 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -21,7 +21,7 @@ from jax.sharding import Mesh from MaxText import pyconfig from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import maxtext_utils diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 9086ff600e..2d38641d3c 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -23,7 +23,7 @@ from jax.sharding import Mesh from MaxText import pyconfig from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import Config, DType +from maxtext.common.common_types import Config, DType from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index 185b834680..ccd24c8ce9 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -20,12 +20,12 @@ from jax.sharding import Mesh from flax import nnx -from MaxText.common_types import Config from MaxText import pyconfig from maxtext.layers.decoders import DecoderLayer from maxtext.layers import multi_token_prediction # The class under test from maxtext.layers import embeddings -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import Config from maxtext.common.gcloud_stub import is_decoupled from maxtext.utils import max_logging from maxtext.utils import maxtext_utils diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py index 4070f70d3a..dec27dea15 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/unit/pipeline_parallelism_test.py @@ -26,8 +26,8 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_ASSETS_ROOT +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.common.gcloud_stub import is_decoupled from maxtext.layers import nnx_wrappers from maxtext.layers import pipeline diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index 9cc24a65d4..7982cd2926 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -27,9 +27,9 @@ from jax import numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR from MaxText.globals import MAXTEXT_CONFIGS_DIR +from maxtext.common.gcloud_stub import is_decoupled +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR from maxtext.kernels.megablox import gmm from maxtext.layers import nnx_wrappers, quantizations from maxtext.utils import maxtext_utils diff --git a/tests/unit/qwen3_omni_layers_test.py b/tests/unit/qwen3_omni_layers_test.py index f3a858ab6e..a0a34da6e2 100644 --- a/tests/unit/qwen3_omni_layers_test.py +++ b/tests/unit/qwen3_omni_layers_test.py @@ -25,10 +25,10 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh -from MaxText import common_types from MaxText import maxengine from MaxText import pyconfig from MaxText.globals import MAXTEXT_REPO_ROOT +from maxtext.common import common_types from maxtext.layers.attentions import Attention from maxtext.layers.embeddings import ( PositionalEmbedding, diff --git a/tests/unit/state_dtypes_test.py b/tests/unit/state_dtypes_test.py index d9b16aed38..7c24acb39f 100644 --- a/tests/unit/state_dtypes_test.py +++ b/tests/unit/state_dtypes_test.py @@ -21,7 +21,7 @@ from MaxText import optimizers from MaxText import pyconfig from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import maxtext_utils diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 4339482762..aa9dba8a3e 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -24,13 +24,13 @@ import jax.numpy as jnp from jax.sharding import Mesh from MaxText import pyconfig -from MaxText.common_types import Config -from MaxText.common_types import MODEL_MODE_TRAIN +from MaxText.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.common.common_types import Config +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -from MaxText.vocabulary_tiling import vocab_tiling_linen_loss from tests.utils.test_helpers import get_test_config_path import pytest diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index b4504c3df5..2003b18b22 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -22,10 +22,10 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from MaxText import pyconfig +from MaxText.sharding import maybe_shard_with_name from maxtext.common.gcloud_stub import is_decoupled -from MaxText.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, EP_AS_CONTEXT, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, EP_AS_CONTEXT, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode from maxtext.layers.attention_mla import MLA -from MaxText.sharding import maybe_shard_with_name from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 78567bfad5..88cee48e36 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -45,9 +45,9 @@ import jax import jax.numpy as jnp from MaxText import pyconfig -from maxtext.checkpoint_conversion.utils.hf_utils import convert_jax_weight_to_torch -from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_TEST_ASSETS_ROOT +from maxtext.checkpoint_conversion.utils.hf_utils import convert_jax_weight_to_torch +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging