Skip to content

Commit 5b75ef1

Browse files
Perseus14prishajain1
authored andcommitted
Add LTX2 Transformer integrated with Attention.
adaln.py for nnx, classes NNXPixArtAlphaCombinedTimestepSizeEmbeddings and NNXTimesteps added added dummy attention and rope, verified transformer_ltx2.py, added unit tests for transformer mesh init in test fix config file added fix fix fix fix debug debug debug debug debug fix change in rope args debug nnx.scan fix fix attention mask dim change change jax.nnx features added trying diff out_axes reverted added more unit tests change for scan_layers= False change for scan_layers= False change for scan_layers= False deleted attention and rope folder renamed from ltx-2 to ltx2 changes for integrating transformer and attention import corrected Attenton -> LTX2Attention use_bias - > bias params names changed in transformer file fixes missing mesh issue and rope reshape error passed mesh param to test for transformer dimensions changed for tests dim changed in test dim changed in test changes to transformer and test fix fix added attention_kernel param fixed num_attention_heads in test active mesh audio_path_size rope arg changed args val changed fix for tpu flash attention fix test file dims changed printing shapes for final output in dot product test as well numerical parity test fix in numerical tests fix test weights conversion weight conversion logic fix print fixed and weight conversion testing weight mapping unit tests added import fix in ltx_2_transformer_test.py import fix in ltx_2_transformer_test.py changed input dims in unit tests test for transformer removed extra prints from transformer_ltx2.py removed unused arg qk_norm reformatted adding support for rope modifying test for testing rope = split test for rope type split added debug statement debug statement debug testing with rope_type = split added rope_type param to attention calls testing with interleaved Cleaned up ltx2 transformer tests and implementations clean up, ensuring layers are in fp32 fix Added attention_ltx2.py and transformer_ltx2.py Reformatted with pyink
1 parent 9bd1dde commit 5b75ef1

File tree

7 files changed

+2861
-0
lines changed

7 files changed

+2861
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
attention_sharding_uniform: True
6+
7+
jax_cache_dir: ''
8+
weights_dtype: 'bfloat16'
9+
activations_dtype: 'bfloat16'
10+
11+
12+
run_name: ''
13+
output_dir: ''
14+
config_path: ''
15+
save_config_to_gcs: False
16+
17+
#Checkpoints
18+
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
19+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
20+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
21+
frame_rate: 30
22+
max_sequence_length: 512
23+
sampler: "from_checkpoint"
24+
25+
# Generation parameters
26+
pipeline_type: multi-scale
27+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
28+
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
29+
height: 512
30+
width: 512
31+
num_frames: 88
32+
flow_shift: 5.0
33+
downscale_factor: 0.6666666
34+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
35+
prompt_enhancement_words_threshold: 120
36+
stg_mode: "attention_values"
37+
decode_timestep: 0.05
38+
decode_noise_scale: 0.025
39+
seed: 10
40+
conditioning_media_paths: None #["IMAGE_PATH"]
41+
conditioning_start_frames: [0]
42+
43+
44+
first_pass:
45+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
46+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
47+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
48+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
49+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
50+
num_inference_steps: 30
51+
skip_final_inference_steps: 3
52+
skip_initial_inference_steps: 0
53+
cfg_star_rescale: True
54+
55+
second_pass:
56+
guidance_scale: [1]
57+
stg_scale: [1]
58+
rescaling_scale: [1]
59+
guidance_timesteps: [1.0]
60+
skip_block_list: [27]
61+
num_inference_steps: 30
62+
skip_initial_inference_steps: 17
63+
skip_final_inference_steps: 0
64+
cfg_star_rescale: True
65+
66+
#parallelism
67+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
68+
logical_axis_rules: [
69+
['batch', 'data'],
70+
['activation_heads', 'fsdp'],
71+
['activation_batch', 'data'],
72+
['activation_kv', 'tensor'],
73+
['mlp','tensor'],
74+
['embed','fsdp'],
75+
['heads', 'tensor'],
76+
['norm', 'fsdp'],
77+
['conv_batch', ['data','fsdp']],
78+
['out_channels', 'tensor'],
79+
['conv_out', 'fsdp'],
80+
['conv_in', 'fsdp']
81+
]
82+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
83+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
84+
dcn_fsdp_parallelism: -1
85+
dcn_context_parallelism: 1
86+
dcn_tensor_parallelism: 1
87+
ici_data_parallelism: 1
88+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
89+
ici_context_parallelism: 1
90+
ici_tensor_parallelism: 1
91+
92+
allow_split_physical_axes: False
93+
learning_rate_schedule_steps: -1
94+
max_train_steps: 500
95+
pretrained_model_name_or_path: ''
96+
unet_checkpoint: ''
97+
dataset_name: 'diffusers/pokemon-gpt4-captions'
98+
train_split: 'train'
99+
dataset_type: 'tf'
100+
cache_latents_text_encoder_outputs: True
101+
per_device_batch_size: 1
102+
compile_topology_num_slices: -1
103+
quantization_local_shard_count: -1
104+
use_qwix_quantization: False
105+
jit_initializers: True
106+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,80 @@ def __call__(self, timestep, guidance, pooled_projection):
501501
conditioning = time_guidance_emb + pooled_projections
502502

503503
return conditioning
504+
505+
506+
class NNXTimesteps(nnx.Module):
507+
508+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
509+
self.num_channels = num_channels
510+
self.flip_sin_to_cos = flip_sin_to_cos
511+
self.downscale_freq_shift = downscale_freq_shift
512+
self.scale = scale
513+
514+
def __call__(self, timesteps: jax.Array) -> jax.Array:
515+
return get_sinusoidal_embeddings(
516+
timesteps=timesteps,
517+
embedding_dim=self.num_channels,
518+
freq_shift=self.downscale_freq_shift,
519+
flip_sin_to_cos=self.flip_sin_to_cos,
520+
scale=self.scale,
521+
)
522+
523+
524+
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
525+
526+
def __init__(
527+
self,
528+
rngs: nnx.Rngs,
529+
embedding_dim: int,
530+
size_emb_dim: int,
531+
use_additional_conditions: bool = False,
532+
dtype: jnp.dtype = jnp.float32,
533+
weights_dtype: jnp.dtype = jnp.float32,
534+
):
535+
self.outdim = size_emb_dim
536+
self.use_additional_conditions = use_additional_conditions
537+
538+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
539+
self.timestep_embedder = NNXTimestepEmbedding(
540+
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
541+
)
542+
543+
if use_additional_conditions:
544+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
545+
self.resolution_embedder = NNXTimestepEmbedding(
546+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
547+
)
548+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
549+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
550+
)
551+
552+
def __call__(
553+
self,
554+
timestep: jax.Array,
555+
resolution: Optional[jax.Array] = None,
556+
aspect_ratio: Optional[jax.Array] = None,
557+
hidden_dtype: jnp.dtype = jnp.float32,
558+
) -> jax.Array:
559+
timesteps_proj = self.time_proj(timestep)
560+
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
561+
562+
if self.use_additional_conditions:
563+
if resolution is None or aspect_ratio is None:
564+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
565+
566+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
567+
resolution_emb = self.resolution_embedder(resolution_emb)
568+
# Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
569+
# assuming resolution input was (batch_size, ...) so flatten logic holds.
570+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
571+
572+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
573+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
574+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
575+
576+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
577+
else:
578+
conditioning = timesteps_emb
579+
580+
return conditioning

0 commit comments

Comments
 (0)