From df0e2e9c81ae270a9b3dc0eb26c505493c24d920 Mon Sep 17 00:00:00 2001 From: ExtremeViscent Date: Wed, 28 Aug 2024 03:51:14 +0800 Subject: [PATCH 1/2] Enabled fully sharded decoding --- videosys/core/comm.py | 121 +++++++- .../autoencoders/autoencoder_kl_open_sora.py | 96 ++++++- .../autoencoder_kl_open_sora_plan.py | 103 +++++++ videosys/pipelines/latte/pipeline_latte.py | 17 +- .../open_sora_plan/pipeline_open_sora_plan.py | 3 + videosys/utils/vae_utils.py | 267 ++++++++++++++++++ 6 files changed, 604 insertions(+), 3 deletions(-) create mode 100644 videosys/utils/vae_utils.py diff --git a/videosys/core/comm.py b/videosys/core/comm.py index 175fba59..40796507 100644 --- a/videosys/core/comm.py +++ b/videosys/core/comm.py @@ -255,7 +255,7 @@ def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): # skip if only one rank involved world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) - if world_size == 1: + if world_size == 1 or input_.size(dim) < world_size: return input_ if pad > 0: @@ -418,3 +418,122 @@ def all_to_all_with_pad( input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) return input_ + +# ====================================================== +# Halo Exchange +# ====================================================== + + +def _halo_exchange_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): + # skip if only one rank involved + if input_.size(dim) < dist.get_world_size(pg): + return input_ + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + rank_list = dist.get_process_group_ranks(pg) + input_.shape[dim] // world_size + + dst_l = (rank - 1) % world_size + dst_r = (rank + 1) % world_size + + send_l = input_.narrow(dim, 0, pad).contiguous() + send_r = input_.narrow(dim, input_.size(dim) - pad, pad).contiguous() + recv_l = torch.zeros_like(send_l) + recv_r = torch.zeros_like(send_r) + + is_odd = rank % 2 == 1 + dst_l = rank_list[dst_l] + dst_r = rank_list[dst_r] + if is_odd: + dist.send(send_l, dst_l, group=pg) + dist.send(send_r, dst_r, group=pg) + else: + dist.recv(recv_r, dst_r, group=pg) + dist.recv(recv_l, dst_l, group=pg) + if is_odd: + dist.recv(recv_r, dst_r, group=pg) + dist.recv(recv_l, dst_l, group=pg) + else: + dist.send(send_l, dst_l, group=pg) + dist.send(send_r, dst_r, group=pg) + + if rank == 0: + output = torch.cat([input_, recv_r], dim=dim) + elif rank == world_size - 1: + output = torch.cat([recv_l, input_], dim=dim) + else: + output = torch.cat([recv_l, input_, recv_r], dim=dim) + + return output + + +class _HaloExchange(torch.autograd.Function): + """ + Halo exchange. + + Args: + input_: input matrix. + process_group: process group. + dim: dimension + pad: padding size + """ + + @staticmethod + def symbolic(graph, input_): + return _halo_exchange_func(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, pad): + ctx.process_group = process_group + ctx.dim = dim + ctx.pad = pad + return _halo_exchange_func(input_, process_group, dim, pad) + + @staticmethod + def backward(ctx, grad_output): + raise NotImplementedError("Halo exchange does not support backward now.") + + +def halo_exchange(input_, process_group, dim, pad): + return _HaloExchange.apply(input_, process_group, dim, pad) + + +# ====================================================== +# All Reduce +# ====================================================== + + +def _all_reduce_func(input_, pg: dist.ProcessGroup, op: dist.ReduceOp): + dist.get_world_size(pg) + dist.get_rank(pg) + dist.all_reduce(input_, op=op, group=pg) + return input_ + + +class _AllReduce(torch.autograd.Function): + """ + All reduce. + + Args: + input_: input matrix. + process_group: process group. + op: reduce operation + """ + + @staticmethod + def symbolic(graph, input_): + return _all_reduce_func(input_) + + @staticmethod + def forward(ctx, input_, process_group, op): + ctx.process_group = process_group + ctx.op = op + return _all_reduce_func(input_, process_group, op) + + @staticmethod + def backward(ctx, grad_output): + raise NotImplementedError("All reduce does not support backward now.") + + +def all_reduce(input_, process_group, op=dist.ReduceOp.SUM): + return _AllReduce.apply(input_, process_group, op) \ No newline at end of file diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py index 2919b296..c4016df1 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py @@ -17,7 +17,9 @@ from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from einops import rearrange from transformers import PretrainedConfig, PreTrainedModel - +from videosys.utils.vae_utils import _replace_conv_fwd, _replace_groupnorm_fwd, _replace_conv_opensora_fwd, dynamic_switch +from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group, get_sequence_parallel_size +from videosys.core.comm import split_sequence, gather_sequence class DiagonalGaussianDistribution(object): def __init__( @@ -119,6 +121,10 @@ def __init__( dilation = (dilation, 1, 1) self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + def set_sequence_parallel(self): + _replace_conv_fwd(self.conv) + _replace_conv_opensora_fwd(self) + def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) x = self.conv(x) @@ -152,6 +158,14 @@ def __init__( else: self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) + def set_sequence_parallel(self): + _replace_groupnorm_fwd(self.norm1) + _replace_groupnorm_fwd(self.norm2) + self.conv1.set_sequence_parallel() + self.conv2.set_sequence_parallel() + if self.in_channels != self.filters: + self.conv3.set_sequence_parallel() + def forward(self, x): residual = x x = self.norm1(x) @@ -353,6 +367,19 @@ def __init__( self.conv_out = self.conv_fn(filters, in_out_channels, 3) + def set_sequence_parallel(self): + self.conv1.set_sequence_parallel() + for i in range(self.num_res_blocks): + self.res_blocks[i].set_sequence_parallel() + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + self.block_res_blocks[i][j].set_sequence_parallel() + if i > 0: + if isinstance(self.conv_blocks[i - 1], CausalConv3d): + self.conv_blocks[i - 1].set_sequence_parallel() + _replace_groupnorm_fwd(self.norm1) + self.conv_out.set_sequence_parallel() + def forward(self, x): x = self.conv1(x) for i in range(self.num_res_blocks): @@ -439,6 +466,23 @@ def get_latent_size(self, input_size): latent_size.append(lsize) return latent_size + def get_video_size(self, latent_size): + video_size = [] + for i in range(3): + if latent_size[i] is None: + vsize = None + elif i == 0: + time_padding = ( + 0 + if (latent_size[i] % self.time_downsample_factor == 0) + else self.time_downsample_factor - latent_size[i] % self.time_downsample_factor + ) + vsize = latent_size[i] * self.patch_size[i] - time_padding + else: + vsize = latent_size[i] * self.patch_size[i] + video_size.append(vsize) + return video_size + def encode(self, x): time_padding = ( 0 @@ -546,6 +590,15 @@ def get_latent_size(self, input_size): # ), "Input size must be divisible by patch size" latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) return latent_size + + def get_video_size(self, latent_size): + video_size = [] + for i in range(3): + # assert ( + # latent_size[i] is None or latent_size[i] % self.patch_size[i] == 0 + # ), "Latent size must be divisible by patch size" + video_size.append(latent_size[i] * self.patch_size[i] if latent_size[i] is not None else None) + return video_size @property def device(self): @@ -650,6 +703,12 @@ def __init__(self, config: VideoAutoencoderPipelineConfig): shift = shift[None, :, None, None, None] self.register_buffer("scale", scale) self.register_buffer("shift", shift) + if enable_sequence_parallel(): + self.set_sequence_parallel() + + def set_sequence_parallel(self): + self.temporal_vae.decoder.set_sequence_parallel() + self.temporal_vae.post_quant_conv.set_sequence_parallel() def encode(self, x): x_z = self.spatial_vae.encode(x) @@ -674,8 +733,18 @@ def decode(self, z, num_frames=None): if not self.cal_loss: z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) + if enable_sequence_parallel(): + padding_s = z.shape[4] % get_sequence_parallel_size() + expected_s = self.get_video_size(z.shape[2:])[2] + z = F.pad(z, (0, padding_s, 0, 0, 0, 0)) + z = split_sequence(z, get_sequence_parallel_group(), dim=4) + if self.micro_frame_size is None: x_z = self.temporal_vae.decode(z, num_frames=num_frames) + if enable_sequence_parallel(): + padding_f = x_z.shape[2] % get_sequence_parallel_size() + x_z = F.pad(x_z, (0, 0, 0, 0, 0, padding_f)) + x_z = dynamic_switch(x_z, False, 2, 4) if enable_sequence_parallel() else x_z x = self.spatial_vae.decode(x_z) else: x_z_list = [] @@ -685,8 +754,20 @@ def decode(self, z, num_frames=None): x_z_list.append(x_z_bs) num_frames -= self.micro_frame_size x_z = torch.cat(x_z_list, dim=2) + if enable_sequence_parallel(): + padding_f = x_z.shape[2] % get_sequence_parallel_size() + x_z = F.pad(x_z, (0, 0, 0, 0, 0, padding_f)) + x_z = dynamic_switch(x_z, False, 2, 4) if enable_sequence_parallel() else x_z x = self.spatial_vae.decode(x_z) + if enable_sequence_parallel(): + x = gather_sequence(x, get_sequence_parallel_group(), dim=2) + x_z = gather_sequence(x_z, get_sequence_parallel_group(), dim=2) + x = x.narrow(2, 0, x.shape[2] - padding_f) + x = x.narrow(4, 0, min(expected_s, x.shape[4])) + x_z = x_z.narrow(2, 0, x_z.shape[2] - padding_f) + x_z = x_z.narrow(4, 0, x_z.shape[4] - padding_s) + if self.cal_loss: return x, x_z else: @@ -710,6 +791,19 @@ def get_latent_size(self, input_size): remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) sub_latent_size[0] += remain_size[0] return sub_latent_size + + def get_video_size(self, latent_size): + if self.micro_frame_size is None or latent_size[0] is None: + return self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(latent_size)) + else: + sub_latent_size = [self.micro_z_frame_size, latent_size[1], latent_size[2]] + sub_video_size = self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(sub_latent_size)) + sub_video_size[0] = sub_video_size[0] * (latent_size[0] // self.micro_z_frame_size) + remain_temporal_size = [latent_size[0] % self.micro_z_frame_size, None, None] + if remain_temporal_size[0] > 0: + remain_size = self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(remain_temporal_size)) + sub_video_size[0] += remain_size[0] + return sub_video_size def get_temporal_last_layer(self): return self.temporal_vae.decoder.conv_out.conv.weight diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py index 162060c5..221555f0 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py @@ -22,6 +22,10 @@ from einops import rearrange from torch import nn +from videosys.utils.vae_utils import _replace_groupnorm_fwd, dynamic_switch, _replace_conv_fwd +from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group, get_sequence_parallel_size +from videosys.core.comm import gather_sequence, split_sequence + logging.set_verbosity_error() @@ -332,7 +336,49 @@ def __init__( self.norm_out = Normalize(block_in) self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) + def set_sequence_parallel(self): + self.conv_in.set_sequence_parallel() + self.conv_out.set_sequence_parallel() + _replace_groupnorm_fwd(self.norm_out) + self.mid.block_1.set_sequence_parallel() + self.mid.block_2.set_sequence_parallel() + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + self.up[i_level].block[i_block].set_sequence_parallel() + if hasattr(self.up[i_level], "upsample"): + self.up[i_level].upsample.set_sequence_parallel() + if hasattr(self.up[i_level], "time_upsample"): + self.up[i_level].time_upsample.set_sequence_parallel() + + def _forward_sp(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = dynamic_switch(h, to_spatial_shard=False) + h = self.mid.attn_1(h) + h = dynamic_switch(h, to_spatial_shard=True) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + dynamic_switch(h, to_spatial_shard=False) + h = self.up[i_level].attn[i_block](h) + dynamic_switch(h, to_spatial_shard=True) + if hasattr(self.up[i_level], "upsample"): + dynamic_switch(h, to_spatial_shard=False) + h = self.up[i_level].upsample(h) + dynamic_switch(h, to_spatial_shard=True) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + def forward(self, z): + if enable_sequence_parallel(): + return self._forward_sp(z) h = self.conv_in(z) h = self.mid.block_1(h) h = self.mid.attn_1(h) @@ -466,6 +512,13 @@ def __init__( quant_conv_cls = resolve_str_to_obj(q_conv) self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + self.sp = False + + def set_sequence_parallel(self): + # TODO: sequence parallel for encoder + self.post_quant_conv.set_sequence_parallel() + self.decoder.set_sequence_parallel() + self.sp = True def encode(self, x): if self.use_tiling and ( @@ -750,8 +803,20 @@ def tiled_decode2d(self, z): i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size, ] + if enable_sequence_parallel(): + b, c, t, h, w = tile.shape + tile = tile.view(b * c, t, h, w) + padding_f = t % get_sequence_parallel_size() + padding_s = w % get_sequence_parallel_size() + tile = F.pad(tile, (0, padding_s, 0, 0, 0, padding_f)) + tile = tile.view(b, c, t + padding_f, h, w + padding_s) + tile = split_sequence(tile, get_sequence_parallel_group(), dim=4) tile = self.post_quant_conv(tile) decoded = self.decoder(tile) + if enable_sequence_parallel(): + decoded = gather_sequence(decoded, get_sequence_parallel_group(), dim=4) + decoded = decoded.narrow(4, 0, decoded.shape[4] - padding_s) + decoded = decoded.narrow(2, 0, decoded.shape[2] - padding_f) row.append(decoded) rows.append(row) result_rows = [] @@ -1157,6 +1222,10 @@ def _init_weights(self, init_method): if self.conv.bias is not None: nn.init.constant_(self.conv.bias, 0) + def set_sequence_parallel(self): + if enable_sequence_parallel(): + _replace_conv_fwd(self.conv) + def forward(self, x): # 1 + 16 16 as video, 1 as image first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w @@ -1430,6 +1499,34 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropo else: self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + def set_sequence_parallel(self): + self.conv1.set_sequence_parallel() + self.conv2.set_sequence_parallel() + _replace_groupnorm_fwd(self.norm1) + _replace_groupnorm_fwd(self.norm2) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut.set_sequence_parallel() + else: + self.nin_shortcut.set_sequence_parallel() + self.forward = self.forward_sp + + def forward_sp(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + def forward(self, x): h = x h = self.norm1(h) @@ -1518,6 +1615,9 @@ def __init__( self.kernel_size = kernel_size self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1) + def set_sequence_parallel(self): + self.conv.set_sequence_parallel() + def forward(self, x): t = x.shape[2] x = rearrange(x, "b c t h w -> b (c t) h w") @@ -1584,6 +1684,9 @@ def __init__( self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + def set_sequence_parallel(self): + self.conv.set_sequence_parallel() + def forward(self, x): alpha = torch.sigmoid(self.mix_factor) if x.size(2) > 1: diff --git a/videosys/pipelines/latte/pipeline_latte.py b/videosys/pipelines/latte/pipeline_latte.py index d373fee0..0ef9ad5c 100644 --- a/videosys/pipelines/latte/pipeline_latte.py +++ b/videosys/pipelines/latte/pipeline_latte.py @@ -34,9 +34,11 @@ update_steps, ) from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput +from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_size from videosys.models.transformers.latte_transformer_3d import LatteT2V from videosys.utils.logging import logger from videosys.utils.utils import save_video +from videosys.utils.vae_utils import _replace_decoder_fwd, _replace_mid_fwd, _replace_stres_fwd, _replace_up_fwd class LattePABConfig(PABConfig): @@ -213,6 +215,19 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + if enable_sequence_parallel(): + self.set_sequence_parallel() + + def set_sequence_parallel(self): + resnets = self.vae.decoder.mid_block.resnets + _replace_mid_fwd(self.vae.decoder.mid_block) + _replace_decoder_fwd(self.vae.decoder) + for up_block in self.vae.decoder.up_blocks: + resnets += up_block.resnets + _replace_up_fwd(up_block) + for resnet in resnets: + _replace_stres_fwd(resnet) + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): if emb.shape[0] == 1: @@ -895,7 +910,7 @@ def decode_latents_with_temporal_decoder(self, latents): latents = einops.rearrange(latents, "b c f h w -> (b f) c h w") video = [] - decode_chunk_size = 14 + decode_chunk_size = 14 * get_sequence_parallel_size() for frame_idx in range(0, latents.shape[0], decode_chunk_size): num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0] diff --git a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py index e51031fd..6e0ea609 100644 --- a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py +++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py @@ -33,6 +33,7 @@ update_steps, ) from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput +from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_rank from videosys.utils.logging import logger from videosys.utils.utils import save_video @@ -881,6 +882,8 @@ def generate( return VideoSysPipelineOutput(video=video) def decode_latents(self, latents): + if enable_sequence_parallel() and not self.vae.vae.sp: + self.vae.vae.set_sequence_parallel() video = self.vae.decode(latents) # b t c h w # b t c h w -> b t h w c video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() diff --git a/videosys/utils/vae_utils.py b/videosys/utils/vae_utils.py new file mode 100644 index 00000000..b88d0fe4 --- /dev/null +++ b/videosys/utils/vae_utils.py @@ -0,0 +1,267 @@ +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from videosys.core.comm import all_reduce, all_to_all_with_pad, gather_sequence, halo_exchange, split_sequence +from videosys.core.parallel_mgr import ( + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_size, +) + + +def _replace_stres_fwd(module: torch.nn.Module): + bound_method = _forward_stres.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + if module.temporal_res_block.conv1.kernel_size[1] != 0: + bound_method = _forward_conv3d_sp.__get__(module, module.__class__) + _replace_conv_fwd(module.temporal_res_block.conv1) + _replace_conv_fwd(module.temporal_res_block.conv2) + if module.temporal_res_block.use_in_shortcut: + _replace_conv_fwd(module.temporal_res_block.conv_shortcut) + _replace_groupnorm_fwd(module.temporal_res_block.norm1) + _replace_groupnorm_fwd(module.temporal_res_block.norm2) + + +def _replace_mid_fwd(module: torch.nn.Module): + bound_method = _forward_mid.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + + +def _replace_up_fwd(module: torch.nn.Module): + bound_method = _forward_up.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + + +def _replace_decoder_fwd(module: torch.nn.Module): + bound_method = _forward_decoder.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + + +def _replace_groupnorm_fwd(module: nn.GroupNorm): + bound_method = dist_groupnorm.__get__(module, module.__class__) + bound_method = partial( + bound_method, + weight=module.weight, + bias=module.bias, + eps=module.eps, + group_num=module.num_groups, + group=get_sequence_parallel_group(), + ) + setattr(module, "forward", bound_method) + + +def _replace_conv_fwd(module: nn.Conv3d): + bound_method = _forward_conv3d_sp.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + set_sp_padding(module) + +def _replace_conv_opensora_fwd(module: nn.Module): + bound_method = _forward_conv3d_opensora.__get__(module, module.__class__) + setattr(module, "forward", bound_method) + + +def dist_groupnorm( + self, x: torch.Tensor, group_num: int, weight: torch.Tensor, bias: torch.Tensor, eps: float, group +) -> torch.Tensor: + # x: input features with shape [N,C,H,W] + # weight, bias: scale and offset, with shape [C] + # group_num: number of groups for GN + + x_shape = x.shape + batch_size = x_shape[0] + dtype = x.dtype + x = x.to(torch.float32) + x = x.reshape(batch_size, group_num, -1) + + mean = x.mean(dim=-1, keepdim=True) + mean = all_reduce(mean, group) + mean = mean / dist.get_world_size(group) + + var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) + var = all_reduce(var, group) + var = var / dist.get_world_size(group) + + x = (x - mean) / torch.sqrt(var + eps) + + x = x.view(x_shape).to(dtype) + x = weight.view(1, -1, 1, 1, 1) * x + bias.view(1, -1, 1, 1, 1) + return x + + +def set_sp_padding(module: nn.Conv3d): + padding_ = module.padding + padding = [] + for i in range(3): + padding.insert(0, padding_[i]) + padding.insert(0, padding_[i]) + module.width_pad = padding[0] + if get_sequence_parallel_rank() == 0: + padding[1] = 0 + elif get_sequence_parallel_rank() == get_sequence_parallel_size() - 1: + padding[0] = 0 + else: + padding[0] = 0 + padding[1] = 0 + module.padding = (0, 0, 0) + module.padding_mode = "zeros" + module._reversed_padding_repeated_twice = tuple(padding) + + +def _forward_conv3d_sp(self, x): + halo_size = self.kernel_size[2] // 2 + x = halo_exchange(x, get_sequence_parallel_group(), 4, halo_size) + x = F.pad(x, self._reversed_padding_repeated_twice, mode="constant", value=0) + ret = self._conv_forward(x, self.weight, self.bias) + return ret + + +def dynamic_switch(x, to_spatial_shard: bool, scatter_dim: int = 0, gather_dim: int = 3): + if to_spatial_shard: + scatter_dim, gather_dim = gather_dim, scatter_dim + + if x.shape[scatter_dim] % get_sequence_parallel_size() != 0 or x.shape[gather_dim] < get_sequence_parallel_size(): + return x + + x = all_to_all_with_pad( + x, + get_sequence_parallel_group(), + scatter_dim=scatter_dim, + gather_dim=gather_dim, + ) + return x + + +def _forward_stres( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, +): + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=False) + num_frames = image_only_indicator.shape[-1] + hidden_states = self.spatial_res_block(hidden_states, temb) + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=True) + + batch_frames, channels, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states_mix = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + hidden_states = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + + if temb is not None: + temb = temb.reshape(batch_size, num_frames, -1) + + hidden_states = self.temporal_res_block(hidden_states, temb) + hidden_states = self.time_mixer( + x_spatial=hidden_states_mix, + x_temporal=hidden_states, + image_only_indicator=image_only_indicator, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + return hidden_states + + +def _forward_mid( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, +): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=False) + hidden_states = attn(hidden_states) + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=True) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + return hidden_states + + +def _forward_up( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, +) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=False) + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + hidden_states = dynamic_switch(hidden_states, to_spatial_shard=True) + + return hidden_states + + +def _forward_decoder( + self, + sample: torch.Tensor, + image_only_indicator: torch.Tensor, + num_frames: int = 1, +) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + # middle + padding_s = sample.shape[-1] % get_sequence_parallel_size() + padding_f = sample.shape[0] % get_sequence_parallel_size() + sample = F.pad(sample, (0, padding_s, 0, 0, 0, padding_f)) + sample = split_sequence(sample, get_sequence_parallel_group(), dim=3) + sample = self.mid_block(sample, image_only_indicator=image_only_indicator) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, image_only_indicator=image_only_indicator) + + sample = gather_sequence(sample, get_sequence_parallel_group(), dim=3) + sample = sample.narrow(3, 0, sample.shape[3] - padding_s) + sample = sample.narrow(0, 0, sample.shape[0] - padding_f) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + batch_frames, channels, height, width = sample.shape + batch_size = batch_frames // num_frames + sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + sample = self.time_conv_out(sample) + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + return sample + +def _forward_conv3d_opensora(self, x): + time_causal_padding = list(self.time_causal_padding) + width_pad = time_causal_padding[0] + if get_sequence_parallel_rank() == 0: + time_causal_padding[1] = 0 + elif get_sequence_parallel_rank() == get_sequence_parallel_size() - 1: + time_causal_padding[0] = 0 + else: + time_causal_padding[0] = 0 + time_causal_padding[1] = 0 + time_causal_padding = tuple(time_causal_padding) + x = F.pad(x, time_causal_padding, mode=self.pad_mode) + x = self.conv(x) + return x \ No newline at end of file From f31df2f71278c1f63c56ed61da47b53cc44f2eb7 Mon Sep 17 00:00:00 2001 From: ExtremeViscent Date: Thu, 12 Sep 2024 16:00:53 +0800 Subject: [PATCH 2/2] Add encoder sharding for OpenSora and OpenSora-Plan --- .../autoencoders/autoencoder_kl_open_sora.py | 32 ++++++- .../autoencoder_kl_open_sora_plan.py | 89 ++++++++++++++++++- videosys/utils/vae_utils.py | 2 - 3 files changed, 115 insertions(+), 8 deletions(-) diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py index c4016df1..248d588d 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py @@ -18,7 +18,7 @@ from einops import rearrange from transformers import PretrainedConfig, PreTrainedModel from videosys.utils.vae_utils import _replace_conv_fwd, _replace_groupnorm_fwd, _replace_conv_opensora_fwd, dynamic_switch -from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group, get_sequence_parallel_size +from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group, get_sequence_parallel_size, get_sequence_parallel_rank from videosys.core.comm import split_sequence, gather_sequence class DiagonalGaussianDistribution(object): @@ -270,6 +270,19 @@ def __init__( self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") + def set_sequence_parallel(self): + self.conv_in.set_sequence_parallel() + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + self.block_res_blocks[i][j].set_sequence_parallel() + if i < self.num_blocks - 1: + if isinstance(self.conv_blocks[i], CausalConv3d): + self.conv_blocks[i].set_sequence_parallel() + for i in range(self.num_res_blocks): + self.res_blocks[i].set_sequence_parallel() + _replace_groupnorm_fwd(self.norm1) + self.conv2.set_sequence_parallel() + def forward(self, x): x = self.conv_in(x) @@ -707,12 +720,23 @@ def __init__(self, config: VideoAutoencoderPipelineConfig): self.set_sequence_parallel() def set_sequence_parallel(self): + self.temporal_vae.encoder.set_sequence_parallel() self.temporal_vae.decoder.set_sequence_parallel() + self.temporal_vae.quant_conv.set_sequence_parallel() self.temporal_vae.post_quant_conv.set_sequence_parallel() def encode(self, x): + if enable_sequence_parallel(): + padding_f = x.shape[2] % get_sequence_parallel_size() + x = F.pad(x, (0, 0, 0, 0, 0, padding_f)) + x = split_sequence(x, get_sequence_parallel_group(), dim=2) x_z = self.spatial_vae.encode(x) - + padding_s = 0 + if enable_sequence_parallel(): + padding_s = x_z.shape[4] % get_sequence_parallel_size() + x_z = F.pad(x_z, (0, padding_s, 0, 0, 0, 0)) + x_z = dynamic_switch(x_z, True, 2, 4) + x_z = x_z.narrow(2, 0, x_z.shape[2] - padding_f) if self.micro_frame_size is None: posterior = self.temporal_vae.encode(x_z) z = posterior.sample() @@ -723,6 +747,10 @@ def encode(self, x): posterior = self.temporal_vae.encode(x_z_bs) z_list.append(posterior.sample()) z = torch.cat(z_list, dim=2) + + if enable_sequence_parallel(): + z = gather_sequence(z, get_sequence_parallel_group(), dim=4) + z = z.narrow(4, 0, z.shape[4] - padding_s) if self.cal_loss: return z, posterior, x_z diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py index 221555f0..25c37cac 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py @@ -228,6 +228,54 @@ def __init__( padding=1, ) + def set_sequence_parallel(self): + self.conv_out.set_sequence_parallel() + _replace_groupnorm_fwd(self.norm_out) + self.mid.block_1.set_sequence_parallel() + self.mid.block_2.set_sequence_parallel() + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + if self.down[i_level].block[i_block].__class__.__name__ == "ResnetBlock3D": + self.down[i_level].block[i_block].set_sequence_parallel() + if hasattr(self.down[i_level], "time_downsample"): + self.down[i_level].time_downsample.set_sequence_parallel() + + def _forward_sp(self, x): + x = dynamic_switch(x, to_spatial_shard=False) + hs = [self.conv_in(x)] + hs[-1] = dynamic_switch(hs[-1], to_spatial_shard=True) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + if self.down[i_level].block[i_block].__class__.__name__ == "ResnetBlock3D": + h = self.down[i_level].block[i_block](hs[-1]) + else: + h_ = dynamic_switch(hs[-1], to_spatial_shard=False) + h = self.down[i_level].block[i_block](h_) + h = dynamic_switch(h, to_spatial_shard=True) + if len(self.down[i_level].attn) > 0: + h = dynamic_switch(h, to_spatial_shard=False) + h = self.down[i_level].attn[i_block](h) + h = dynamic_switch(h, to_spatial_shard=True) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + h_ = dynamic_switch(hs[-1], to_spatial_shard=False) + hs.append(self.down[i_level].downsample(h_)) + hs[-1] = dynamic_switch(hs[-1], to_spatial_shard=True) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = dynamic_switch(h, to_spatial_shard=False) + h = self.mid.attn_1(h) + h = dynamic_switch(h, to_spatial_shard=True) + h = self.mid.block_2(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + def forward(self, x): hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): @@ -362,13 +410,13 @@ def _forward_sp(self, z): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: - dynamic_switch(h, to_spatial_shard=False) + h = dynamic_switch(h, to_spatial_shard=False) h = self.up[i_level].attn[i_block](h) - dynamic_switch(h, to_spatial_shard=True) + h = dynamic_switch(h, to_spatial_shard=True) if hasattr(self.up[i_level], "upsample"): - dynamic_switch(h, to_spatial_shard=False) + h = dynamic_switch(h, to_spatial_shard=False) h = self.up[i_level].upsample(h) - dynamic_switch(h, to_spatial_shard=True) + h = dynamic_switch(h, to_spatial_shard=True) if hasattr(self.up[i_level], "time_upsample"): h = self.up[i_level].time_upsample(h) h = self.norm_out(h) @@ -516,7 +564,10 @@ def __init__( def set_sequence_parallel(self): # TODO: sequence parallel for encoder + print("set_sequence_parallel") self.post_quant_conv.set_sequence_parallel() + self.quant_conv.set_sequence_parallel() + self.encoder.set_sequence_parallel() self.decoder.set_sequence_parallel() self.sp = True @@ -527,7 +578,21 @@ def encode(self, x): or x.shape[-3] > self.tile_sample_min_size_t ): return self.tiled_encode(x) + if enable_sequence_parallel(): + b, c, t, h, w = x.shape + x = x.view(b * c, t, h, w) + padding_f = t % get_sequence_parallel_size() + padding_s = w % get_sequence_parallel_size() + x = F.pad(x, (0, padding_s, 0, 0, 0, padding_f)) + x = x.view(b, c, t + padding_f, h, w + padding_s) + x = split_sequence(x, get_sequence_parallel_group(), dim=4) + print(x.shape) h = self.encoder(x) + if enable_sequence_parallel(): + h = gather_sequence(h, get_sequence_parallel_group(), dim=4) + # h = h.narrow(4, 0, h.shape[4] - padding_s) + # h = h.narrow(2, 0, h.shape[2] - padding_f) + print(h.shape) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior @@ -762,8 +827,21 @@ def tiled_encode2d(self, x, return_moments=False): i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size, ] + if enable_sequence_parallel(): + b, c, t, h, w = tile.shape + tile = tile.view(b * c, t, h, w) + padding_f = t % get_sequence_parallel_size() + padding_s = w % get_sequence_parallel_size() + tile = F.pad(tile, (0, padding_s, 0, 0, 0, padding_f)) + tile = tile.view(b, c, t + padding_f, h, w + padding_s) + tile = split_sequence(tile, get_sequence_parallel_group(), dim=4) + print(tile.shape) tile = self.encoder(tile) tile = self.quant_conv(tile) + if enable_sequence_parallel(): + tile = gather_sequence(tile, get_sequence_parallel_group(), dim=4) + tile = tile.narrow(4, 0, tile.shape[4] - padding_s) + tile = tile.narrow(2, 0, tile.shape[2] - padding_f) row.append(tile) rows.append(row) result_rows = [] @@ -1665,6 +1743,9 @@ def __init__( self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1)) self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + def set_sequence_parallel(self): + _replace_conv_fwd(self.conv) + def forward(self, x): alpha = torch.sigmoid(self.mix_factor) first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1)) diff --git a/videosys/utils/vae_utils.py b/videosys/utils/vae_utils.py index b88d0fe4..ff877e21 100644 --- a/videosys/utils/vae_utils.py +++ b/videosys/utils/vae_utils.py @@ -126,7 +126,6 @@ def dynamic_switch(x, to_spatial_shard: bool, scatter_dim: int = 0, gather_dim: if x.shape[scatter_dim] % get_sequence_parallel_size() != 0 or x.shape[gather_dim] < get_sequence_parallel_size(): return x - x = all_to_all_with_pad( x, get_sequence_parallel_group(), @@ -253,7 +252,6 @@ def _forward_decoder( def _forward_conv3d_opensora(self, x): time_causal_padding = list(self.time_causal_padding) - width_pad = time_causal_padding[0] if get_sequence_parallel_rank() == 0: time_causal_padding[1] = 0 elif get_sequence_parallel_rank() == get_sequence_parallel_size() - 1: