1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| class FeedForward(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float], ): """ Initialize the FeedForward module.
Args: dim (int): Input dimension. hidden_dim (int): Hidden dimension of the feedforward layer. multiple_of (int): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes: w1 (ColumnParallelLinear): Linear transformation for the first layer. w2 (RowParallelLinear): Linear transformation for the second layer. w3 (ColumnParallelLinear): Linear transformation for the third layer.
""" super().__init__() hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x ) self.w2 = RowParallelLinear( hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x ) self.w3 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x )
def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x))
|