@@ -481,3 +481,29 @@ def forward(self, query, key, value, attn_mask):
481481 # attn: [batch_size, n_heads, len_q, len_k] value: [batch_size, n_heads, len_v(=len_k), d_v]
482482 context = matmul4d (attn , value )
483483 return context , attn
484+
485+
486+ class PoswiseFeedForwardNet (layer .Layer ):
487+ def __init__ (self , d_model = 512 , dim_feedforward = 2048 , bias = False ):
488+ super (PoswiseFeedForwardNet , self ).__init__ ()
489+
490+ self .d_model = d_model
491+ self .dim_feedforward = dim_feedforward
492+ self .bias = bias
493+
494+ self .linear1 = Linear3D (d_model , dim_feedforward , bias = bias )
495+ self .relu = layer .ReLU ()
496+ self .linear2 = Linear3D (dim_feedforward , d_model , bias = bias )
497+ self .add = layer .Add ()
498+ self .norm = LayerNorm (d_model )
499+
500+ def forward (self , inputs ):
501+ # inputs: [batch_size, seq_len, d_model]
502+ residual = inputs
503+ output = self .linear1 (inputs )
504+ output = self .relu (output )
505+ output = self .linear2 (output )
506+ # [batch_size, seq_len, d_model]
507+ output = self .add (output , residual )
508+ output = self .norm (output )
509+ return output
0 commit comments