diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9e3896244c..b1b8a8fb78 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1988,6 +1988,9 @@ def unsloth_fast_generate( *args, **kwargs, ): + # If the model starts out in training mode, restore training mode after generation + restore_training_mode = self.training + FastLlamaModel.for_inference(self) dtype = _get_dtype(dtype_from_config(self.config)) @@ -2043,7 +2046,8 @@ def unsloth_fast_generate( # accelerate.utils.operations.send_to_device = accelerate_old_send_to_device # pass - FastLlamaModel.for_training(self) + if restore_training_mode: + FastLlamaModel.for_training(self) return output diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 828fe17320..a8ddbfed2a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -258,7 +258,7 @@ def from_pretrained( if model_name.lower().endswith("-bf16"): load_in_4bit = False load_in_8bit = False - load_in_fp8 = False + load_in_fp8 = False load_in_16bit = True if USE_MODELSCOPE and not os.path.exists(model_name): @@ -387,7 +387,7 @@ def from_pretrained( if model_name.lower().endswith("-bf16"): load_in_4bit = False load_in_8bit = False - load_in_fp8 = False + load_in_fp8 = False load_in_16bit = True model_config = AutoConfig.from_pretrained( @@ -711,10 +711,16 @@ def from_pretrained( ) load_in_4bit = False load_in_8bit = False - load_in_fp8 = False + load_in_fp8 = False load_in_16bit = False - if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) + int(load_in_fp8 != False) >= 2: + if ( + int(load_in_4bit) + + int(load_in_8bit) + + int(load_in_16bit) + + int(load_in_fp8 != False) + >= 2 + ): raise RuntimeError( "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n" "Also, we by default set `load_in_4bit = True`.\n"