From 2e09ed0976247065d975d26ec5af21a3316bd555 Mon Sep 17 00:00:00 2001 From: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com> Date: Fri, 19 Jun 2026 20:23:56 +0100 Subject: [PATCH] fix(networks): return mu from VAE reparameterize at inference VarAutoEncoder.reparameterize added the standard deviation to mu at eval time (`std.add_(mu)` with no noise term), so inference returned mu + std instead of the posterior mean. At inference the latent code should be mu; the random term belongs only to training (the reparameterization trick). Return mu directly when not training, and compute mu + eps * std out-of-place otherwise. VarFullyConnectedNet.reparameterize had the identical bug and is fixed the same way. Adds regression tests asserting eval is deterministic and equals mu while training stays stochastic. Fixes #8413. Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com> --- monai/networks/nets/fullyconnectednet.py | 11 +++++----- monai/networks/nets/varautoencoder.py | 11 +++++----- tests/networks/nets/test_fullyconnectednet.py | 21 ++++++++++++++++++ tests/networks/nets/test_varautoencoder.py | 22 +++++++++++++++++++ 4 files changed, 55 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index be179e5b59..8c331fb38d 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -172,12 +172,13 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + if not self.training: + # At inference the latent code is the posterior mean; the random + # term is only added during training (the reparameterization trick). + return mu std = torch.exp(0.5 * logvar) - - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + eps = torch.randn_like(std) + return mu + eps * std def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 0674094aa7..9a8110aae9 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -142,12 +142,13 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + if not self.training: + # At inference the latent code is the posterior mean; the random + # term is only added during training (the reparameterization trick). + return mu std = torch.exp(0.5 * logvar) - - if self.training: # multiply random noise with std only during training - std = torch.randn_like(std).mul(std) - - return std.add_(mu) + eps = torch.randn_like(std) + return mu + eps * std def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) diff --git a/tests/networks/nets/test_fullyconnectednet.py b/tests/networks/nets/test_fullyconnectednet.py index 863d1399a9..0eca66a0f2 100644 --- a/tests/networks/nets/test_fullyconnectednet.py +++ b/tests/networks/nets/test_fullyconnectednet.py @@ -64,6 +64,27 @@ def test_vfc_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) + def test_vfc_reparameterize_eval_returns_mu(self): + # At eval the latent code must equal mu (deterministic); at train it must + # be stochastic. Same #8413 reparameterize bug as VarAutoEncoder. + net = VarFullyConnectedNet( + in_channels=10, out_channels=10, latent_size=30, encode_channels=(15, 20, 25), decode_channels=(15, 20, 25) + ).to(device) + data = torch.randn(3, 10).to(device) + + with eval_mode(net): + _, mu1, _, z1 = net(data) + _, _, _, z2 = net(data) + self.assertTrue(torch.allclose(z1, mu1)) + self.assertTrue(torch.allclose(z1, z2)) + + net.train() + with torch.no_grad(): + _, mu_t, _, zt1 = net(data) + _, _, _, zt2 = net(data) + self.assertFalse(torch.allclose(zt1, mu_t)) + self.assertFalse(torch.allclose(zt1, zt2)) + if __name__ == "__main__": unittest.main() diff --git a/tests/networks/nets/test_varautoencoder.py b/tests/networks/nets/test_varautoencoder.py index 459c537c55..d24ece1659 100644 --- a/tests/networks/nets/test_varautoencoder.py +++ b/tests/networks/nets/test_varautoencoder.py @@ -122,6 +122,28 @@ def test_script(self): test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data, rtol=1e-3, atol=1e-3) + def test_reparameterize_eval_returns_mu(self): + # At eval the latent code must equal mu (deterministic, no noise added); + # at train it must be stochastic. Regression test for #8413, where eval + # returned mu + std. + net = VarAutoEncoder( + spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=4, channels=(4, 8), strides=(2, 2) + ).to(device) + data = torch.randn(2, 1, 32, 32).to(device) + + with eval_mode(net): + _, mu1, _, z1 = net(data) + _, _, _, z2 = net(data) + self.assertTrue(torch.allclose(z1, mu1)) + self.assertTrue(torch.allclose(z1, z2)) + + net.train() + with torch.no_grad(): + _, mu_t, _, zt1 = net(data) + _, _, _, zt2 = net(data) + self.assertFalse(torch.allclose(zt1, mu_t)) + self.assertFalse(torch.allclose(zt1, zt2)) + if __name__ == "__main__": unittest.main()