Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike |
subfolder=subfolder,
module_file=module_file,
class_name=class_name,
trust_remote_code=trust_remote_code,
**hub_kwargs,
)
else:
Expand All @@ -143,6 +144,7 @@ def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike |
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
trust_remote_code=trust_remote_code,
)

if model_cls is None:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def from_pretrained(
pretrained_model_name_or_path,
module_file=module_file,
class_name=class_name,
trust_remote_code=trust_remote_code,
**hub_kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,14 @@ def simple_get_class_obj(library_name, class_name):


def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=None,
cache_dir=None,
trust_remote_code: bool = False,
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
Expand All @@ -426,7 +433,10 @@ def get_class_obj_and_candidates(
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
component_folder,
module_file=library_name + ".py",
class_name=class_name,
trust_remote_code=trust_remote_code,
)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
Expand All @@ -450,6 +460,7 @@ def _get_custom_pipeline_class(
class_name=None,
cache_dir=None,
revision=None,
trust_remote_code: bool = False,
):
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
Expand All @@ -473,6 +484,7 @@ def _get_custom_pipeline_class(
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
trust_remote_code=trust_remote_code,
)


Expand All @@ -486,6 +498,7 @@ def _get_pipeline_class(
class_name=None,
cache_dir=None,
revision=None,
trust_remote_code: bool = False,
):
if custom_pipeline is not None:
return _get_custom_pipeline_class(
Expand All @@ -495,6 +508,7 @@ def _get_pipeline_class(
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
trust_remote_code=trust_remote_code,
)

if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
Expand Down Expand Up @@ -766,6 +780,7 @@ def load_sub_model(
disable_mmap: bool,
quantization_config: Any | None = None,
use_flashpack: bool = False,
trust_remote_code: bool = False,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
Expand All @@ -780,6 +795,7 @@ def load_sub_model(
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
trust_remote_code=trust_remote_code,
)

load_method_name = None
Expand Down
25 changes: 9 additions & 16 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
quantization_config = kwargs.pop("quantization_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
disable_mmap = kwargs.pop("disable_mmap", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)

if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
Expand Down Expand Up @@ -871,6 +872,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
variant=variant,
dduf_file=dduf_file,
load_connected_pipeline=load_connected_pipeline,
trust_remote_code=trust_remote_code,
**kwargs,
)
else:
Expand Down Expand Up @@ -928,6 +930,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
trust_remote_code=trust_remote_code,
)

if device_map is not None and pipeline_class._load_connected_pipes:
Expand Down Expand Up @@ -1077,6 +1080,7 @@ def load_module(name, value):
disable_mmap=disable_mmap,
quantization_config=quantization_config,
use_flashpack=use_flashpack,
trust_remote_code=trust_remote_code,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down Expand Up @@ -1684,21 +1688,6 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike:
custom_class_name = config_dict["_class_name"][1]

load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note this remains to control repo_id:

repo_id=pretrained_model_name if load_pipe_from_hub else None,

def _get_custom_pipeline_class(
custom_pipeline,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

It helps distinguish between:
a) custom_pipeline is e.g. my_pipeline and that filename exists in pretrained_model_name's files
b) custom_pipeline is Hub repo (and pipeline.py is used)

Maybe could be renamed load_pipe_from_hub -> hub_contains_custom_pipeline?

load_components_from_hub = len(custom_components) > 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See case 3 - this code is in download which is only reached when we use a Hub path with from_pretrained. It is replaced by the check in get_cached_module_file, specifically the local code path.

Consider this scenario:

hf download rotcasuoicilam/SuperCoolNewModel --local-dir rotcasuoicilam/SuperCoolNewModel
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("rotcasuoicilam/SuperCoolNewModel")

rotcasuoicilam/SuperCoolNewModel contains malicious custom components, user downloads the Hub repo assuming it is safe, Diffusers loads the custom components without the user's consent.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get it.

hf download rotcasuoicilam/SuperCoolNewModel --local-dir rotcasuoicilam/SuperCoolNewModel

is agnostic to DiffusionPipeline.from_pretrained(...).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. A malicious actor uploads a model with malicious custom components
  2. Either:
    2a. The user follows instructions that say to download the model first then run from the local path
    2b. The user chooses to download the model first out of personal preference
  3. DiffusionPipeline.from_pretrained(the_local_path)
  4. pwned

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test case for this scenario then.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c0e0731

On main:

FAILED tests/pipelines/test_pipelines.py::CustomPipelineTests::test_custom_components_from_local_dir - AssertionError: ValueError not raised

PR:

1 passed

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit more elaborate explanation for other reviewers (feel free to correct).

The critical branching point is:

if not os.path.isdir(pretrained_model_name_or_path):                                                                                     
    # ... calls cls.download() which had the trust_remote_code check                                                                     
else:                                                                                                                                    
    cached_folder = pretrained_model_name_or_path  

When you call from_pretrained("rotcasuoicilam/SuperCoolNewModel"):

  1. os.path.isdir("rotcasuoicilam/SuperCoolNewModel") is checked.
  2. If the user previously ran hf download ... --local-dir rotcasuoicilam/SuperCoolNewModel, that directory exists locally.
  3. So os.path.isdir() returns True, and the code takes the else branch at line 871 — it just sets cached_folder = pretrained_model_name_or_path directly.
  4. The download() method is never called.

The old trust_remote_code check for custom components lived inside download(). Since download() is skipped entirely when the path is a
local directory, the check never runs. The custom components in that local folder get loaded without any consent gate.

That's why the fix moves the trust_remote_code check into get_cached_module_file — that's where the actual import of custom .py files
happens, and it runs regardless of whether the code came through download() or the local else branch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well just to be clear, using the same local-dir as Hub repo is a possible trick to hide the attack as anyone who didn't pre-download wouldn't be affected but any local directory is affected, and the local directory could be from other sources like snapshot_download or git clone.


if load_pipe_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
Expand All @@ -1711,6 +1700,7 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike:
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
trust_remote_code=trust_remote_code,
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
Expand Down Expand Up @@ -2127,13 +2117,16 @@ def from_pipe(cls, pipeline, **kwargs):

original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
trust_remote_code = kwargs.pop("trust_remote_code", False)

# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)

if custom_pipeline is not None:
pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
pipeline_class = _get_custom_pipeline_class(
custom_pipeline, revision=custom_revision, trust_remote_code=trust_remote_code
)
else:
pipeline_class = cls

Expand Down
31 changes: 31 additions & 0 deletions src/diffusers/utils/dynamic_modules_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def get_cached_module_file(
revision: str | None = None,
local_files_only: bool = False,
local_dir: str | None = None,
trust_remote_code: bool = False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add the ValueError to the caller sites of get_cached_module_file instead? Because the function itself isn't specifically tied to custom pipelines, I think.

):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
Expand Down Expand Up @@ -289,6 +290,10 @@ def get_cached_module_file(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.

> [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
[gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Expand All @@ -307,6 +312,12 @@ def get_cached_module_file(
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
submodule = "local"
if not trust_remote_code:
raise ValueError(
f"The directory {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly "
f"load the model. You can inspect the file content at {module_file_or_url}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
elif pretrained_model_name_or_path.count("/") == 0:
available_versions = get_diffusers_versions()
# cut ".dev0"
Expand All @@ -326,6 +337,13 @@ def get_cached_module_file(
f" {', '.join(available_versions + ['main'])}."
)

if not trust_remote_code:
raise ValueError(
f"The community pipeline for {pretrained_model_name_or_path} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/datasets/{COMMUNITY_PIPELINES_MIRROR_ID}/blob/main/{revision}/{pretrained_model_name_or_path}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

try:
resolved_module_file = hf_hub_download(
repo_id=COMMUNITY_PIPELINES_MIRROR_ID,
Expand All @@ -349,6 +367,12 @@ def get_cached_module_file(
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
else:
if not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name_or_path}/blob/main/{module_file}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
try:
# Load from URL or cache if already cached
resolved_module_file = hf_hub_download(
Expand Down Expand Up @@ -426,6 +450,7 @@ def get_cached_module_file(
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
trust_remote_code=trust_remote_code,
)
return os.path.join(full_submodule, module_file)

Expand All @@ -443,6 +468,7 @@ def get_class_from_dynamic_module(
revision: str | None = None,
local_files_only: bool = False,
local_dir: str | None = None,
trust_remote_code: bool = False,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
Expand Down Expand Up @@ -482,6 +508,10 @@ def get_class_from_dynamic_module(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.

> [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
[gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Expand All @@ -508,5 +538,6 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
trust_remote_code=trust_remote_code,
)
return get_class_in_module(class_name, final_module)
65 changes: 59 additions & 6 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,8 +1011,15 @@ def test_get_pipeline_class_from_flax(self):

class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
with self.assertRaises(ValueError):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we investigate the ValueError messaging as well (it should have something related to the use of trust_remote_code or not something else)?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And on main, this should not have yielded a ValueError, right? That is how we know, for one instance, that it's broken.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On main:

pretrained custom_pipeline trust_remote_code?
hub/repoA my_pipeline
hub/repoA one_step_unet[1]
hub/repoA hub/repoB
any local directory any

PR:

pretrained custom_pipeline trust_remote_code?
hub/repoA my_pipeline
hub/repoA one_step_unet[1]
hub/repoA hub/repoB
any local directory any

[1] or any community pipeline name

This case is more implicit vs explicit consent, but on main there is potential for misuse by combining the "trusted" nature of community pipeline names and third party Hub repos.

A user may copy an example like:

pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")

or

pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="pipeline_stable_diffusion_xl_controlnet_adapter_inpaint")

or

pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="pipeline_stable_diffusion_x/_controlnet_adapter_inpaint")

The first two are harmless from diffusers/community-pipelines-mirror, the third is malicious with a user registered as pipeline_stable_diffusion_x with a repo name _controlnet_adapter_inpaint. There are many community pipelines so many potential username/repo name combinations that could easily be missed.

Considering that I think community pipeline names should remain trusted, WDYT? We can just remove this to do so.

if not trust_remote_code:
raise ValueError(
f"The community pipeline for {pretrained_model_name_or_path} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/datasets/{COMMUNITY_PIPELINES_MIRROR_ID}/blob/main/{revision}/{pretrained_model_name_or_path}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)

pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
"google/ddpm-cifar10-32",
custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline",
trust_remote_code=True,
)
pipeline = pipeline.to(torch_device)
# NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
Expand All @@ -1021,7 +1028,10 @@ def test_load_custom_pipeline(self):

def test_load_custom_github(self):
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="main"
"google/ddpm-cifar10-32",
custom_pipeline="one_step_unet",
custom_revision="main",
trust_remote_code=True,
)

# make sure that on "main" pipeline gives only ones because of: https://github.com/huggingface/diffusers/pull/1690
Expand All @@ -1035,7 +1045,10 @@ def test_load_custom_github(self):
del sys.modules["diffusers_modules.git.one_step_unet"]

pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="0.10.2"
"google/ddpm-cifar10-32",
custom_pipeline="one_step_unet",
custom_revision="0.10.2",
trust_remote_code=True,
)
with torch.no_grad():
output = pipeline()
Expand All @@ -1045,8 +1058,15 @@ def test_load_custom_github(self):
assert pipeline.__class__.__name__ == "UnetSchedulerOneForwardPipeline"

def test_run_custom_pipeline(self):
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)

pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
"google/ddpm-cifar10-32",
custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline",
trust_remote_code=True,
)
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
Expand Down Expand Up @@ -1089,6 +1109,37 @@ def test_remote_components(self):

assert images.shape == (1, 64, 64, 3)

def test_custom_components_from_local_dir(self):
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
path = snapshot_download("hf-internal-testing/tiny-sdxl-custom-components", cache_dir=tmpdirname)
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained(path)

# Check that only loading custom components "my_unet", "my_scheduler" works
pipeline = DiffusionPipeline.from_pretrained(path, trust_remote_code=True)

assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

assert images.shape == (1, 64, 64, 3)

# Check that only loading custom components "my_unet", "my_scheduler" and explicit custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(path, custom_pipeline="my_pipeline", trust_remote_code=True)

assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

assert images.shape == (1, 64, 64, 3)

def test_remote_auto_custom_pipe(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -1126,7 +1177,7 @@ def test_remote_custom_pipe_with_dot_in_name(self):
def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path, trust_remote_code=True
)
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
Expand All @@ -1140,7 +1191,9 @@ def test_local_custom_pipeline_file(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py")
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
"google/ddpm-cifar10-32",
custom_pipeline=local_custom_pipeline_path,
trust_remote_code=True,
)
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
Expand Down
Loading