欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

怎样判断一个操作是否是可导的

时间:2023-08-28

在Pytorch中,看一个操作是否可导,即经过这个操作梯度是否还能顺利传递。

 可以看到,经过+操作后得到的z,仍能保持梯度的传递

而像torch.argmax(),  torch.eq() 这些操作就不行了,这些操作就是不可导的

即,遇到不可导的,你反向传播都会出问题,程序自己就会报错

 

而如果像soft argmax

import torchimport torch.nn as nndef soft_argmax(x):"""Arguments: voxel patch in shape (batch_size, channel, H, W, depth)Return: 3D coordinates in shape (batch_size, channel, 3)"""# alpha is here to make the largest element really big, so it# would become very close to 1 after softmaxalpha = 10000.0 N,C,L = x.shapesoft_max = nn.functional.softmax(x*alpha,dim=2)soft_max = soft_max.view(x.shape)indices_kernel = torch.arange(start=0, end=L).unsqueeze(0)# indices_kernel = indices_kernel.view((H,W,D))# indices_kernel = indices_kernel.view(H,W)conv = soft_max*indices_kernelindices = conv.sum(2)# z = indices%D# y = (indices).floor()%W# x = (((indices).floor())/W).floor()%H# coords = torch.stack([x,y,z],dim=2)# coords = torch.stack([x,y],dim=2) #coords[0][0]代表第一个channel的最大点的坐标值 #coords[0][1]代表第2个channel的最大点的坐标值return indices if __name__ == "__main__":x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth)coords = soft_argmax(x) #coords是[b,c,2]print(coords)

 操作就是可导的

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。