diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py index f7304df14048..db78e41f9562 100644 --- a/tests/models/autoencoders/test_models_autoencoder_magvit.py +++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py @@ -13,24 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import pytest +import torch from diffusers import AutoencoderKLMagvit +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin -from .testing_utils import AutoencoderTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin +from .testing_utils import NewAutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): - model_class = AutoencoderKLMagvit - main_input_name = "sample" - base_precision = 1e-2 +class AutoencoderKLMagvitTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return AutoencoderKLMagvit + + @property + def main_input_name(self) -> str: + return "sample" - def get_autoencoder_kl_magvit_config(self): + @property + def output_shape(self) -> tuple: + return (3, 9, 16, 16) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: return { "in_channels": 3, "latent_channels": 4, @@ -53,45 +67,48 @@ def get_autoencoder_kl_magvit_config(self): "spatial_group_norm": True, } - @property - def dummy_input(self): + def get_dummy_inputs(self) -> dict: batch_size = 2 num_frames = 9 num_channels = 3 height = 16 width = 16 - - image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device) - + image = randn_tensor( + (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device + ) return {"sample": image} - @property - def input_shape(self): - return (3, 9, 16, 16) - @property - def output_shape(self): - return (3, 9, 16, 16) +class TestAutoencoderKLMagvit(AutoencoderKLMagvitTesterConfig, ModelTesterMixin): + base_precision = 1e-2 - def prepare_init_args_and_inputs_for_common(self): - init_dict = self.get_autoencoder_kl_magvit_config() - inputs_dict = self.dummy_input - return init_dict, inputs_dict + +class TestAutoencoderKLMagvitTraining(AutoencoderKLMagvitTesterConfig, TrainingTesterMixin): + """Training tests for AutoencoderKLMagvit.""" def test_gradient_checkpointing_is_applied(self): expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - @unittest.skip("Not quite sure why this test fails. Revisit later.") - def test_effective_gradient_checkpointing(self): - pass + @pytest.mark.skip("Not quite sure why this test fails. Revisit later.") + def test_gradient_checkpointing_equivalence(self): + super().test_gradient_checkpointing_equivalence() + + +class TestAutoencoderKLMagvitMemory(AutoencoderKLMagvitTesterConfig, MemoryTesterMixin): + """Memory optimization tests for AutoencoderKLMagvit.""" + + +class TestAutoencoderKLMagvitSlicingTiling(AutoencoderKLMagvitTesterConfig, NewAutoencoderTesterMixin): + """Slicing and tiling tests for AutoencoderKLMagvit.""" - @unittest.skip("Unsupported test.") + @pytest.mark.skip("Unsupported test.") def test_forward_with_norm_groups(self): - pass + super().test_forward_with_norm_groups() - @unittest.skip( - "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list." + @pytest.mark.skip( + "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. " + "Expected size 9 but got size 12 for tensor number 1 in the list." ) def test_enable_disable_slicing(self): - pass + super().test_enable_disable_slicing()