
极市导读
本文第一部分结合代码,从cornerNet的具体网络结构和损失函数开始,对cornerNet进行了详细解析。第二部分则聚焦于阐述如何将真实标签映射为监督信息(类似网络的输出格式)以及详细探讨了损失函数的定义。>>年度盘点:极市计算机视觉资源汇总,顶会论文、技术视频、数据集等(限时开放下载)
1.前言
-
1. 作为以目标检测领域入门深度学习的小白,如果目标检测领域没有接触到anchor free,显得我很业余,很不专业(当然,这仅仅是心理作用罢了) -
2. 接触一些目标追踪领域(MOT)的文章,遇到了如下的一些关键字:Objects as points、anchor-free等,这不为了学习MOT打一下基础嘛
-
1. 在知乎/CSDN上找相关的(解析)博客,进行初步印象的建立 -
2. 细读原文paper -
3. 开始跑代码,看代码,深入理解
2. 一些基础知识
-
1. anchor-free 的目标检测新思路 -
2. corner pooling的提出 -
3. cornerNet网络的提出
We propose CornerNet, a new approach to object detection where we detect an object bounding box as a pair of keypoints, the top-left corner and the bottom-right corner, using a single convolution neural network.
The network also predicts an embedding vector for each detected corner [27] such that the distance between the embeddings of two corners from the same object is small.
-
1. heatmaps -
2. embeddings -
3. offsets
We predict two sets of heatmaps, one for top-left corners and one for bottom- right corners. Each set of heatmaps has C channels, where C is the number of categories, and is of size H ×W. There is no background channel. Each channe is a binary mask indicating the locations of the corners for a class.
3. CornerNet网络结构
-
1. backbone: hourglass Network -
2. head: 二分支输出 Top-left corners 和 Bottom-right corners,每个分支包含了各自的corner pooling以及三分支输出
-
1. 在输入hourglass module之前,需要将图片分辨率降低为原来的1/4倍。本文采用了一个stride=2的7x7卷积和一个stride=2的残差单元进行图片分辨率降低。 -
2. 使用stride=2的卷积层代替max pooling进行downsample -
3. 共进行5次downsample ,这5次downsample后的特征图通道为[256,384,384,384,512] -
4. 采用最近邻插值的上采样(upsample),后面接两个残差单元
#在第一个hourglass module之前,用来降低图片分辨率为原来的1/4
self.pre = nn.Sequential(
convolution(7, 3, 128, stride=2),
residual(3, 128, 256, stride=2)
) if pre is None else pre
We apply a 3 × 3 Conv-BN module to both the input and output of the first hourglass module. We then merge them by element-wise addition followed by a ReLU and a residual block with 256 channels, which is then used as the input to the second hourglass module. The depth of the hourglass network is 104. Unlike many other state-of-the-art detectors, we only use the features from the last layer of the whole network to make predictions.
-
1. 在第一个hourglass module的输入和输出后都有一个3x3卷积层+BN层 -
2. 然后对残差连接后使用按照元素相加 -
3. 处理2完毕后,作为第二个hourglass module的输入 -
4. 预测的话,只选择总网络的最后一层特征图作为输入
class kp_module(nn.Module):
"""
一个简单的hourglass module结构
"""
def __init__(
self, n, dims, modules, layer=residual,
make_up_layer=make_layer, make_low_layer=make_layer,
make_hg_layer=make_layer, make_hg_layer_revr=make_layer_revr,
make_pool_layer=make_pool_layer, make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer, **kwargs
):
super(kp_module, self).__init__()
self.n = n #5
# modules = [2, 2, 2, 2, 2, 4],模块的数量
curr_mod = modules[0]
next_mod = modules[1]
# dims=[256, 256, 384, 384, 384, 512]
curr_dim = dims[0]
next_dim = dims[1]
self.up1 = make_up_layer(
3, curr_dim, curr_dim, curr_mod,
layer=layer, **kwargs
) #三个简单的layer(residual module),kernel_size=3
self.max1 = make_pool_layer(curr_dim) #MaxPool2d(kernel_size=2, stride=2)
self.low1 = make_hg_layer(
3, curr_dim, next_dim, curr_mod,
layer=layer, **kwargs
) #三个简单的layer(residual module),kernel_size=3
self.low2 = kp_module(
n - 1, dims[1:], modules[1:], layer=layer,
make_up_layer=make_up_layer,
make_low_layer=make_low_layer,
make_hg_layer=make_hg_layer,
make_hg_layer_revr=make_hg_layer_revr,
make_pool_layer=make_pool_layer,
make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer,
**kwargs
) if self.n > 1 else \
make_low_layer(
3, next_dim, next_dim, next_mod,
layer=layer, **kwargs
) #递归的思想,不断地降低n,知道n>1不满足
self.low3 = make_hg_layer_revr(
3, next_dim, curr_dim, curr_mod,
layer=layer, **kwargs
)
# nn.Upsample(scale_factor=2)
self.up2 = make_unpool_layer(curr_dim)
self.merge = make_merge_layer(curr_dim)
def forward(self, x):
up1 = self.up1(x)
max1 = self.max1(x)
low1 = self.low1(max1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return self.merge(up1, up2) #element-wise add
self.low2 = kp_module(
n - 1, dims[1:], modules[1:], layer=layer,
make_up_layer=make_up_layer,
make_low_layer=make_low_layer,
make_hg_layer=make_hg_layer,
make_hg_layer_revr=make_hg_layer_revr,
make_pool_layer=make_pool_layer,
make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer,
**kwargs
) if self.n > 1 else \
make_low_layer(
3, next_dim, next_dim, next_mod,
layer=layer, **kwargs
) #递归的思想,不断地降低n,知道n>1不满足
self.kps = nn.ModuleList([
kp_module(
n, dims, modules, layer=kp_layer,
make_up_layer=make_up_layer,
make_low_layer=make_low_layer,
make_hg_layer=make_hg_layer,
make_hg_layer_revr=make_hg_layer_revr,
make_pool_layer=make_pool_layer,
make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer
) for _ in range(nstack) #hourglass 网络,包含了nstack个模块
])
# 获得两个分支特征图,分别做左上点和右下点的预测的
tl_cnv = tl_cnv_(cnv)
br_cnv = br_cnv_(cnv)
-
1. corner pooling -
2. 三分支的输出
class tl_pool(pool):
def __init__(self, dim):
super(tl_pool, self).__init__(dim, TopPool, LeftPool)
class pool(nn.Module):
def __init__(self, dim, pool1, pool2):
super(pool, self).__init__()
self.p1_conv1 = convolution(3, dim, 128)
self.p2_conv1 = convolution(3, dim, 128)
self.p_conv1 = nn.Conv2d(128, dim, (3, 3), padding=(1, 1), bias=False)
self.p_bn1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, dim, (1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(dim)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = convolution(3, dim, dim)
self.pool1 = pool1()
self.pool2 = pool2()
def forward(self, x):
# pool 1
p1_conv1 = self.p1_conv1(x)
pool1 = self.pool1(p1_conv1)
# pool 2
p2_conv1 = self.p2_conv1(x)
pool2 = self.pool2(p2_conv1)
# pool 1 + pool 2
p_conv1 = self.p_conv1(pool1 + pool2)
p_bn1 = self.p_bn1(p_conv1)
# resudual connect
conv1 = self.conv1(x)
bn1 = self.bn1(conv1)
relu1 = self.relu1(p_bn1 + bn1)
conv2 = self.conv2(relu1)
return conv2
#对上面两个分支分别输出三个预测分支
tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv)
tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv)
tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv)
1.CornerNet损失函数深度解析
-
(1)将真实标签(即物体的类别和所在的位置)映射为监督信息(类似网络的输出格式) -
(2)根据网络前向过程的输出和(1)中的监督信息构建相应的损失函数 -
(3)根据损失函数进行梯度下降,更新网络参数
2. 实现的一些细节
2.1 如何将真实标签映射为监督信息(类似网络的输出格式)
-
(1)左上corner的heatmaps -
(2)左上corner的embedding -
(3)左上corner的offsets
-
(1)右下corner的heatmaps -
(2)右下corner的embedding -
(3)右下corner的offsets
During training, we set the input resolution of the network to 511×511, which leads to an output resolution of 128×128.
-
(1)左上corner的heatmaps,大小为(batch size,128,128,80) -
(2)左上corner的embedding,大小为(batch size,128,128,1) -
(3)左上corner的offsets,大小为(batch size,128,128,2) -
(4)右下corner的heatmaps,大小为(batch size,128,128,80) -
(5)右下corner的embedding,大小为(batch size,128,128,1) -
(6)右下corner的offsets,大小为(batch size,128,128,2)
def kp_detection(db, k_ind, data_aug, debug):
data_rng = system_configs.data_rng
batch_size = system_configs.batch_size
categories = db.configs["categories"] #80
input_size = db.configs["input_size"] #511
output_size = db.configs["output_sizes"][0] #[[128,128]]
border = db.configs["border"] #128
lighting = db.configs["lighting"] #True
rand_crop = db.configs["rand_crop"] #False
rand_color = db.configs["rand_color"] #False
rand_scales = db.configs["rand_scales"] #False
gaussian_bump = db.configs["gaussian_bump"] #True
gaussian_iou = db.configs["gaussian_iou"] #0.7
gaussian_rad = db.configs["gaussian_radius"] #-1
max_tag_len = 128 #一张图中最大可能的target数量
# allocating memory
images = np.zeros((batch_size, 3, input_size[0], input_size[1]), dtype=np.float32)
tl_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32)
br_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32)
tl_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
br_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
tl_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
br_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8)
tag_lens = np.zeros((batch_size, ), dtype=np.int32) # store the num of targets for every image in a batch images
db_size = db.db_inds.size
for b_ind in range(batch_size): #b_ind means the index of image in a batch
if not debug and k_ind == 0:
db.shuffle_inds()
db_ind = db.db_inds[k_ind]
k_ind = (k_ind + 1) % db_size
# reading image
image_file = db.image_file(db_ind)
image = cv2.imread(image_file)
# reading detections
detections = db.detections(db_ind)
# cropping an image randomly
if not debug and rand_crop:
image, detections = random_crop(image, detections, rand_scales, input_size, border=border)
else:
image, detections = _full_image_crop(image, detections)
image, detections = _resize_image(image, detections, input_size)
detections = _clip_detections(image, detections)
width_ratio = output_size[1] / input_size[1] # 缩放比例(宽)
height_ratio = output_size[0] / input_size[0] # 缩放比例 (高)
# flipping an image randomly
if not debug and np.random.uniform() > 0.5:
image[:] = image[:, ::-1, :]
width = image.shape[1]
detections[:, [0, 2]] = width - detections[:, [2, 0]] - 1
if not debug:
image = image.astype(np.float32) / 255.
if rand_color:
color_jittering_(data_rng, image)
if lighting:
lighting_(data_rng, image, 0.1, db.eig_val, db.eig_vec)
normalize_(image, db.mean, db.std)
images[b_ind] = image.transpose((2, 0, 1))
for ind, detection in enumerate(detections):
# prepare the ground_truth heatmap
category = int(detection[-1]) - 1 #get the detected target's category
xtl, ytl = detection[0], detection[1] # the coordinate of the left-top corner
xbr, ybr = detection[2], detection[3] # the coordinate of the right-bottom corner
fxtl = (xtl * width_ratio) # reflect the coordinate to the size of output feature map
fytl = (ytl * height_ratio)
fxbr = (xbr * width_ratio)
fybr = (ybr * height_ratio)
xtl = int(fxtl) #give the postion at which the corner actually located
ytl = int(fytl)
xbr = int(fxbr)
ybr = int(fybr)
if gaussian_bump:
# 使用高斯分布的heatmap
# execute
width = detection[2] - detection[0]
height = detection[3] - detection[1]
width = math.ceil(width * width_ratio) #取上整
height = math.ceil(height * height_ratio)
if gaussian_rad == -1:
radius = gaussian_radius((height, width), gaussian_iou) #calculate the radius
radius = max(0, int(radius))
else:
radius = gaussian_rad
draw_gaussian(tl_heatmaps[b_ind, category], [xtl, ytl], radius)
draw_gaussian(br_heatmaps[b_ind, category], [xbr, ybr], radius)
else:
#if not guassian bump,then the corresponding corner equals 1,others equal 0
tl_heatmaps[b_ind, category, ytl, xtl] = 1
br_heatmaps[b_ind, category, ybr, xbr] = 1
# the index of target that be detected in current image, a value
tag_ind = tag_lens[b_ind]
# the offset between the true coordinate of corner and the actual coordinate of it
tl_regrs[b_ind, tag_ind, :] = [fxtl - xtl, fytl - ytl]
br_regrs[b_ind, tag_ind, :] = [fxbr - xbr, fybr - ybr]
# embedding,这里很奇妙,相当于把特征图铺平,然后把corner的位置用该铺平的空间的位置表示
tl_tags[b_ind, tag_ind] = ytl * output_size[1] + xtl
br_tags[b_ind, tag_ind] = ybr * output_size[1] + xbr
# 每多一个目标(target、detection),对应图片的tag_lens加 1
tag_lens[b_ind] += 1
for b_ind in range(batch_size):
# 用来记录一个batch size图片中target的数量,多少个1表示多少个目标
tag_len = tag_lens[b_ind]
tag_masks[b_ind, :tag_len] = 1
images = torch.from_numpy(images)
tl_heatmaps = torch.from_numpy(tl_heatmaps)
br_heatmaps = torch.from_numpy(br_heatmaps)
tl_regrs = torch.from_numpy(tl_regrs)
br_regrs = torch.from_numpy(br_regrs)
tl_tags = torch.from_numpy(tl_tags)
br_tags = torch.from_numpy(br_tags)
tag_masks = torch.from_numpy(tag_masks)
return {
"xs": [images, tl_tags, br_tags],
"ys": [tl_heatmaps, br_heatmaps, tag_masks, tl_regrs, br_regrs]
}, k_ind
def kp_detection(db, k_ind, data_aug, debug):
data_rng = system_configs.data_rng
batch_size = system_configs.batch_size
categories = db.configs["categories"] #80
input_size = db.configs["input_size"] #511
output_size = db.configs["output_sizes"][0] #[[128,128]]
border = db.configs["border"] #128
lighting = db.configs["lighting"] #True
rand_crop = db.configs["rand_crop"] #False
rand_color = db.configs["rand_color"] #False
rand_scales = db.configs["rand_scales"] #False
gaussian_bump = db.configs["gaussian_bump"] #True
gaussian_iou = db.configs["gaussian_iou"] #0.7
gaussian_rad = db.configs["gaussian_radius"] #-1
max_tag_len = 128 #一张图中最大可能的target数量
# allocating memory
images = np.zeros((batch_size, 3, input_size[0], input_size[1]), dtype=np.float32)
tl_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32)
br_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32)
tl_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
br_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
tl_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
br_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8)
tag_lens = np.zeros((batch_size, ), dtype=np.int32) # store the num of targets for every image in a batch images
for ind, detection in enumerate(detections):
# prepare the ground_truth heatmap
category = int(detection[-1]) - 1 #get the detected target's category
xtl, ytl = detection[0], detection[1] # the coordinate of the left-top corner
xbr, ybr = detection[2], detection[3] # the coordinate of the right-bottom corner
fxtl = (xtl * width_ratio) # reflect the coordinate to the size of output feature map
fytl = (ytl * height_ratio)
fxbr = (xbr * width_ratio)
fybr = (ybr * height_ratio)
xtl = int(fxtl) #give the postion at which the corner actually located
ytl = int(fytl)
xbr = int(fxbr)
ybr = int(fybr)
if gaussian_bump:
# 使用高斯分布的heatmap
# execute
width = detection[2] - detection[0]
height = detection[3] - detection[1]
width = math.ceil(width * width_ratio) #取上整
height = math.ceil(height * height_ratio)
if gaussian_rad == -1:
radius = gaussian_radius((height, width), gaussian_iou) #calculate the radius
radius = max(0, int(radius))
else:
radius = gaussian_rad
draw_gaussian(tl_heatmaps[b_ind, category], [xtl, ytl], radius)
draw_gaussian(br_heatmaps[b_ind, category], [xbr, ybr], radius)
else:
#if not guassian bump,then the corresponding corner equals 1,others equal 0
tl_heatmaps[b_ind, category, ytl, xtl] = 1
br_heatmaps[b_ind, category, ybr, xbr] = 1
-
(1)获得左上角点坐标(xtl, ytl )和右上角点坐标( xbr, ybr ) -
(2)将坐标按照类别映射到heatmap上(记作真实映射位置,为整型),代码中的(xtl ,ytl)和(xbr ,ybr) -
(3)记录精确的映射点(记作精确映射位置,为浮点型),代码中的(fxtl ,fytl)和(fxbr ,fybr) -
(4)以整型的角点(xtl ,ytl)和(xbr ,ybr)所在位置为圆心,使用高斯分布获取其邻近区域(越靠近真实角点的值越大,越远离的越小) -
(5)获得heatmap的监督信息
max_tag_len = 128 #一张图中最大可能的target数量
tl_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
br_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32)
tl_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
br_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64)
tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8)
tag_lens = np.zeros((batch_size, ), dtype=np.int32) # store the num of targets for every image in a batch images
# 每多一个目标(target、detection),对应图片的tag_lens加 1
tag_lens[b_ind] += 1
# the index of target that be detected in current image, a value
tag_ind = tag_lens[b_ind]
# the offset between the true coordinate of corner and the actual coordinate of it
tl_regrs[b_ind, tag_ind, :] = [fxtl - xtl, fytl - ytl]
br_regrs[b_ind, tag_ind, :] = [fxbr - xbr, fybr - ybr]
# embedding,这里很奇妙,相当于把特征图铺平,然后把corner的位置用该铺平的空间的位置表示
tl_tags[b_ind, tag_ind] = ytl * output_size[1] + xtl
br_tags[b_ind, tag_ind] = ybr * output_size[1] + xbr
-
(1)ytl * output_size[1]+ xtl -
(2)ybr * output_size[1]+ xbr
tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8)
for b_ind in range(batch_size):
# 用来记录一个batch size图片中target的数量,多少个1表示多少个目标
tag_len = tag_lens[b_ind]
tag_masks[b_ind, :tag_len] = 1
images = torch.from_numpy(images)
tl_heatmaps = torch.from_numpy(tl_heatmaps)
br_heatmaps = torch.from_numpy(br_heatmaps)
tl_regrs = torch.from_numpy(tl_regrs)
br_regrs = torch.from_numpy(br_regrs)
tl_tags = torch.from_numpy(tl_tags)
br_tags = torch.from_numpy(br_tags)
tag_masks = torch.from_numpy(tag_masks)
return {
"xs": [images, tl_tags, br_tags],
"ys": [tl_heatmaps, br_heatmaps, tag_masks, tl_regrs, br_regrs]
}, k_ind
2.2 损失函数详解
class AELoss(nn.Module):
def __init__(self, pull_weight=1, push_weight=1, regr_weight=1, focal_loss=_neg_loss):
super(AELoss, self).__init__()
self.pull_weight = pull_weight
self.push_weight = push_weight
self.regr_weight = regr_weight
self.focal_loss = focal_loss
self.ae_loss = _ae_loss
self.regr_loss = _regr_loss
def forward(self, outs, targets):
stride = 6
tl_heats = outs[0::stride] #stride就是step,没有end
br_heats = outs[1::stride]
tl_tags = outs[2::stride]
br_tags = outs[3::stride]
tl_regrs = outs[4::stride]
br_regrs = outs[5::stride]
gt_tl_heat = targets[0]
gt_br_heat = targets[1]
gt_mask = targets[2]
gt_tl_regr = targets[3]
gt_br_regr = targets[4]
# focal loss
focal_loss = 0
tl_heats = [_sigmoid(t) for t in tl_heats]
br_heats = [_sigmoid(b) for b in br_heats]
focal_loss += self.focal_loss(tl_heats, gt_tl_heat)
focal_loss += self.focal_loss(br_heats, gt_br_heat)
# tag loss
pull_loss = 0
push_loss = 0
for tl_tag, br_tag in zip(tl_tags, br_tags):
pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
pull_loss += pull
push_loss += push
pull_loss = self.pull_weight * pull_loss
push_loss = self.push_weight * push_loss
# regression loss
regr_loss = 0
for tl_regr, br_regr in zip(tl_regrs, br_regrs):
regr_loss += self.regr_loss(tl_regr, gt_tl_regr, gt_mask)
regr_loss += self.regr_loss(br_regr, gt_br_regr, gt_mask)
regr_loss = self.regr_weight * regr_loss
loss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats)
return loss.unsqueeze(0)
stride = 6
tl_heats = outs[0::stride] #stride就是step,没有end
br_heats = outs[1::stride]
tl_tags = outs[2::stride]
br_tags = outs[3::stride]
tl_regrs = outs[4::stride]
br_regrs = outs[5::stride]
gt_tl_heat = targets[0]
gt_br_heat = targets[1]
gt_mask = targets[2]
gt_tl_regr = targets[3]
gt_br_regr = targets[4]
# focal loss
focal_loss = 0
tl_heats = [_sigmoid(t) for t in tl_heats]
br_heats = [_sigmoid(b) for b in br_heats]
focal_loss += self.focal_loss(tl_heats, gt_tl_heat)
focal_loss += self.focal_loss(br_heats, gt_br_heat)
def _neg_loss(preds, gt):
pos_inds = gt.eq(1)
neg_inds = gt.lt(1)
neg_weights = torch.pow(1 - gt[neg_inds], 4) ##由于negative过多,需要降低权重
loss = 0
for pred in preds:
pos_pred = pred[pos_inds]
neg_pred = pred[neg_inds]
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if pos_pred.nelement() == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
-
(1)左上corner的heatmaps,大小为(batch size,128,128,80) -
(2)左上corner的embedding,大小为(batch size,128,128,1) -
(3)左上corner的offsets,大小为(batch size,128,128,2) -
(4)右下corner的heatmaps,大小为(batch size,128,128,80) -
(5)右下corner的embedding,大小为(batch size,128,128,1) -
(6)右下corner的offsets,大小为(batch size,128,128,2)
-
(1)左上corner的embedding,大小为(batch size,128,128,1) -
(2)左上corner的offsets,大小为(batch size,128,128,2) -
(3)右下corner的embedding,大小为(batch size,128,128,1) -
(4)右下corner的offsets,大小为(batch size,128,128,2)
#对上面两个分支分别输出三个预测分支
tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv) #bsx127x127x80
tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv) #bsx127x127x1
tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv) #bsx127x127x2
# 在输出特征图上,取物体的gt bbox的角点对应位置的值(可以是embedding,也可以是regr)
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
outs += [tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr]
we only apply the losses at the ground-truth corner location.
# tag loss
pull_loss = 0
push_loss = 0
for tl_tag, br_tag in zip(tl_tags, br_tags):
pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
pull_loss += pull
push_loss += push
pull_loss = self.pull_weight * pull_loss
push_loss = self.push_weight * push_loss
pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
-
(1)网络的两个输出特征图tl_tag, br_tag -
(2)2.1中提到的tag_masks
def _ae_loss(tag0, tag1, mask):
num = mask.sum(dim=1, keepdim=True).float()
tag0 = tag0.squeeze()
tag1 = tag1.squeeze()
tag_mean = (tag0 + tag1) / 2
tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4)
tag0 = tag0[mask].sum()
tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4)
tag1 = tag1[mask].sum()
pull = tag0 + tag1
mask = mask.unsqueeze(1) + mask.unsqueeze(2)
mask = mask.eq(2)
num = num.unsqueeze(2)
num2 = (num - 1) * num
dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)
dist = 1 - torch.abs(dist)
dist = nn.functional.relu(dist, inplace=True)
dist = dist - 1 / (num + 1e-4)
dist = dist / (num2 + 1e-4)
dist = dist[mask]
push = dist.sum()
return pull, push
# regression loss
regr_loss = 0
for tl_regr, br_regr in zip(tl_regrs, br_regrs):
regr_loss += self.regr_loss(tl_regr, gt_tl_regr, gt_mask)
regr_loss += self.regr_loss(br_regr, gt_br_regr, gt_mask)
regr_loss = self.regr_weight * regr_loss
loss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats)
return loss.unsqueeze(0)
def _regr_loss(regr, gt_regr, mask):
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr)
regr = regr[mask]
gt_regr = gt_regr[mask]
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
loss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats)
3 总结
推荐阅读


