Skip to content

Commit 1d4c226

Browse files
authored
if mlp doesn't exist in layer module check for feed_forward name for falcon h1 (#2913)
1 parent 770c88f commit 1d4c226

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

unsloth/models/llama.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,10 +2718,21 @@ def patch_peft_model(
27182718
if lora_dropout == 0 and bias == "none":
27192719
for idx, layer in enumerate(model.model.model.layers):
27202720

2721+
# Determine MLP module name (falcon_h1 has feed_forward, llama style has mlp)
2722+
if hasattr(layer, "mlp"):
2723+
mlp_module_name = "mlp"
2724+
elif hasattr(layer, "feed_forward"):
2725+
mlp_module_name = "feed_forward"
2726+
else:
2727+
logger.warning_once(f"Unsloth: No MLP module found in layer {idx} so skipping peft mlp patching")
2728+
continue
2729+
2730+
mlp_module = getattr(layer, mlp_module_name)
2731+
27212732
# MLP patching
2722-
gate_proj = layer.mlp.gate_proj
2723-
up_proj = layer.mlp. up_proj
2724-
down_proj = layer.mlp.down_proj
2733+
gate_proj = mlp_module.gate_proj
2734+
up_proj = mlp_module. up_proj
2735+
down_proj = mlp_module.down_proj
27252736

27262737
if hasattr(gate_proj, "lora_A") and \
27272738
hasattr( up_proj, "lora_A") and \
@@ -2734,7 +2745,7 @@ def patch_peft_model(
27342745
(len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
27352746

27362747
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
2737-
layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp)
2748+
mlp_module.forward = types.MethodType(_apply_lora_mlp, mlp_module)
27382749
n_mlp += 1
27392750
else:
27402751
logger.warning_once(

0 commit comments

Comments
 (0)