Skip to content

Commit 7ac29b6

Browse files
committed
disable_mmap in from_pretrained
1 parent f9c1e61 commit 7ac29b6

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ def load_sub_model(
758758
use_safetensors: bool,
759759
dduf_entries: Optional[Dict[str, DDUFEntry]],
760760
provider_options: Any,
761+
disable_mmap: bool,
761762
quantization_config: Optional[Any] = None,
762763
):
763764
"""Helper method to load the module `name` from `library_name` and `class_name`"""
@@ -854,6 +855,9 @@ def load_sub_model(
854855
else:
855856
loading_kwargs["low_cpu_mem_usage"] = False
856857

858+
if is_diffusers_model:
859+
loading_kwargs["disable_mmap"] = disable_mmap
860+
857861
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
858862
loading_kwargs.pop("offload_state_dict")
859863

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
707707
loading `from_flax`.
708708
dduf_file(`str`, *optional*):
709709
Load weights from the specified dduf file.
710+
disable_mmap ('bool', *optional*, defaults to 'False'):
711+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
712+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
710713
711714
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
712715
with `hf > auth login`.
@@ -758,6 +761,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
758761
use_onnx = kwargs.pop("use_onnx", None)
759762
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
760763
quantization_config = kwargs.pop("quantization_config", None)
764+
disable_mmap = kwargs.pop("disable_mmap", False)
761765

762766
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
763767
torch_dtype = torch.float32
@@ -1041,6 +1045,7 @@ def load_module(name, value):
10411045
use_safetensors=use_safetensors,
10421046
dduf_entries=dduf_entries,
10431047
provider_options=provider_options,
1048+
disable_mmap=disable_mmap,
10441049
quantization_config=quantization_config,
10451050
)
10461051
logger.info(

0 commit comments

Comments
 (0)