Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def forward(
target_is_real = True # With generator, we always want this to be true!
warnings.warn(
"Variable target_is_real has been set to False, but for_discriminator is set"
"to False. To optimise a generator, target_is_real must be set to True."
"to False. To optimise a generator, target_is_real must be set to True.",
stacklevel=2,
)

if not isinstance(input, list):
Expand Down
14 changes: 7 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

Expand All @@ -165,13 +165,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# if skipping background, removing first channel
target = target[:, 1:]
Expand Down Expand Up @@ -405,7 +405,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

Expand All @@ -414,13 +414,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# if skipping background, removing first channel
target = target[:, 1:]
Expand Down Expand Up @@ -987,7 +987,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)
dice_loss = self.dice(input, target)
Expand Down
4 changes: 2 additions & 2 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# if skipping background, removing first channel
target = target[:, 1:]
Expand Down
6 changes: 3 additions & 3 deletions monai/losses/hausdorff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

Expand All @@ -163,13 +163,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# If skipping background, removing first channel
target = target[:, 1:]
Expand Down
6 changes: 3 additions & 3 deletions monai/losses/mcc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

Expand All @@ -142,13 +142,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
target = target[:, 1:]
input = input[:, 1:]
Expand Down
5 changes: 3 additions & 2 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
)
if not channel_wise:
warnings.warn(
"MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2
"MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.", stacklevel=2
)

# Channel-wise only for MedicalNet
Expand All @@ -127,7 +127,8 @@ def __init__(
torch.hub.set_dir(cache_dir)
# raise a warning that this may change the default cache dir for all torch.hub calls
warnings.warn(
f"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls."
f"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls.",
stacklevel=2,
)

self.spatial_dims = spatial_dims
Expand Down
8 changes: 5 additions & 3 deletions monai/losses/spatial_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
mask: the shape should be B1H[WD] or 11H[WD].
"""
if mask is None:
warnings.warn("No mask value specified for the MaskedLoss.")
warnings.warn("No mask value specified for the MaskedLoss.", stacklevel=2)
return self.loss(input, target)

if input.dim() != mask.dim():
warnings.warn(f"Dim of input ({input.shape}) is different from mask ({mask.shape}).")
warnings.warn(f"Dim of input ({input.shape}) is different from mask ({mask.shape}).", stacklevel=2)
if input.shape[0] != mask.shape[0] and mask.shape[0] != 1:
raise ValueError(f"Batch size of mask ({mask.shape}) must be one or equal to input ({input.shape}).")
if target.dim() > 1:
if mask.shape[1] != 1:
raise ValueError(f"Mask ({mask.shape}) must have only one channel.")
if input.shape[2:] != mask.shape[2:]:
warnings.warn(f"Spatial size of input ({input.shape}) is different from mask ({mask.shape}).")
warnings.warn(
f"Spatial size of input ({input.shape}) is different from mask ({mask.shape}).", stacklevel=2
)
return self.loss(input * mask, target * mask)
6 changes: 3 additions & 3 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, 1)

Expand All @@ -127,13 +127,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# if skipping background, removing first channel
target = target[:, 1:]
Expand Down
6 changes: 3 additions & 3 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand Down Expand Up @@ -223,7 +223,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand Down
Loading