欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

MMdetection的Proposal原理和代码解析

时间:2023-04-24
一、算法原理

接受N级score,bbox_pred,anchor和image_shape作为输入,通过anchor和框的偏移(bbox_pred)得到proposal,然后对这些proposal做NMS,最后选出前num个。

二、执行步骤 将每级score,bbox_pred,anchor按照score从大到小排序,并选择前nms_pre个(一般为1000),共N*nms_pre个。通过anchor和框的偏移(bbox_pred)得到proposal去除框大小为负数的框,并且对于每级的proposal,加上一个足够大的offset,使得每级的框之间不会有重叠,将多分类NMS转成单分类NMS将N级score和proposal整合在一起,按照score从大到小排序做NMS取前num个,并且给proposal减去之前加上的offset三、python源码解析

#路径:mmdetection/mmdet/models/dense_heads/cascade_rpn_head.py:StageCascadeRPNHead::_get_bboxes_singlelevel_ids = []mlvl_scores = []mlvl_bbox_preds = []mlvl_valid_anchors = []for idx in range(len(cls_scores)): #len(cls_scores)表示是N级cascade rpn_cls_score = cls_scores[idx] #每级的score rpn_bbox_pred = bbox_preds[idx] #每级的bbox_preds assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] #每级score的shape是(num_anchors * num_classes, H, W),bbox_preds的shape是(num_anchors * 4, H, W) rpn_cls_score = rpn_cls_score.permute(1, 2, 0) if self.use_sigmoid_cls: #对score做二分类,用sigmoid rpn_cls_score = rpn_cls_score.reshape(-1) scores = rpn_cls_score.sigmoid() else: #对score做二分类,用softmax rpn_cls_score = rpn_cls_score.reshape(-1, 2) # We set FG labels to [0, num_class-1] and BG label to # num_class in RPN head since mmdet v2.5, which is unified to # be consistent with other head since mmdet v2.0、In mmdet v2.0 # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. scores = rpn_cls_score.softmax(dim=1)[:, 0] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) anchors = mlvl_anchors[idx] if 0 < nms_pre < scores.shape[0]: # sort is faster than topk # _, topk_inds = scores.topk(cfg.nms_pre) ranked_scores, rank_inds = scores.sort(descending=True) topk_inds = rank_inds[:nms_pre] scores = ranked_scores[:nms_pre] rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] anchors = anchors[topk_inds, :] mlvl_scores.append(scores) mlvl_bbox_preds.append(rpn_bbox_pred) mlvl_valid_anchors.append(anchors) level_ids.append( scores.new_full((scores.size(0), ), idx, dtype=torch.long))scores = torch.cat(mlvl_scores)anchors = torch.cat(mlvl_valid_anchors)rpn_bbox_pred = torch.cat(mlvl_bbox_preds)proposals = self.bbox_coder.decode( #通过anchor和框的偏移(bbox_pred)得到proposal anchors, rpn_bbox_pred, max_shape=img_shape)ids = torch.cat(level_ids)if cfg.min_bbox_size >= 0: #去除大小为负数的框 w = proposals[:, 2] - proposals[:, 0] h = proposals[:, 3] - proposals[:, 1] valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) if not valid_mask.all(): proposals = proposals[valid_mask] scores = scores[valid_mask] ids = ids[valid_mask]#NMS操作if proposals.numel() > 0: dets, _ = batched_nms(proposals, scores, ids, cfg.nms)else: return proposals.new_zeros(0, 5)#取前max_per_img个return dets[:cfg.max_per_img]

四,cpu源码解析

Tensor *output = ctx->Output(0, {num, 5});//指定0号输出的尺寸是(num,5),num是框的数量,前四个是框的坐标,最后一个是框的得分float* output_ptr = output->template MutableData();//获取输出指针float* score_ptr = new float[level*nms_pre];//level是级数,nms_pre是每级保留的框数memset(score_ptr, 0, sizeof(float)*level*nms_pre);//有个级不足nms_pre个框,将多余的框的分数置零float* score_sorted_ptr = new float[level*nms_pre];//排序后分数保存的地址float* bbox_pred = new float[level*nms_pre*4]; //bbox_pred,每个坐标都对应一个,一个框有4个坐标float* anchor = new float[level*nms_pre*4]; //anchor的坐标,一个框有4个float* proposal_ptr = new float[level*nms_pre*4]; //偏移后proposal的坐标,一个框有4个float* proposal_sorted_ptr = new float[level*nms_pre*4]; //排序后proposal的坐标//step1 整合并排序N级score,bbox_pred和anchorvector vec_thread;for(int i=0; i(i)->template Data(); const float* input_bbox = Input(i+level)->template Data(); const float* input_anchor = Input(i+level*2)->template Data(); //对score进行排序,并且取前nms_pre个 vector vec_node; vec_node.resize(Input(i).Size());//排序所有的score vector sorted_id = SortedIdx(input_score, vec_node, nms_pre); }typedef struct{ float key; int value;}KeyValuePair;bool compareNode(KeyValuePair node1, KeyValuePair node2) return node1.key>node2.key;vector SortedIdx(float* input_score, vector& vec_node, int nms_pre){ for(int i=0; i sorted_id(nms_pre); for(int i=0; i

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。