19
19
20
20
class MultiPartLoss (nn .Module ):
21
21
22
- def __init__ (self , S = 7 , B = 2 , C = 20 , lambda_coord = 5 , lambda_noobj = 0.5 ):
22
+ def __init__ (self , img_w , img_h , S = 7 , B = 2 , C = 20 , lambda_coord = 5 , lambda_noobj = 0.5 ):
23
23
super (MultiPartLoss , self ).__init__ ()
24
24
self .S = S
25
25
self .B = B
@@ -28,6 +28,12 @@ def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
28
28
self .coord = lambda_coord
29
29
self .noobj = lambda_noobj
30
30
31
+ self .img_w = img_w
32
+ self .img_h = img_h
33
+
34
+ self .grid_w = img_w / S
35
+ self .grid_h = img_h / S
36
+
31
37
def forward (self , preds , targets ):
32
38
"""
33
39
:param preds: (N, S*S, B*5+C) 其中
@@ -172,10 +178,8 @@ def _process3(self, preds, targets):
172
178
# [N, S*S, B] -> [N*S*S, B]
173
179
pred_confidences = preds [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
174
180
# 提取每个网格的预测边界框坐标
175
- # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
176
- pred_bboxs = preds [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
177
- .reshape (- 1 , self .B * 4 ) \
178
- .reshape (- 1 , self .B , 4 )
181
+ # [N, S*S, B*4] -> [N, S*S, B, 4]
182
+ pred_bboxs = preds [:, :, (self .B + self .C ): (self .B * 5 + self .C )].reshape (N , self .S * self .S , self .B , 4 )
179
183
180
184
## 目标
181
185
# 提取每个网格的分类概率
@@ -185,18 +189,20 @@ def _process3(self, preds, targets):
185
189
# [N, S*S, B] -> [N*S*S, B]
186
190
target_confidences = targets [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
187
191
# 提取每个网格的边界框坐标
188
- # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
189
- target_bboxs = targets [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
190
- .reshape (- 1 , self .B * 4 ) \
191
- .reshape (- 1 , self .B , 4 )
192
+ # [N, S*S, B*4] -> [N, S*S, B, 4]
193
+ target_bboxs = targets [:, :, (self .B + self .C ): (self .B * 5 + self .C )].reshape (N , self .S * self .S , self .B , 4 )
192
194
193
195
## 首先计算所有边界框的置信度损失(假定不存在obj)
194
196
loss = self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
195
197
196
198
# 计算每个预测边界框与对应目标边界框的IoU
197
- iou_scores = self .iou (pred_bboxs .reshape (- 1 , 4 ), target_bboxs .reshape (- 1 , 4 )).reshape (- 1 , 2 )
199
+ # [N*S*S*B]
200
+ iou_scores = self .compute_ious (pred_bboxs .clone (), target_bboxs .clone ())
201
+ # [N, S*S, B, 4] -> [N*S*S, B, 4]
202
+ pred_bboxs = pred_bboxs .reshape (- 1 , self .B , 4 )
203
+ target_bboxs = target_bboxs .reshape (- 1 , self .B , 4 )
198
204
# 选取每个网格中IoU最高的边界框
199
- top_idxs = torch .argmax (iou_scores , dim = 1 )
205
+ top_idxs = torch .argmax (iou_scores . reshape ( - 1 , self . B ) , dim = 1 )
200
206
top_len = len (top_idxs )
201
207
# 获取相应的置信度以及边界框
202
208
top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
@@ -247,6 +253,41 @@ def bbox_loss(self, pred_boxs, target_boxs):
247
253
248
254
return loss
249
255
256
+ def compute_ious (self , pred_boxs , target_boxs ):
257
+ """
258
+ 将边界框变形回标准化之前,然后计算IoU
259
+ :param pred_boxs: [N, S*S, B, 4]
260
+ :param target_boxs: [N, S*S, B, 4]
261
+ :return: [N*S*S*B]
262
+ """
263
+ N = pred_boxs .shape [0 ]
264
+ for i in range (N ):
265
+ for j in range (self .S * self .S ):
266
+ col = j % self .S
267
+ row = int (j / self .S )
268
+ for k in range (self .B ):
269
+ pred_box = pred_boxs [i , j , k ]
270
+ target_box = target_boxs [i , j , k ]
271
+
272
+ # 变形会标准化之前
273
+ # x_center
274
+ pred_box [0 ] = (pred_box [0 ] + col ) * self .grid_w
275
+ target_box [0 ] = (target_box [0 ] + col ) * self .grid_w
276
+ # y_center
277
+ pred_box [1 ] = (pred_box [1 ] + row ) * self .grid_h
278
+ target_box [1 ] = (target_box [1 ] + row ) * self .grid_h
279
+ # w
280
+ pred_box [2 ] = pred_box [2 ] * self .img_w
281
+ target_box [2 ] = target_box [2 ] * self .img_w
282
+ # h
283
+ pred_box [3 ] = pred_box [3 ] * self .img_h
284
+ target_box [3 ] = target_box [3 ] * self .img_h
285
+
286
+ pred_boxs = pred_boxs .reshape (- 1 , 4 )
287
+ target_boxs = target_boxs .reshape (- 1 , 4 )
288
+
289
+ return self .iou (pred_boxs , target_boxs )
290
+
250
291
def iou (self , pred_boxs , target_boxs ):
251
292
"""
252
293
计算候选建议和标注边界框的IoU
@@ -291,7 +332,7 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
291
332
C = 3
292
333
cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
293
334
294
- criterion = MultiPartLoss (S = 7 , B = 2 , C = 3 )
335
+ criterion = MultiPartLoss (448 , 448 , S = 7 , B = 2 , C = 3 )
295
336
# preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
296
337
# targets = torch.ones((1, 7 * 7, 13)) * 0.01
297
338
# loss = criterion(preds, targets)
0 commit comments