From f74108c1bfafa3588e034df7a3b6b501817d9950 Mon Sep 17 00:00:00 2001 From: rtmalikian Date: Thu, 18 Jun 2026 14:47:38 -0700 Subject: [PATCH] fix: return mu (not mu+std) in reparameterize at eval time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In VarAutoEncoder and VarFullyConnectedNet, the reparameterize method computed and always returned . At training time, was correctly replaced with random noise scaled by the standard deviation. But at eval time, the raw standard deviation was still added to the mean — giving instead of just . The reparameterization trick (Kingma & Welling, 2014) defines: - Training: z = μ + ε · σ (ε ~ N(0,1)) - Inference: z = μ (no stochastic component) This fix restructures the method to return directly when not training, avoiding the unnecessary computation and eliminating the incorrect result at inference. Fixes #8413 Signed-off-by: Raphael Malikian --- monai/networks/nets/fullyconnectednet.py | 9 ++++----- monai/networks/nets/varautoencoder.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index be179e5b59..ace07e9cc3 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -172,12 +172,11 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - std = torch.exp(0.5 * logvar) + if self.training: # reparameterization trick only during training + std = torch.exp(0.5 * logvar) + return mu + torch.randn_like(std) * std - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + return mu def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 0674094aa7..28fbded86e 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -142,12 +142,11 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - std = torch.exp(0.5 * logvar) + if self.training: # reparameterization trick only during training + std = torch.exp(0.5 * logvar) + return mu + torch.randn_like(std) * std - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + return mu def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x)