@@ -251,6 +251,16 @@ def new_init(self, *args, **kwargs):
251251 else :
252252 config_arg = kwargs .get ("args" )
253253
254+ # Check if model is Gemma 2 (which doesn't support padding_free without flash attention)
255+ model = kwargs .get ("model" )
256+ is_unsupported_gemma = False
257+ if model is not None :
258+ model_config = getattr (model , "config" , None )
259+ if model_config is not None :
260+ model_type = getattr (model_config , "model_type" , "" ).lower ()
261+ # Gemma 2 uses slow_attention_softcapping which has torch.compile issues with padding_free
262+ is_unsupported_gemma = model_type == "gemma2"
263+
254264 processing_class = kwargs .get ("processing_class" ) or kwargs .get ("tokenizer" )
255265 data_collator = kwargs .get ("data_collator" )
256266
@@ -278,12 +288,16 @@ def new_init(self, *args, **kwargs):
278288 if not blocked :
279289 if padding_free_requested :
280290 configure_padding_free (config_arg )
281- elif _should_auto_padding_free (config_arg ):
291+ elif not is_unsupported_gemma and _should_auto_padding_free (config_arg ):
282292 configure_padding_free (config_arg )
283293 auto_padding_free_active = True
284294 logger .info (
285295 "Unsloth: Padding-free batching auto-enabled for SFTTrainer instance."
286296 )
297+ elif is_unsupported_gemma and _should_auto_padding_free (config_arg ):
298+ logger .info (
299+ "Unsloth: Padding-free batching auto-disabled for Gemma 2 (requires flash attention)."
300+ )
287301
288302 original_init (self , * args , ** kwargs )
289303
0 commit comments