定义钩子函数
import torchvision.utils as vutilimport cv2def hook_func(module, input, output): """ Hook function of register_forward_hook Parameters: ----------- module: module of neural network input: input of module output: output of module """ image_name = get_image_name_for_hook(module) data = output.clone().detach() # data = data.permute(1, 0, 2, 3) # vutil.save_image(data, image_name, pad_value=0.5) # 这保存的是每个通道捕捉的语义 data = data.permute(1,0,2,3).cpu().squeeze() pic = (np.mean(data.numpy(),axis=0)*255).astype(np.uint8) feature=cv2.resize(pic,(512,512)) # 根据图像的像素值中最大最小值,将特征图的像素值归一化到了[0,1]; feature = (feature - np.amin(feature))/(np.amax(feature) - np.amin(feature) + 1e-5) # 注意要防止分母为0! feature = np.round(feature * 255) # [0, 1]——[0, 255],为cv2.imwrite()函数而进行 feature = cv2.applyColorMap(np.array(feature,np.uint8),2) # 给特征图个颜色 热力图 cv2.imwrite(image_name,feature)INSTANCE_FOLDER = "VIS_results"def get_image_name_for_hook(module): """ Generate image filename for hook function Parameters: ----------- module: module of neural network """ os.makedirs(INSTANCE_FOLDER, exist_ok=True) base_name = str(module).split('(')[0] index = 0 image_name = '.' # '.' is surely exist, to make first loop condition True while os.path.exists(image_name): index += 1 image_name = os.path.join( INSTANCE_FOLDER, '%s_%d.png' % (base_name, index)) return image_name
在验证处嵌入如下定义
with torch.no_grad(): # modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d, # torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d) names_for_plot = ('module.classifier.fusion','module.classifier.context','module.classifier.context.2','module.classifier.context.2.aspp') for name, module in model.named_modules(): # if isinstance(module, modules_for_plot): if name in names_for_plot: module.register_forward_hook(hook_func) for i, (images, labels) in tqdm(enumerate(loader)): if i>=20: break
部分参照:https://blog.csdn.net/bby1987/article/details/109590108