Skip to content

Commit fa52fef

Browse files
committed
Add Vocoder
1 parent 2f447f1 commit fa52fef

File tree

2 files changed

+602
-0
lines changed

2 files changed

+602
-0
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import math
18+
from typing import Sequence
19+
20+
import jax
21+
import jax.numpy as jnp
22+
from flax import nnx
23+
from ... import common_types
24+
25+
Array = common_types.Array
26+
DType = common_types.DType
27+
28+
class ResBlock(nnx.Module):
29+
"""
30+
Residual Block for the LTX-2 Vocoder.
31+
"""
32+
def __init__(
33+
self,
34+
channels: int,
35+
kernel_size: int = 3,
36+
stride: int = 1,
37+
dilations: Sequence[int] = (1, 3, 5),
38+
leaky_relu_negative_slope: float = 0.1,
39+
*,
40+
rngs: nnx.Rngs,
41+
dtype: DType = jnp.float32,
42+
):
43+
self.dilations = dilations
44+
self.negative_slope = leaky_relu_negative_slope
45+
46+
self.convs1 = nnx.List(
47+
[
48+
nnx.Conv(
49+
in_features=channels,
50+
out_features=channels,
51+
kernel_size=(kernel_size,),
52+
strides=(stride,),
53+
kernel_dilation=(dilation,),
54+
padding="SAME",
55+
rngs=rngs,
56+
dtype=dtype,
57+
)
58+
for dilation in dilations
59+
]
60+
)
61+
62+
self.convs2 = nnx.List(
63+
[
64+
nnx.Conv(
65+
in_features=channels,
66+
out_features=channels,
67+
kernel_size=(kernel_size,),
68+
strides=(stride,),
69+
kernel_dilation=(1,),
70+
padding="SAME",
71+
rngs=rngs,
72+
dtype=dtype,
73+
)
74+
for _ in range(len(dilations))
75+
]
76+
)
77+
78+
def __call__(self, x: Array) -> Array:
79+
for conv1, conv2 in zip(self.convs1, self.convs2):
80+
xt = jax.nn.leaky_relu(x, negative_slope=self.negative_slope)
81+
xt = conv1(xt)
82+
xt = jax.nn.leaky_relu(xt, negative_slope=self.negative_slope)
83+
xt = conv2(xt)
84+
x = x + xt
85+
return x
86+
87+
class LTX2Vocoder(nnx.Module):
88+
"""
89+
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
90+
"""
91+
def __init__(
92+
self,
93+
in_channels: int = 128,
94+
hidden_channels: int = 1024,
95+
out_channels: int = 2,
96+
upsample_kernel_sizes: Sequence[int] = (16, 15, 8, 4, 4),
97+
upsample_factors: Sequence[int] = (6, 5, 2, 2, 2),
98+
resnet_kernel_sizes: Sequence[int] = (3, 7, 11),
99+
resnet_dilations: Sequence[Sequence[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
100+
leaky_relu_negative_slope: float = 0.1,
101+
# output_sampling_rate is unused in model structure but kept for config compat
102+
output_sampling_rate: int = 24000,
103+
*,
104+
rngs: nnx.Rngs,
105+
dtype: DType = jnp.float32,
106+
):
107+
self.num_upsample_layers = len(upsample_kernel_sizes)
108+
self.resnets_per_upsample = len(resnet_kernel_sizes)
109+
self.out_channels = out_channels
110+
self.total_upsample_factor = math.prod(upsample_factors)
111+
self.negative_slope = leaky_relu_negative_slope
112+
self.dtype = dtype
113+
114+
if self.num_upsample_layers != len(upsample_factors):
115+
raise ValueError(
116+
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
117+
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
118+
)
119+
120+
if self.resnets_per_upsample != len(resnet_dilations):
121+
raise ValueError(
122+
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
123+
f" {self.resnets_per_upsample} and {len(resnet_dilations)}, respectively."
124+
)
125+
126+
# PyTorch Conv1d expects (Batch, Channels, Length), we use (Batch, Length, Channels)
127+
# So in_channels/out_channels args are standard, but data layout is transposed in __call__
128+
self.conv_in = nnx.Conv(
129+
in_features=in_channels,
130+
out_features=hidden_channels,
131+
kernel_size=(7,),
132+
strides=(1,),
133+
padding="SAME",
134+
rngs=rngs,
135+
dtype=self.dtype,
136+
)
137+
138+
self.upsamplers = nnx.List()
139+
self.resnets = nnx.List()
140+
input_channels = hidden_channels
141+
142+
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
143+
output_channels = input_channels // 2
144+
145+
# ConvTranspose with padding='SAME' matches PyTorch's specific padding logic
146+
# for these standard HiFi-GAN upsampling configurations.
147+
self.upsamplers.append(
148+
nnx.ConvTranspose(
149+
in_features=input_channels,
150+
out_features=output_channels,
151+
kernel_size=(kernel_size,),
152+
strides=(stride,),
153+
padding="SAME",
154+
rngs=rngs,
155+
dtype=self.dtype,
156+
)
157+
)
158+
159+
for res_kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
160+
self.resnets.append(
161+
ResBlock(
162+
channels=output_channels,
163+
kernel_size=res_kernel_size,
164+
dilations=dilations,
165+
leaky_relu_negative_slope=leaky_relu_negative_slope,
166+
rngs=rngs,
167+
dtype=self.dtype,
168+
)
169+
)
170+
input_channels = output_channels
171+
172+
self.conv_out = nnx.Conv(
173+
in_features=input_channels,
174+
out_features=out_channels,
175+
kernel_size=(7,),
176+
strides=(1,),
177+
padding="SAME",
178+
rngs=rngs,
179+
dtype=self.dtype
180+
)
181+
182+
def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
183+
"""
184+
Forward pass of the vocoder.
185+
186+
Args:
187+
hidden_states: Input Mel spectrogram tensor.
188+
Shape: `(B, C, T, F)` or `(B, C, F, T)`
189+
time_last: Legacy flag for input layout.
190+
191+
Returns:
192+
Audio waveform: `(B, OutChannels, AudioLength)`
193+
"""
194+
# Ensure layout: (Batch, Channels, MelBins, Time)
195+
if not time_last:
196+
hidden_states = jnp.transpose(hidden_states, (0, 1, 3, 2))
197+
198+
# Flatten Channels and MelBins -> (Batch, Features, Time)
199+
batch, channels, mel_bins, time = hidden_states.shape
200+
hidden_states = hidden_states.reshape(batch, channels * mel_bins, time)
201+
202+
# Transpose to (Batch, Time, Features) for Flax NWC Convolutions
203+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
204+
205+
hidden_states = self.conv_in(hidden_states)
206+
207+
for i in range(self.num_upsample_layers):
208+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=self.negative_slope)
209+
hidden_states = self.upsamplers[i](hidden_states)
210+
211+
# Accumulate ResNet outputs (Memory Optimization)
212+
start = i * self.resnets_per_upsample
213+
end = (i + 1) * self.resnets_per_upsample
214+
215+
res_sum = 0.0
216+
for j in range(start, end):
217+
res_sum = res_sum + self.resnets[j](hidden_states)
218+
219+
# Average the outputs (matches PyTorch mean(stack))
220+
hidden_states = res_sum / self.resnets_per_upsample
221+
222+
# Final Post-Processing
223+
# Note: using 0.01 slope here specifically (matches Diffusers implementation quirk)
224+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=0.01)
225+
hidden_states = self.conv_out(hidden_states)
226+
hidden_states = jnp.tanh(hidden_states)
227+
228+
# Transpose back to (Batch, Channels, Time) to match PyTorch/Diffusers output format
229+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
230+
231+
return hidden_states

0 commit comments

Comments
 (0)