在Pytorch中,看一个操作是否可导,即经过这个操作梯度是否还能顺利传递。
可以看到,经过+操作后得到的z,仍能保持梯度的传递
而像torch.argmax(), torch.eq() 这些操作就不行了,这些操作就是不可导的
即,遇到不可导的,你反向传播都会出问题,程序自己就会报错
而如果像soft argmax
import torchimport torch.nn as nndef soft_argmax(x):"""Arguments: voxel patch in shape (batch_size, channel, H, W, depth)Return: 3D coordinates in shape (batch_size, channel, 3)"""# alpha is here to make the largest element really big, so it# would become very close to 1 after softmaxalpha = 10000.0 N,C,L = x.shapesoft_max = nn.functional.softmax(x*alpha,dim=2)soft_max = soft_max.view(x.shape)indices_kernel = torch.arange(start=0, end=L).unsqueeze(0)# indices_kernel = indices_kernel.view((H,W,D))# indices_kernel = indices_kernel.view(H,W)conv = soft_max*indices_kernelindices = conv.sum(2)# z = indices%D# y = (indices).floor()%W# x = (((indices).floor())/W).floor()%H# coords = torch.stack([x,y,z],dim=2)# coords = torch.stack([x,y],dim=2) #coords[0][0]代表第一个channel的最大点的坐标值 #coords[0][1]代表第2个channel的最大点的坐标值return indices if __name__ == "__main__":x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth)coords = soft_argmax(x) #coords是[b,c,2]print(coords)
操作就是可导的
怎样判断一个操作是否是可导的
时间:2023-08-28
相关推荐