Skip to content

Commit d5b342f

Browse files
committed
gemma2 disable
1 parent 442bc0f commit d5b342f

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

unsloth/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,16 @@ def new_init(self, *args, **kwargs):
251251
else:
252252
config_arg = kwargs.get("args")
253253

254+
# Check if model is Gemma 2 (which doesn't support padding_free without flash attention)
255+
model = kwargs.get("model")
256+
is_unsupported_gemma = False
257+
if model is not None:
258+
model_config = getattr(model, "config", None)
259+
if model_config is not None:
260+
model_type = getattr(model_config, "model_type", "").lower()
261+
# Gemma 2 uses slow_attention_softcapping which has torch.compile issues with padding_free
262+
is_unsupported_gemma = model_type == "gemma2"
263+
254264
processing_class = kwargs.get("processing_class") or kwargs.get("tokenizer")
255265
data_collator = kwargs.get("data_collator")
256266

@@ -278,12 +288,16 @@ def new_init(self, *args, **kwargs):
278288
if not blocked:
279289
if padding_free_requested:
280290
configure_padding_free(config_arg)
281-
elif _should_auto_padding_free(config_arg):
291+
elif not is_unsupported_gemma and _should_auto_padding_free(config_arg):
282292
configure_padding_free(config_arg)
283293
auto_padding_free_active = True
284294
logger.info(
285295
"Unsloth: Padding-free batching auto-enabled for SFTTrainer instance."
286296
)
297+
elif is_unsupported_gemma and _should_auto_padding_free(config_arg):
298+
logger.info(
299+
"Unsloth: Padding-free batching auto-disabled for Gemma 2 (requires flash attention)."
300+
)
287301

288302
original_init(self, *args, **kwargs)
289303

0 commit comments

Comments
 (0)