diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 266d884b..c12d4ddc 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -628,11 +628,10 @@ def indirect_indexing(self, index_var, size, check=True): def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): # In case of index expr, dimension size should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): - new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges, self.attempted_tile_sizes) self.kernel_group.tile_desc.set_tile_size(new_tile_size) self.reset("recompile") raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") - tile_size = tile_desc.get_tile_size_per_lane() compute_vec_size = tile_desc.get_compute_vec_size() strides = tile_desc.get_tile_stride_per_lane() @@ -1277,7 +1276,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): # Note: In case of indirect indexing, dimensions should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): - new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges, self.attempted_tile_sizes) self.kernel_group.tile_desc.set_tile_size(new_tile_size) self.reset("recompile") raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 15408c0d..b58529f2 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -313,22 +313,32 @@ def is_dim_dividable(self, dim_sizes: list[int]) -> bool: return all(d % t == 0 for d, t in zip(dim_sizes_cpy, self._tile_size)) - def adjust_tile_to_divisible(self, dim_sizes: list[int]) -> list[int]: + def adjust_tile_to_divisible(self, dim_sizes: list[int], attempted_tile_sizes) -> list[int]: """Adjust current tile to be divisible by given dimensions.""" if len(dim_sizes) != len(self._tile_size): raise ValueError("dim_sizes must match the tile size dimensions") - def _adjust_one(dim_size, tile_size): + dim_sizes_cpy = list(dim_sizes) + axis, stride = self.vmap.vlane_split_axis, self.vmap.vlane_stride + remain = dim_sizes_cpy[axis] % stride + if remain: + dim_sizes_cpy[axis] += stride - remain + + def _adjust_one(dim_size, tile_size, is_split_dim, skip_size=[]): for candidate in range(tile_size, 0, -1): if dim_size % candidate == 0: - return candidate + if is_split_dim: + remain = candidate % stride + candidate += (stride - remain) if remain else 0 + if candidate not in skip_size: + return candidate return 1 - candidate_tile_size = [_adjust_one(d, t) for d, t in zip(dim_sizes, self._tile_size)] + vlane_axis_skip_size = [dim[axis] for dim in attempted_tile_sizes] + candidate_tile_size = [_adjust_one(d, t, i==axis, vlane_axis_skip_size if i == axis else []) for i, (d, t) in enumerate(zip(dim_sizes_cpy, self._tile_size))] for i in range(len(candidate_tile_size)): self.tile_constraint[i].must_divide_dim = True - axis, stride = self.vmap.vlane_split_axis, self.vmap.vlane_stride remain = candidate_tile_size[axis] % stride if remain: @@ -609,6 +619,9 @@ def __init__(self, kernel_group, reason=None): self.target_buffer_override = contextvars.ContextVar("Handler_compute_override", default=self.compute) self.target_cse_override = contextvars.ContextVar("Handler_cse_override", default=self.cse) + # Compile tile size manage + self.attempted_tile_sizes = set() + def set_ranges(self, lengths, reduction_lengths): if self.call_ranges: assert self.call_ranges == tuple(lengths) + tuple( @@ -761,6 +774,7 @@ def codegen_nodes(self, nodes, kernel_name): # Set node range info vars, reduction_vars = self.set_ranges(group, reduction_group) tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + self.attempted_tile_sizes.add(tuple(tile_desc.get_tile_size())) self.compute_body_loop.size = tile_desc.get_numel_per_lane() self.compute_body_loop.step = tile_desc.get_compute_vec_size() try: diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_timing.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_timing.yml new file mode 100644 index 00000000..3b9b8fc8 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_timing.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tests/OPT/config.json b/tests/OPT/config.json new file mode 100644 index 00000000..562d268b --- /dev/null +++ b/tests/OPT/config.json @@ -0,0 +1,28 @@ +{ + "_name_or_path": "opt-350m", + "activation_dropout": 0.0, + "activation_function": "relu", + "architectures": [ + "OPTForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 2, + "do_layer_norm_before": false, + "dropout": 0.1, + "eos_token_id": 2, + "ffn_dim": 4096, + "hidden_size": 1024, + "init_std": 0.02, + "layerdrop": 0.0, + "max_position_embeddings": 2048, + "model_type": "opt", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "pad_token_id": 1, + "prefix": "", + "torch_dtype": "float16", + "transformers_version": "4.20.0.dev0", + "use_cache": true, + "vocab_size": 50272, + "word_embed_proj_dim": 512 +} diff --git a/tests/OPT/experiment_cpu.py b/tests/OPT/experiment_cpu.py new file mode 100644 index 00000000..a283f650 --- /dev/null +++ b/tests/OPT/experiment_cpu.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + hidden = self.embed(x) + print(f"after embed hidden shape: {hidden.shape}") + q, k, v = self.qkv(hidden) + print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}") + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + print(f"attn_output shape: {attn_output.shape}") + attn_output = self.out_proj(attn_output) + print(f"after out_proj attn_output shape: {attn_output.shape}") + outputs = self.ffn(attn_output) + print(f"after ffn outputs[0] shape: {outputs[0].shape}") + logits = self.lm_head(outputs, 1) + print(f"lm_head logits shape: {logits.shape}") + return logits + + + +if __name__ == "__main__": + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = 128 + seq_len = 1128 + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + + input = torch.randint(0, vocab_size, (bsz, 1)) # (bsz, seq_len) + + with torch.no_grad(): + decoder(input) \ No newline at end of file diff --git a/tests/OPT/experiment_cpu_tp.py b/tests/OPT/experiment_cpu_tp.py new file mode 100644 index 00000000..6f3ada21 --- /dev/null +++ b/tests/OPT/experiment_cpu_tp.py @@ -0,0 +1,258 @@ +import torch +import torch.nn as nn + +from typing import Optional + +TP = 2 + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp=tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + hidden = self.embed(x) + print(f"after embed hidden shape: {hidden.shape}") + q, k, v = self.qkv(hidden) + print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}") + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + print(f"attn_output shape: {attn_output.shape}") + attn_output = self.out_proj(attn_output) + print(f"after out_proj attn_output shape: {attn_output.shape}") + outputs = self.ffn(attn_output) + print(f"after ffn outputs[0] shape: {outputs[0].shape}") + logits = self.lm_head(outputs, 1) + print(f"lm_head logits shape: {logits.shape}") + return logits + + + +if __name__ == "__main__": + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + tp = TP + + bsz = 128 + seq_len = 1128 + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = tp) + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + + input = torch.randint(0, vocab_size, (bsz, 1)) # (bsz, seq_len) + + with torch.no_grad(): + decoder(input) \ No newline at end of file diff --git a/tests/OPT/experiment_npu.py b/tests/OPT/experiment_npu.py new file mode 100644 index 00000000..b09a7d97 --- /dev/null +++ b/tests/OPT/experiment_npu.py @@ -0,0 +1,270 @@ +import os +import sys + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + q, k, v = self.qkv(x) + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + attn_output = self.out_proj(attn_output) + outputs = self.ffn(attn_output) + logits = self.lm_head(outputs, 1) + return logits + + + + +if __name__ == "__main__": + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = 2 + seq_len = 10 + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + with torch.no_grad(): + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/attn.py b/tests/OPT/opt1_3b/attn.py new file mode 100644 index 00000000..95234345 --- /dev/null +++ b/tests/OPT/opt1_3b/attn.py @@ -0,0 +1,285 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, q, k, v): + # hidden = self.embed(x) + # return hidden + + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + return attn_output + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + q = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + k = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + v = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + with torch.no_grad(): + opt_decoder(q_device, k_device, v_device) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/embed.py b/tests/OPT/opt1_3b/embed.py new file mode 100644 index 00000000..4c35764d --- /dev/null +++ b/tests/OPT/opt1_3b/embed.py @@ -0,0 +1,270 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.full( + (bsz, 1), + past_key_values_length, + dtype=torch.long, + ) + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + hidden = self.embed(x) + return hidden + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + input = torch.randint(0, vocab_size, (bsz, 1)) + input_device = input.to(device) + + with torch.no_grad(): + opt_decoder(input_device) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/ffn.py b/tests/OPT/opt1_3b/ffn.py new file mode 100644 index 00000000..daa7c637 --- /dev/null +++ b/tests/OPT/opt1_3b/ffn.py @@ -0,0 +1,282 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + outputs = self.ffn(x) + return outputs + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/lm_head.py b/tests/OPT/opt1_3b/lm_head.py new file mode 100644 index 00000000..37a77268 --- /dev/null +++ b/tests/OPT/opt1_3b/lm_head.py @@ -0,0 +1,280 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + logits = self.lm_head(x, 1) + return logits + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder((hidden_device, )) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/out_proj.py b/tests/OPT/opt1_3b/out_proj.py new file mode 100644 index 00000000..2cb07c66 --- /dev/null +++ b/tests/OPT/opt1_3b/out_proj.py @@ -0,0 +1,291 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output, residual): + self.bsz = attn_output.size(0) + self.tgt_len = attn_output.size(1) + + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x, r): + # hidden = self.embed(x) + # return hidden + + attn_output = self.out_proj(x, r) + return attn_output + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.num_heads // config.tp, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + # Residual + residual = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + residual_device = residual.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device, residual_device) \ No newline at end of file diff --git a/tests/OPT/opt1_3b/qkv.py b/tests/OPT/opt1_3b/qkv.py new file mode 100644 index 00000000..50dfc0b1 --- /dev/null +++ b/tests/OPT/opt1_3b/qkv.py @@ -0,0 +1,277 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + q, k, v = self.qkv(x) + return q, k, v + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 2048 // args.tp + hidden_size = 2048 // args.tp + num_heads = 32 // args.tp + ffn_dim = 8192 + vocab_size = 50272 + word_embed_proj_dim = 2048 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = True + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/tp_1/attn.py b/tests/OPT/tp_1/attn.py new file mode 100644 index 00000000..244ffd80 --- /dev/null +++ b/tests/OPT/tp_1/attn.py @@ -0,0 +1,284 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, q, k, v): + # hidden = self.embed(x) + # return hidden + + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + return attn_output + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + q = torch.randn( + bsz, config.num_heads, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + k = torch.randn( + bsz, config.num_heads, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + v = torch.randn( + bsz, config.num_heads, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + with torch.no_grad(): + opt_decoder(q_device, k_device, v_device) \ No newline at end of file diff --git a/tests/OPT/tp_1/embed.py b/tests/OPT/tp_1/embed.py new file mode 100644 index 00000000..4ba3f4b2 --- /dev/null +++ b/tests/OPT/tp_1/embed.py @@ -0,0 +1,267 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.full( + (bsz, 1), + past_key_values_length, + dtype=torch.long, + ) + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + hidden = self.embed(x) + return hidden + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + input = torch.randint(0, vocab_size, (bsz, 1)) + input_device = input.to(device) + with torch.no_grad(): + opt_decoder(input_device) \ No newline at end of file diff --git a/tests/OPT/tp_1/ffn.py b/tests/OPT/tp_1/ffn.py new file mode 100644 index 00000000..d8bc0ee1 --- /dev/null +++ b/tests/OPT/tp_1/ffn.py @@ -0,0 +1,279 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + outputs = self.ffn(x) + return outputs + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/tp_1/lm_head.py b/tests/OPT/tp_1/lm_head.py new file mode 100644 index 00000000..18013108 --- /dev/null +++ b/tests/OPT/tp_1/lm_head.py @@ -0,0 +1,275 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + logits = self.lm_head(x, 1) + return logits + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + with torch.no_grad(): + opt_decoder((hidden_device, )) \ No newline at end of file diff --git a/tests/OPT/tp_1/out_proj.py b/tests/OPT/tp_1/out_proj.py new file mode 100644 index 00000000..dc33e2e0 --- /dev/null +++ b/tests/OPT/tp_1/out_proj.py @@ -0,0 +1,288 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.residual = None + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output, residual): + self.bsz = attn_output.size(0) + self.tgt_len = attn_output.size(1) + + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x, r): + # hidden = self.embed(x) + # return hidden + + attn_output = self.out_proj(x, r) + return attn_output + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.num_heads, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + # Residual + residual = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + residual_device = residual.to(device) + with torch.no_grad(): + opt_decoder(hidden_device, residual_device) \ No newline at end of file diff --git a/tests/OPT/tp_1/qkv.py b/tests/OPT/tp_1/qkv.py new file mode 100644 index 00000000..1ac497f4 --- /dev/null +++ b/tests/OPT/tp_1/qkv.py @@ -0,0 +1,272 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + q, k, v = self.qkv(x) + return q, k, v + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/tp_n/attn.py b/tests/OPT/tp_n/attn.py new file mode 100644 index 00000000..c19b8267 --- /dev/null +++ b/tests/OPT/tp_n/attn.py @@ -0,0 +1,285 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, q, k, v): + # hidden = self.embed(x) + # return hidden + + attn_output = self.attn(q, k, v, None, self.scaling, 0.0) + return attn_output + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 // args.tp + hidden_size = 1024 // args.tp + num_heads = 16 // args.tp + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + q = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + k = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + v = torch.randn( + bsz, config.num_heads // config.tp, 1, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + with torch.no_grad(): + opt_decoder(q_device, k_device, v_device) \ No newline at end of file diff --git a/tests/OPT/tp_n/embed.py b/tests/OPT/tp_n/embed.py new file mode 100644 index 00000000..5d30ec3c --- /dev/null +++ b/tests/OPT/tp_n/embed.py @@ -0,0 +1,270 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.full( + (bsz, 1), + past_key_values_length, + dtype=torch.long, + ) + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + hidden = self.embed(x) + return hidden + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + input = torch.randint(0, vocab_size, (bsz, 1)) + input_device = input.to(device) + + with torch.no_grad(): + opt_decoder(input_device) \ No newline at end of file diff --git a/tests/OPT/tp_n/ffn.py b/tests/OPT/tp_n/ffn.py new file mode 100644 index 00000000..a82b2a77 --- /dev/null +++ b/tests/OPT/tp_n/ffn.py @@ -0,0 +1,282 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + outputs = self.ffn(x) + return outputs + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file diff --git a/tests/OPT/tp_n/lm_head.py b/tests/OPT/tp_n/lm_head.py new file mode 100644 index 00000000..6abca2b5 --- /dev/null +++ b/tests/OPT/tp_n/lm_head.py @@ -0,0 +1,280 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + # self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + # self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + logits = self.lm_head(x, 1) + return logits + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder((hidden_device, )) \ No newline at end of file diff --git a/tests/OPT/tp_n/out_proj.py b/tests/OPT/tp_n/out_proj.py new file mode 100644 index 00000000..a0e54bca --- /dev/null +++ b/tests/OPT/tp_n/out_proj.py @@ -0,0 +1,291 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + print(f"input shape: {input_ids.shape}, embed shape: {inputs_embeds.shape}, position shape: {position_embeds.shape}") + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output, residual): + self.bsz = attn_output.size(0) + self.tgt_len = attn_output.size(1) + + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + print(f"input hidden_states shape: {hidden_states.shape}") + print(f"config.word_embed_proj_dim, config.vocab_size: {config.word_embed_proj_dim} {config.vocab_size}") + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x, r): + # hidden = self.embed(x) + # return hidden + + attn_output = self.out_proj(x, r) + return attn_output + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.num_heads // config.tp, config.hidden_size // config.num_heads, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + # Residual + residual = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + residual_device = residual.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device, residual_device) \ No newline at end of file diff --git a/tests/OPT/tp_n/qkv.py b/tests/OPT/tp_n/qkv.py new file mode 100644 index 00000000..568a00fe --- /dev/null +++ b/tests/OPT/tp_n/qkv.py @@ -0,0 +1,277 @@ +import os +import sys +import argparse + +import torch +import torch.nn as nn + +from typing import Optional + +class LLM_Config: + def __init__(self, + embed_dim, + hidden_size, + num_heads, + ffn_dim, + vocab_size, + word_embed_proj_dim, + pad_token_id, + max_position_embeddings, + enable_bias, + layer_norm_elementwise_affine, + do_layer_norm_before, + tp): + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.vocab_size = vocab_size + self.word_embed_proj_dim = word_embed_proj_dim + self.pad_token_id = pad_token_id + self.max_position_embeddings = max_position_embeddings + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + self.do_layer_norm_before = do_layer_norm_before + self.tp = tp + + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + +class my_opt_decoder(nn.Module): + def __init__(self, config: LLM_Config, current_seq_len): + super(my_opt_decoder, self).__init__() + self.config = config + + self.head_dim = self.config.embed_dim // self.config.num_heads + self.scaling = self.head_dim**-0.5 + + # Embedding layers + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id) + self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size) + self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False) + + # KV Cache + self.register_buffer( + "past_k", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + self.register_buffer( + "past_v", + torch.randn(bsz, num_heads // self.config.tp, current_seq_len, self.head_dim) + ) + + + # QKV layers + self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim // self.config.tp, bias=self.config.enable_bias) + self.o_proj = nn.Linear(self.config.embed_dim // self.config.tp, self.config.embed_dim, bias=self.config.enable_bias) + + self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # FC layers + self.activation_fn = nn.ReLU() + self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim // self.config.tp, bias=config.enable_bias) + self.fc2 = nn.Linear(self.config.ffn_dim // self.config.tp, self.config.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + # LM head + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + def embed(self, input_ids): + # input_ids: (bsz, seq_len) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.project_in(inputs_embeds) + bsz, seq_len, _ = inputs_embeds.size() + attention_mask = (input_ids != self.config.pad_token_id).long() + position_embeds = self.embed_positions(attention_mask=attention_mask) + hidden_states = inputs_embeds + position_embeds + return hidden_states + + # qkv + rms + def qkv(self, hidden_states): + self.residual = hidden_states + self.bsz, self.tgt_len, _ = hidden_states.size() + + if self.config.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(self.bsz, -1, self.config.num_heads // self.config.tp, self.head_dim).transpose(1, 2) + + return query_states, key_states, value_states + + # QK^T + SV + def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs): + # KV cache update + self.past_k = torch.cat([self.past_k, key], dim=2) + self.past_v = torch.cat([self.past_v, value], dim=2) + + key = self.past_k + value = self.past_v + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + + + # out-proj + rms + def out_proj(self, attn_output): + attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=0.0, training=False) + attn_output = self.residual + attn_output + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + attn_output = self.self_attn_layer_norm(attn_output) + + return attn_output + + # MLP + rms + def ffn(self, hidden_states): + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.config.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + return outputs + + def lm_head(self, outputs, logits_to_keep): + hidden_states = outputs[0] + hidden_states = self.project_out(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous() + + return logits + + def forward(self, x): + # hidden = self.embed(x) + # return hidden + + q, k, v = self.qkv(x) + return q, k, v + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--bsz", type=int, default=128, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Input sequence length") + parser.add_argument("--tp", type=int, default=1, help="Input sequence length") + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + embed_dim = 1024 + hidden_size = 1024 + num_heads = 16 + ffn_dim = 4096 + vocab_size = 50272 + word_embed_proj_dim = 512 + pad_token_id = 1 + max_position_embeddings = 2048 + enable_bias = True + layer_norm_elementwise_affine = True + do_layer_norm_before = False + + bsz = args.bsz + seq_len = args.seq_len + + print(f"Batch size: {bsz}, Seq len: {seq_len}, TP: {args.tp}") + + config = LLM_Config(embed_dim = embed_dim, + hidden_size = hidden_size, + num_heads = num_heads, + ffn_dim = ffn_dim, + vocab_size = vocab_size, + word_embed_proj_dim = word_embed_proj_dim, + pad_token_id = pad_token_id, + max_position_embeddings = max_position_embeddings, + enable_bias = enable_bias, + layer_norm_elementwise_affine = layer_norm_elementwise_affine, + do_layer_norm_before = do_layer_norm_before, + tp = args.tp) + + decoder = my_opt_decoder(config, seq_len) + decoder.eval() + decoder_device = decoder.to(device=device) + opt_decoder = torch.compile(decoder_device, dynamic=False) + + # Embedding is not supported currently, just skip + # input = torch.randint(0, vocab_size, (bsz, seq_len)).to(device) # (bsz, seq_len) + # hidden = opt_decoder.embed(input) + + + hidden = torch.randn( + bsz, 1, config.hidden_size, + dtype=torch.float32 + ) + hidden_device = hidden.to(device) + + with torch.no_grad(): + opt_decoder(hidden_device) \ No newline at end of file