Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from MaxText import optimizers
from MaxText import pyconfig
from maxtext.common import checkpointing
from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import gcs_utils
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp
from jax.sharding import NamedSharding

from MaxText.common_types import ShardMode
from maxtext.common.common_types import ShardMode
from MaxText.sharding import maybe_shard_with_name


Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/layerwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from MaxText import common_types
from MaxText import pyconfig
from maxtext.common import common_types
from maxtext.common import checkpointing
from maxtext.layers import quantizations
from maxtext.models import deepseek, models
Expand All @@ -49,6 +48,7 @@
from maxtext.utils import maxtext_utils
import orbax.checkpoint as ocp
from tqdm import tqdm
from MaxText import pyconfig

IGNORE = ocp.PLACEHOLDER
PRNGKeyType = Any
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import flax

from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
from MaxText.globals import MAXTEXT_PKG_DIR
from maxtext.models import models
from maxtext.layers import quantizations
Expand All @@ -48,6 +47,7 @@
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.common.gcloud_stub import jetstream, is_decoupled
from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE

config_lib, engine_api, token_utils, tokenizer_api, _token_params_ns = jetstream()
TokenizerParameters = getattr(_token_params_ns, "TokenizerParameters", object) # type: ignore[assignment]
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import omegaconf

from MaxText import pyconfig_deprecated
from MaxText.common_types import DecoderBlockType, ShardMode
from MaxText.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import DecoderBlockType, ShardMode
from maxtext.configs import types
from maxtext.configs.types import MaxTextConfig
from maxtext.inference.inference_utils import str2bool
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from MaxText import accelerator_to_spec_map
from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
from maxtext.utils import gcs_utils
from maxtext.utils import max_logging
from maxtext.utils import max_utils
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import optax

from MaxText.common_types import ShardMode
from maxtext.common.common_types import ShardMode
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/vocabulary_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
all_gather_over_fsdp,
create_sharding,
)
from MaxText.common_types import ShardMode
from maxtext.common.common_types import ShardMode
from maxtext.utils import max_utils


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
from jax.sharding import Mesh
from MaxText import optimizers
from MaxText import pyconfig
from maxtext.common import checkpointing
from MaxText.common_types import MODEL_MODE_TRAIN
from MaxText.globals import MAXTEXT_PKG_DIR
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.models.models import transformer_as_linen
from maxtext.utils import max_logging
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
from MaxText.common_types import MODEL_MODE_TRAIN
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.inference.inference_utils import str2bool
from maxtext.layers import quantizations
from maxtext.models import models
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
7 changes: 3 additions & 4 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@
from typing import Any, Literal, NewType, Optional

import jax
from MaxText import accelerator_to_spec_map
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
from MaxText.globals import MAXTEXT_ASSETS_ROOT
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
from maxtext.utils import gcs_utils
from maxtext.utils import max_utils
from MaxText import accelerator_to_spec_map
from MaxText.globals import MAXTEXT_ASSETS_ROOT
from pydantic.config import ConfigDict
from pydantic.fields import Field
from pydantic.functional_validators import field_validator, model_validator
from pydantic.main import BaseModel
from pydantic.types import NonNegativeFloat, NonNegativeInt, PositiveInt


class XProfTPUPowerTraceMode(enum.IntEnum): # pylint: disable=invalid-name
"""Enum for XProfTPUPowerTraceMode."""

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/examples/demo_decoding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
"import numpy as np\n",
"\n",
"import MaxText as mt\n",
"from MaxText import common_types\n",
"from maxtext.common import common_types\n",
"from MaxText import pyconfig\n",
"from MaxText.input_pipeline import _input_pipeline_utils\n",
"from maxtext.checkpoint_conversion import to_maxtext\n",
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/experimental/rl/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jaxtyping
from typing import Any, Callable

from MaxText.common_types import DecoderBlockType
from maxtext.common.common_types import DecoderBlockType
from maxtext.inference.offline_engine import InputData
from maxtext.utils import max_logging
from maxtext.utils import max_utils
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import variable_to_logically_partitioned

from MaxText.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR
from MaxText.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV
from maxtext.common.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR
from maxtext.common.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV


MAX_INT8 = 127.5
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/page_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from jaxtyping import Array, Integer, Bool

from MaxText.common_types import Config
from maxtext.common.common_types import Config

# Aliases using <Dims><Type><Rank>d convention
# We use string names for dimensions as they are symbolic within the type hints.
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from MaxText.common_types import Array, AxisNames, BATCH, DType, D_KV, HEAD, LENGTH, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL
from maxtext.common.common_types import Array, AxisNames, BATCH, DType, D_KV, HEAD, LENGTH, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from maxtext.inference import paged_attention_kernel_v2
from maxtext.layers.initializers import variable_to_logically_partitioned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from jax import numpy as jnp
from jax.sharding import Mesh
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
from MaxText.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE
from maxtext.utils import max_logging
from maxtext.utils import model_creation_utils

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/kernels/attention/ragged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp

from MaxText.common_types import DEFAULT_MASK_VALUE
from maxtext.common.common_types import DEFAULT_MASK_VALUE


def get_mha_cost_estimate(shape_dtype):
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from flax import nnx

from MaxText.common_types import (
from maxtext.common.common_types import (
Array,
AxisIdxes,
AxisNames,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding
from MaxText.common_types import (
from maxtext.common.common_types import (
Array,
AttentionType,
AxisIdxes,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flax import nnx

from MaxText.common_types import (
from maxtext.common.common_types import (
DecoderBlockType,
BATCH,
BATCH_NO_EXP,
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText import sharding
from MaxText.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN
from maxtext.common.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN
from maxtext.inference import page_manager
from maxtext.layers import linears
from maxtext.layers import mhc
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flax import nnx

from MaxText.sharding import logical_to_mesh_axes, create_sharding
from MaxText.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned
from maxtext.utils import max_logging
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flax import nnx
from jax.sharding import Mesh

from MaxText.common_types import Config
from maxtext.common.common_types import Config
from maxtext.layers import nnx_wrappers
from maxtext.layers import initializers

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/engram.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText.common_types import Array, Config, MODEL_MODE_TRAIN
from maxtext.common.common_types import Array, Config, MODEL_MODE_TRAIN
from maxtext.input_pipeline.tokenizer import HFTokenizer
from maxtext.layers.embeddings import Embed
from maxtext.layers.initializers import NdInitializer, nd_dense_init
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flax import nnx
from aqt.jax.v2 import aqt_tensor

from MaxText.common_types import Array, DType, Shape, PRNGKey
from maxtext.common.common_types import Array, DType, Shape, PRNGKey

Initializer = Callable[[PRNGKey, Shape, DType], Array]
InitializerAxis = int | tuple[int, ...]
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import flax.linen as nn

from MaxText.sharding import maybe_shard_with_logical
from MaxText.common_types import DecoderBlockType, ShardMode, DType, Array, Config
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT
from maxtext.common.common_types import DecoderBlockType, ShardMode, DType, Array, Config
from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT
from maxtext.layers import nnx_wrappers, quantizations
from maxtext.layers import normalizations
from maxtext.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText.common_types import Array, Config
from MaxText.common_types import HyperConnectionType
from maxtext.common.common_types import Array, Config
from maxtext.common.common_types import HyperConnectionType
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
from maxtext.layers.normalizations import RMSNorm

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from jax.sharding import NamedSharding, Mesh
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
from MaxText import common_types as ctypes
from MaxText.common_types import ShardMode
from maxtext.common import common_types as ctypes
from maxtext.common.common_types import ShardMode
from MaxText.sharding import maybe_shard_with_logical, create_sharding
from MaxText.sharding import logical_to_mesh_axes
from maxtext.layers import attentions, linears, nnx_wrappers, quantizations
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText import sharding
from MaxText.common_types import Config, MODEL_MODE_TRAIN
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
from MaxText.globals import EPS
from maxtext.layers import nnx_wrappers
from maxtext.layers.decoders import DecoderLayer
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from MaxText.common_types import Array, DType, ShardMode
from maxtext.common.common_types import Array, DType, ShardMode
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned
from maxtext.utils import max_logging
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flax import linen as nn
from flax.linen.spmd import LogicallyPartitioned

from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode
from MaxText.sharding import (
maybe_shard_with_logical,
maybe_shard_with_name,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from flax.linen import initializers as flax_initializers
import flax.linen as nn

from MaxText.common_types import DType, Config
from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant

# Params used to define mixed precision quantization configs
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
from MaxText.common_types import Config
from MaxText.common_types import HyperConnectionType, MODEL_MODE_PREFILL
from maxtext.common.common_types import Config
from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from maxtext.layers import attention_mla
from maxtext.layers import initializers
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.sharding import Mesh
import jax.numpy as jnp

from MaxText.common_types import Config
from maxtext.common.common_types import Config
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.sharding import Mesh
import jax.numpy as jnp

from MaxText.common_types import MODEL_MODE_PREFILL, Config
from maxtext.common.common_types import MODEL_MODE_PREFILL, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flax import linen as nn
from flax import nnx

from MaxText.common_types import Config, AttentionType, MODEL_MODE_PREFILL
from maxtext.common.common_types import Config, AttentionType, MODEL_MODE_PREFILL
from maxtext.layers import quantizations
from maxtext.layers import nnx_wrappers
from maxtext.layers import initializers
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flax import linen as nn
from flax import nnx

from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
from maxtext.layers import initializers, nnx_wrappers
from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes
from maxtext.layers import quantizations
Expand Down
Loading
Loading