Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions tests/models/autoencoders/test_models_autoencoder_magvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLMagvit
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLMagvitTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLMagvit

@property
def main_input_name(self) -> str:
return "sample"

def get_autoencoder_kl_magvit_config(self):
@property
def output_shape(self) -> tuple:
return (3, 9, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"in_channels": 3,
"latent_channels": 4,
Expand All @@ -53,45 +67,48 @@ def get_autoencoder_kl_magvit_config(self):
"spatial_group_norm": True,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_frames = 9
num_channels = 3
height = 16
width = 16

image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)

image = randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
)
return {"sample": image}

@property
def input_shape(self):
return (3, 9, 16, 16)

@property
def output_shape(self):
return (3, 9, 16, 16)
class TestAutoencoderKLMagvit(AutoencoderKLMagvitTesterConfig, ModelTesterMixin):
base_precision = 1e-2

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_magvit_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

class TestAutoencoderKLMagvitTraining(AutoencoderKLMagvitTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLMagvit."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Not quite sure why this test fails. Revisit later.")
def test_effective_gradient_checkpointing(self):
pass
@pytest.mark.skip("Not quite sure why this test fails. Revisit later.")
def test_gradient_checkpointing_equivalence(self):
super().test_gradient_checkpointing_equivalence()


class TestAutoencoderKLMagvitMemory(AutoencoderKLMagvitTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLMagvit."""


class TestAutoencoderKLMagvitSlicingTiling(AutoencoderKLMagvitTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLMagvit."""

@unittest.skip("Unsupported test.")
@pytest.mark.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
super().test_forward_with_norm_groups()

@unittest.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
@pytest.mark.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. "
"Expected size 9 but got size 12 for tensor number 1 in the list."
)
def test_enable_disable_slicing(self):
pass
super().test_enable_disable_slicing()
Loading