Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions monai/networks/nets/fullyconnectednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten
return x

def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
if not self.training:
# At inference the latent code is the posterior mean; the random
# term is only added during training (the reparameterization trick).
return mu
std = torch.exp(0.5 * logvar)

if self.training: # multiply random noise with std only during training
std = torch.randn_like(std).mul(std)

return std.add_(mu)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
mu, logvar = self.encode_forward(x)
Expand Down
11 changes: 6 additions & 5 deletions monai/networks/nets/varautoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,13 @@ def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Ten
return x

def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
if not self.training:
# At inference the latent code is the posterior mean; the random
# term is only added during training (the reparameterization trick).
return mu
std = torch.exp(0.5 * logvar)

if self.training: # multiply random noise with std only during training
std = torch.randn_like(std).mul(std)

return std.add_(mu)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
mu, logvar = self.encode_forward(x)
Expand Down
21 changes: 21 additions & 0 deletions tests/networks/nets/test_fullyconnectednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ def test_vfc_shape(self, input_param, input_shape, expected_shape):
result = net.forward(torch.randn(input_shape).to(device))[0]
self.assertEqual(result.shape, expected_shape)

def test_vfc_reparameterize_eval_returns_mu(self):
# At eval the latent code must equal mu (deterministic); at train it must
# be stochastic. Same #8413 reparameterize bug as VarAutoEncoder.
net = VarFullyConnectedNet(
in_channels=10, out_channels=10, latent_size=30, encode_channels=(15, 20, 25), decode_channels=(15, 20, 25)
).to(device)
data = torch.randn(3, 10).to(device)

with eval_mode(net):
_, mu1, _, z1 = net(data)
_, _, _, z2 = net(data)
self.assertTrue(torch.allclose(z1, mu1))
self.assertTrue(torch.allclose(z1, z2))

net.train()
with torch.no_grad():
_, mu_t, _, zt1 = net(data)
_, _, _, zt2 = net(data)
self.assertFalse(torch.allclose(zt1, mu_t))
self.assertFalse(torch.allclose(zt1, zt2))


if __name__ == "__main__":
unittest.main()
22 changes: 22 additions & 0 deletions tests/networks/nets/test_varautoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ def test_script(self):
test_data = torch.randn(2, 1, 32, 32)
test_script_save(net, test_data, rtol=1e-3, atol=1e-3)

def test_reparameterize_eval_returns_mu(self):
# At eval the latent code must equal mu (deterministic, no noise added);
# at train it must be stochastic. Regression test for #8413, where eval
# returned mu + std.
net = VarAutoEncoder(
spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=4, channels=(4, 8), strides=(2, 2)
).to(device)
data = torch.randn(2, 1, 32, 32).to(device)

with eval_mode(net):
_, mu1, _, z1 = net(data)
_, _, _, z2 = net(data)
self.assertTrue(torch.allclose(z1, mu1))
self.assertTrue(torch.allclose(z1, z2))

net.train()
with torch.no_grad():
_, mu_t, _, zt1 = net(data)
_, _, _, zt2 = net(data)
self.assertFalse(torch.allclose(zt1, mu_t))
self.assertFalse(torch.allclose(zt1, zt2))


if __name__ == "__main__":
unittest.main()
Loading