diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f4751f59a..4aad9a0ba 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -84,6 +84,12 @@ def update(self, config): self._add_item(k, v) return self + def __delattr__(self, key): + if key in self: + super().__delitem__(key) + else: + raise AttributeError(f"{key} does not exist") + @staticmethod def from_file(filename: str): """Reads a python file and constructs a corresponding :class:`Config` object. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 35b3d646c..817e6721d 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -362,6 +362,9 @@ def args_sanity_check(): "Please make sure you are using flash attention in cuda device." ) + if "mlp_layer_fusion" not in model: + model._add_item("mlp_layer_fusion", False) + if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name): if "num_experts" not in model: model._add_item("num_experts", 1) @@ -375,9 +378,8 @@ def args_sanity_check(): model._add_item("moe_type", "GShard") if "moe_layer_kwargs" not in model: model.moe_layer_kwargs = {} - - if "mlp_layer_fusion" not in model: - model._add_item("mlp_layer_fusion", False) + if model.mlp_layer_fusion is False: + logger.warning("The config 'mlp_layer_fusion' is False, we recommend it should be set True when use MoE.") # qk_interleaved config if "qk_interleaved" not in gpc.config.model: diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index e51e5897f..3163e523a 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -106,20 +106,19 @@ def __init__( "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert ) - def forward(self, x): - if not self.mlp_layer_fusion: - w1_o = self.w1(x) - w3_o = self.w3(x) - else: - fussed_out = self.fused_w1_w3(x) - w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - if self.activation_type is ActivationType.swiglu.name: - out = self.w2(Silu(w1_o, w3_o)) + self.activation_fn = Silu else: - out = self.w2(Gelu(w1_o, w3_o)) + self.activation_fn = Gelu - return out + def forward(self, x): + if self.mlp_layer_fusion: + fused_out = self.fused_w1_w3(x) + w1_o, w3_o = torch.split(fused_out, fused_out.shape[-1] // 2, dim=-1) + else: + w1_o = self.w1(x) + w3_o = self.w3(x) + return self.w2(self.activation_fn(w1_o, w3_o)) class GroupedFeedForward(nn.Module):