diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index be179e5b59..ace07e9cc3 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -172,12 +172,11 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - std = torch.exp(0.5 * logvar) + if self.training: # reparameterization trick only during training + std = torch.exp(0.5 * logvar) + return mu + torch.randn_like(std) * std - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + return mu def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 0674094aa7..28fbded86e 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -142,12 +142,11 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - std = torch.exp(0.5 * logvar) + if self.training: # reparameterization trick only during training + std = torch.exp(0.5 * logvar) + return mu + torch.randn_like(std) * std - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + return mu def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x)