diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..5cd232e82f36 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -758,6 +758,7 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + disable_mmap: bool, quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -854,6 +855,9 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if is_diffusers_model: + loading_kwargs["disable_mmap"] = disable_mmap + if is_transformers_model and is_transformers_version(">=", "4.57.0"): loading_kwargs.pop("offload_state_dict") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..510e74ce88d7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -707,6 +707,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loading `from_flax`. dduf_file(`str`, *optional*): Load weights from the specified dduf file. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf > auth login`. @@ -758,6 +761,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -1041,6 +1045,7 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + disable_mmap=disable_mmap, quantization_config=quantization_config, ) logger.info(