Skip to content

Commit 4583c12

Browse files
committed
perf(loss): 将边界框坐标重置回标准化之前计算IoU
1 parent 3e78f16 commit 4583c12

File tree

1 file changed

+53
-12
lines changed

1 file changed

+53
-12
lines changed

py/lib/models/multi_part_loss.py

+53-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class MultiPartLoss(nn.Module):
2121

22-
def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
22+
def __init__(self, img_w, img_h, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
2323
super(MultiPartLoss, self).__init__()
2424
self.S = S
2525
self.B = B
@@ -28,6 +28,12 @@ def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
2828
self.coord = lambda_coord
2929
self.noobj = lambda_noobj
3030

31+
self.img_w = img_w
32+
self.img_h = img_h
33+
34+
self.grid_w = img_w / S
35+
self.grid_h = img_h / S
36+
3137
def forward(self, preds, targets):
3238
"""
3339
:param preds: (N, S*S, B*5+C) 其中
@@ -172,10 +178,8 @@ def _process3(self, preds, targets):
172178
# [N, S*S, B] -> [N*S*S, B]
173179
pred_confidences = preds[:, :, self.C: (self.B + self.C)].reshape(-1, self.B)
174180
# 提取每个网格的预测边界框坐标
175-
# [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
176-
pred_bboxs = preds[:, :, (self.B + self.C): (self.B * 5 + self.C)] \
177-
.reshape(-1, self.B * 4) \
178-
.reshape(-1, self.B, 4)
181+
# [N, S*S, B*4] -> [N, S*S, B, 4]
182+
pred_bboxs = preds[:, :, (self.B + self.C): (self.B * 5 + self.C)].reshape(N, self.S * self.S, self.B, 4)
179183

180184
## 目标
181185
# 提取每个网格的分类概率
@@ -185,18 +189,20 @@ def _process3(self, preds, targets):
185189
# [N, S*S, B] -> [N*S*S, B]
186190
target_confidences = targets[:, :, self.C: (self.B + self.C)].reshape(-1, self.B)
187191
# 提取每个网格的边界框坐标
188-
# [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
189-
target_bboxs = targets[:, :, (self.B + self.C): (self.B * 5 + self.C)] \
190-
.reshape(-1, self.B * 4) \
191-
.reshape(-1, self.B, 4)
192+
# [N, S*S, B*4] -> [N, S*S, B, 4]
193+
target_bboxs = targets[:, :, (self.B + self.C): (self.B * 5 + self.C)].reshape(N, self.S * self.S, self.B, 4)
192194

193195
## 首先计算所有边界框的置信度损失(假定不存在obj)
194196
loss = self.noobj * self.sum_squared_error(pred_confidences, target_confidences)
195197

196198
# 计算每个预测边界框与对应目标边界框的IoU
197-
iou_scores = self.iou(pred_bboxs.reshape(-1, 4), target_bboxs.reshape(-1, 4)).reshape(-1, 2)
199+
# [N*S*S*B]
200+
iou_scores = self.compute_ious(pred_bboxs.clone(), target_bboxs.clone())
201+
# [N, S*S, B, 4] -> [N*S*S, B, 4]
202+
pred_bboxs = pred_bboxs.reshape(-1, self.B, 4)
203+
target_bboxs = target_bboxs.reshape(-1, self.B, 4)
198204
# 选取每个网格中IoU最高的边界框
199-
top_idxs = torch.argmax(iou_scores, dim=1)
205+
top_idxs = torch.argmax(iou_scores.reshape(-1, self.B), dim=1)
200206
top_len = len(top_idxs)
201207
# 获取相应的置信度以及边界框
202208
top_pred_confidences = pred_confidences[range(top_len), top_idxs]
@@ -247,6 +253,41 @@ def bbox_loss(self, pred_boxs, target_boxs):
247253

248254
return loss
249255

256+
def compute_ious(self, pred_boxs, target_boxs):
257+
"""
258+
将边界框变形回标准化之前,然后计算IoU
259+
:param pred_boxs: [N, S*S, B, 4]
260+
:param target_boxs: [N, S*S, B, 4]
261+
:return: [N*S*S*B]
262+
"""
263+
N = pred_boxs.shape[0]
264+
for i in range(N):
265+
for j in range(self.S * self.S):
266+
col = j % self.S
267+
row = int(j / self.S)
268+
for k in range(self.B):
269+
pred_box = pred_boxs[i, j, k]
270+
target_box = target_boxs[i, j, k]
271+
272+
# 变形会标准化之前
273+
# x_center
274+
pred_box[0] = (pred_box[0] + col) * self.grid_w
275+
target_box[0] = (target_box[0] + col) * self.grid_w
276+
# y_center
277+
pred_box[1] = (pred_box[1] + row) * self.grid_h
278+
target_box[1] = (target_box[1] + row) * self.grid_h
279+
# w
280+
pred_box[2] = pred_box[2] * self.img_w
281+
target_box[2] = target_box[2] * self.img_w
282+
# h
283+
pred_box[3] = pred_box[3] * self.img_h
284+
target_box[3] = target_box[3] * self.img_h
285+
286+
pred_boxs = pred_boxs.reshape(-1, 4)
287+
target_boxs = target_boxs.reshape(-1, 4)
288+
289+
return self.iou(pred_boxs, target_boxs)
290+
250291
def iou(self, pred_boxs, target_boxs):
251292
"""
252293
计算候选建议和标注边界框的IoU
@@ -291,7 +332,7 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
291332
C = 3
292333
cate_list = ['cucumber', 'eggplant', 'mushroom']
293334

294-
criterion = MultiPartLoss(S=7, B=2, C=3)
335+
criterion = MultiPartLoss(448, 448, S=7, B=2, C=3)
295336
# preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
296337
# targets = torch.ones((1, 7 * 7, 13)) * 0.01
297338
# loss = criterion(preds, targets)

0 commit comments

Comments
 (0)