-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Improve trust_remote_code
#13448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Improve trust_remote_code
#13448
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -787,6 +787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa | |
| quantization_config = kwargs.pop("quantization_config", None) | ||
| use_flashpack = kwargs.pop("use_flashpack", False) | ||
| disable_mmap = kwargs.pop("disable_mmap", False) | ||
| trust_remote_code = kwargs.pop("trust_remote_code", False) | ||
|
|
||
| if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): | ||
| torch_dtype = torch.float32 | ||
|
|
@@ -871,6 +872,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa | |
| variant=variant, | ||
| dduf_file=dduf_file, | ||
| load_connected_pipeline=load_connected_pipeline, | ||
| trust_remote_code=trust_remote_code, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
|
|
@@ -928,6 +930,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa | |
| class_name=custom_class_name, | ||
| cache_dir=cache_dir, | ||
| revision=custom_revision, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
|
|
||
| if device_map is not None and pipeline_class._load_connected_pipes: | ||
|
|
@@ -1077,6 +1080,7 @@ def load_module(name, value): | |
| disable_mmap=disable_mmap, | ||
| quantization_config=quantization_config, | ||
| use_flashpack=use_flashpack, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
| logger.info( | ||
| f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." | ||
|
|
@@ -1684,21 +1688,6 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike: | |
| custom_class_name = config_dict["_class_name"][1] | ||
|
|
||
| load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames | ||
| load_components_from_hub = len(custom_components) > 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this going?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See case 3 - this code is in Consider this scenario: hf download rotcasuoicilam/SuperCoolNewModel --local-dir rotcasuoicilam/SuperCoolNewModelfrom diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("rotcasuoicilam/SuperCoolNewModel")
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get it. hf download rotcasuoicilam/SuperCoolNewModel --local-dir rotcasuoicilam/SuperCoolNewModelis agnostic to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a test case for this scenario then.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On PR:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit more elaborate explanation for other reviewers (feel free to correct). The critical branching point is: if not os.path.isdir(pretrained_model_name_or_path):
# ... calls cls.download() which had the trust_remote_code check
else:
cached_folder = pretrained_model_name_or_path When you call
The old That's why the fix moves the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well just to be clear, using the same local-dir as Hub repo is a possible trick to hide the attack as anyone who didn't pre-download wouldn't be affected but any local directory is affected, and the local directory could be from other sources like |
||
|
|
||
| if load_pipe_from_hub and not trust_remote_code: | ||
| raise ValueError( | ||
| f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
|
|
||
| if load_components_from_hub and not trust_remote_code: | ||
| raise ValueError( | ||
| f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
|
|
||
| # retrieve passed components that should not be downloaded | ||
| pipeline_class = _get_pipeline_class( | ||
|
|
@@ -1711,6 +1700,7 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike: | |
| class_name=custom_class_name, | ||
| cache_dir=cache_dir, | ||
| revision=custom_revision, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
| expected_components, _ = cls._get_signature_keys(pipeline_class) | ||
| passed_components = [k for k in expected_components if k in kwargs] | ||
|
|
@@ -2127,13 +2117,16 @@ def from_pipe(cls, pipeline, **kwargs): | |
|
|
||
| original_config = dict(pipeline.config) | ||
| torch_dtype = kwargs.pop("torch_dtype", torch.float32) | ||
| trust_remote_code = kwargs.pop("trust_remote_code", False) | ||
|
|
||
| # derive the pipeline class to instantiate | ||
| custom_pipeline = kwargs.pop("custom_pipeline", None) | ||
| custom_revision = kwargs.pop("custom_revision", None) | ||
|
|
||
| if custom_pipeline is not None: | ||
| pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision) | ||
| pipeline_class = _get_custom_pipeline_class( | ||
| custom_pipeline, revision=custom_revision, trust_remote_code=trust_remote_code | ||
| ) | ||
| else: | ||
| pipeline_class = cls | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,6 +254,7 @@ def get_cached_module_file( | |
| revision: str | None = None, | ||
| local_files_only: bool = False, | ||
| local_dir: str | None = None, | ||
| trust_remote_code: bool = False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to add the |
||
| ): | ||
| """ | ||
| Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached | ||
|
|
@@ -289,6 +290,10 @@ def get_cached_module_file( | |
| identifier allowed by git. | ||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| If `True`, will only try to load the tokenizer configuration from local files. | ||
| trust_remote_code (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This | ||
| option should only be set to `True` for repositories you trust and in which you have read the code, as it | ||
| will execute code present on the Hub on your local machine. | ||
|
|
||
| > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or | ||
| [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models). | ||
|
|
@@ -307,6 +312,12 @@ def get_cached_module_file( | |
| if os.path.isfile(module_file_or_url): | ||
| resolved_module_file = module_file_or_url | ||
| submodule = "local" | ||
| if not trust_remote_code: | ||
| raise ValueError( | ||
| f"The directory {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly " | ||
| f"load the model. You can inspect the file content at {module_file_or_url}.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
| elif pretrained_model_name_or_path.count("/") == 0: | ||
| available_versions = get_diffusers_versions() | ||
| # cut ".dev0" | ||
|
|
@@ -326,6 +337,13 @@ def get_cached_module_file( | |
| f" {', '.join(available_versions + ['main'])}." | ||
| ) | ||
|
|
||
| if not trust_remote_code: | ||
| raise ValueError( | ||
| f"The community pipeline for {pretrained_model_name_or_path} contains custom code which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at https://hf.co/datasets/{COMMUNITY_PIPELINES_MIRROR_ID}/blob/main/{revision}/{pretrained_model_name_or_path}.py.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
|
|
||
| try: | ||
| resolved_module_file = hf_hub_download( | ||
| repo_id=COMMUNITY_PIPELINES_MIRROR_ID, | ||
|
|
@@ -349,6 +367,12 @@ def get_cached_module_file( | |
| logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | ||
| raise | ||
| else: | ||
| if not trust_remote_code: | ||
| raise ValueError( | ||
| f"The repository for {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name_or_path}/blob/main/{module_file}.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
| try: | ||
| # Load from URL or cache if already cached | ||
| resolved_module_file = hf_hub_download( | ||
|
|
@@ -426,6 +450,7 @@ def get_cached_module_file( | |
| revision=revision, | ||
| local_files_only=local_files_only, | ||
| local_dir=local_dir, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
| return os.path.join(full_submodule, module_file) | ||
|
|
||
|
|
@@ -443,6 +468,7 @@ def get_class_from_dynamic_module( | |
| revision: str | None = None, | ||
| local_files_only: bool = False, | ||
| local_dir: str | None = None, | ||
| trust_remote_code: bool = False, | ||
| ): | ||
| """ | ||
| Extracts a class from a module file, present in the local folder or repository of a model. | ||
|
|
@@ -482,6 +508,10 @@ def get_class_from_dynamic_module( | |
| identifier allowed by git. | ||
| local_files_only (`bool`, *optional*, defaults to `False`): | ||
| If `True`, will only try to load the tokenizer configuration from local files. | ||
| trust_remote_code (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This | ||
| option should only be set to `True` for repositories you trust and in which you have read the code, as it | ||
| will execute code present on the Hub on your local machine. | ||
|
|
||
| > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or | ||
| [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models). | ||
|
|
@@ -508,5 +538,6 @@ def get_class_from_dynamic_module( | |
| revision=revision, | ||
| local_files_only=local_files_only, | ||
| local_dir=local_dir, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
| return get_class_in_module(class_name, final_module) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1011,8 +1011,15 @@ def test_get_pipeline_class_from_flax(self): | |||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| class CustomPipelineTests(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||||||||||
| def test_load_custom_pipeline(self): | ||||||||||||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we investigate the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On
PR:
[1] or any community pipeline name This case is more A user may copy an example like: pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")or pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="pipeline_stable_diffusion_xl_controlnet_adapter_inpaint")or pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="pipeline_stable_diffusion_x/_controlnet_adapter_inpaint")The first two are harmless from Considering that I think community pipeline names should remain trusted, WDYT? We can just remove this to do so. diffusers/src/diffusers/utils/dynamic_modules_utils.py Lines 340 to 345 in 78a5028
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline", | ||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=True, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1021,7 +1028,10 @@ def test_load_custom_pipeline(self): | |||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_load_custom_github(self): | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="main" | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_pipeline="one_step_unet", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_revision="main", | ||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=True, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # make sure that on "main" pipeline gives only ones because of: https://github.com/huggingface/diffusers/pull/1690 | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1035,7 +1045,10 @@ def test_load_custom_github(self): | |||||||||||||||||||||||||||||||||||||||||||
| del sys.modules["diffusers_modules.git.one_step_unet"] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="0.10.2" | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_pipeline="one_step_unet", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_revision="0.10.2", | ||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=True, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||
| output = pipeline() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1045,8 +1058,15 @@ def test_load_custom_github(self): | |||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.__class__.__name__ == "UnetSchedulerOneForwardPipeline" | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_run_custom_pipeline(self): | ||||||||||||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline", | ||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=True, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| images, output_str = pipeline(num_inference_steps=2, output_type="np") | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1089,6 +1109,37 @@ def test_remote_components(self): | |||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert images.shape == (1, 64, 64, 3) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_custom_components_from_local_dir(self): | ||||||||||||||||||||||||||||||||||||||||||||
| with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: | ||||||||||||||||||||||||||||||||||||||||||||
| path = snapshot_download("hf-internal-testing/tiny-sdxl-custom-components", cache_dir=tmpdirname) | ||||||||||||||||||||||||||||||||||||||||||||
| # make sure that trust remote code has to be passed | ||||||||||||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained(path) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Check that only loading custom components "my_unet", "my_scheduler" works | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained(path, trust_remote_code=True) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel") | ||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler") | ||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline" | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| images = pipeline("test", num_inference_steps=2, output_type="np")[0] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert images.shape == (1, 64, 64, 3) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Check that only loading custom components "my_unet", "my_scheduler" and explicit custom pipeline works | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained(path, custom_pipeline="my_pipeline", trust_remote_code=True) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel") | ||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler") | ||||||||||||||||||||||||||||||||||||||||||||
| assert pipeline.__class__.__name__ == "MyPipeline" | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| images = pipeline("test", num_inference_steps=2, output_type="np")[0] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert images.shape == (1, 64, 64, 3) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_remote_auto_custom_pipe(self): | ||||||||||||||||||||||||||||||||||||||||||||
| # make sure that trust remote code has to be passed | ||||||||||||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1126,7 +1177,7 @@ def test_remote_custom_pipe_with_dot_in_name(self): | |||||||||||||||||||||||||||||||||||||||||||
| def test_local_custom_pipeline_repo(self): | ||||||||||||||||||||||||||||||||||||||||||||
| local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path, trust_remote_code=True | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| images, output_str = pipeline(num_inference_steps=2, output_type="np") | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1140,7 +1191,9 @@ def test_local_custom_pipeline_file(self): | |||||||||||||||||||||||||||||||||||||||||||
| local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") | ||||||||||||||||||||||||||||||||||||||||||||
| local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py") | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = DiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path | ||||||||||||||||||||||||||||||||||||||||||||
| "google/ddpm-cifar10-32", | ||||||||||||||||||||||||||||||||||||||||||||
| custom_pipeline=local_custom_pipeline_path, | ||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=True, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
| pipeline = pipeline.to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
| images, output_str = pipeline(num_inference_steps=2, output_type="np") | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to note this remains to control
repo_id:diffusers/src/diffusers/pipelines/pipeline_utils.py
Line 1699 in dc8d903
diffusers/src/diffusers/pipelines/pipeline_loading_utils.py
Lines 441 to 458 in dc8d903
It helps distinguish between:
a)
custom_pipelineis e.g.my_pipelineand that filename exists inpretrained_model_name's filesb)
custom_pipelineis Hub repo (andpipeline.pyis used)Maybe could be renamed
load_pipe_from_hub->hub_contains_custom_pipeline?