欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

PyTorch:tensor操作contiguous

时间:2023-05-26
本文目录

tensor在内存中的存储

信息区和存储区shape && stride contiguous

什么时候用contiguous呢?为什么要用contiguous为什么contiguous能有效? tensor在内存中的存储 信息区和存储区

tensor在内存中的存储包含信息区和存储区

信息区(Tensor)包含tensor的形状size,步长stride,数据类型type等存储区(Storage)包含存储的数据

高维数组在内存中是按照行优先顺序存储的,什么是行优先顺序?假设我们有一个(3, 4)的tensor,他其实是按照一维数组的方式存储的,只不过在tensor的信息区记录了他的size和stride导致实际上展示出的数组是二维的,size为(3, 4)

二维数组

内存中的一维形式

接下来我们看一个例子,例子表明tensor中的元素在内存上是连续的,并且也证明了确实是行优先顺序存储

tensor = torch.tensor([[[1 ,2, 3, 4], [5, 6, 7, 8], [9, 10,11,12]], [[13,14,15,16],[17,18,19,20],[21,22,23,24]]])print(tensor.is_contiguous())for i in range(2): for j in range(3): for k in range(4): print(tensor[i][j][k].data_ptr(), end=' ')'''True140430616343104 140430616343112 140430616343120 140430616343128140430616343136 140430616343144 140430616343152 140430616343160 140430616343168 140430616343176 140430616343184 140430616343192 140430616343200 140430616343208 140430616343216 140430616343224 140430616343232 140430616343240 140430616343248 140430616343256 140430616343264 140430616343272 140430616343280 140430616343288 '''

shape && stride

继续上述的例子我们来看一下在信息区的shape和stride属性,对于(2, 3, 4)维的tensor他的shape为(2, 3 ,4),stride为(12, 4, 1)

shape

shape很容易理解,就是tensor的维度,上述例子为(2, 3, 4)的tensor,维度就为(2, 3, 4)

stride

stride代表着多维索引的步长,每一步都代表内存上的偏移量+1,对于(2, 3, 4)维度的tensor:stride+1代表着(dim2)+1,stride+4代表其余dim不变,(dim1)+1,stride+12代表其余dim不变,(dim0)+1,如下图所示

图示stride

stride计算方法

s t r i d e i = s t r i d e i + 1 ∗ s i z e i + 1      i ∈ [ 0 , n − 2 ] stride_{i} = stride_{i+1} * size_{i+1}~~~~iin[0, n-2] stridei​=stridei+1​∗sizei+1​    i∈[0,n−2]

对于shape(2, 3, 4)的tensor,计算如下(stride3=1)

s t r i d e 2 = s t r i d e 3 ∗ s h a p e 3 = 1 ∗ 4 = 4 s t r i d e 1 = s t r i d e 2 ∗ s h a p e 2 = 4 ∗ 3 = 12 stride_{2} = stride_{3} * shape_{3}=1*4=4 \ stride_{1} = stride_{2} * shape_{2}=4*3=12 stride2​=stride3​∗shape3​=1∗4=4stride1​=stride2​∗shape2​=4∗3=12

stride = [1] # 初始化第一个元素 # 从后往前遍历迭代生成 stride for i in range(len(tensor.size())-2, -1, -1): stride.insert(0, stride[0] * tensor.shape[i+1]) print(stride) # [12, 4, 1]print(tensor.stride()) # (12, 4, 1)

理解了tensor在内存中的存储之后,我们再来看contiguous

contiguous

contiguous

返回一个连续内存的tensor

Returns a contiguous in memory tensor containing the same data as self tensor、If self tensor is already in the specified memory format, this function returns the self tensor.

什么时候用contiguous呢?

简单理解就是tensor在内存地址中的存储顺序与实际的一维索引顺序不一致时使用,如下所示,对上面的tensor进行一维索引,结果为[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],对原tensor运用transpose进行转置,在对其进行一维索引,结果为[1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12],这时索引顺序发生了变化,所以需要用contiguous

注意:不论怎么变化每个元素对应的地址是不变的,比如11对应的地址为x11,transpose之后11依然对应x11,那么变化的是什么呢?还记得tensor分为信息区和存储区吗,存储区是不变化的,变化的是信息区的shape,stride等信息,有时间以后做介绍~

代码示例

tensor = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])print(tensor)print(tensor.is_contiguous())tensor = tensor.transpose(1, 0)print(tensor)print(tensor.is_contiguous())'''tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]])Truetensor([[ 1, 5, 9], [ 2, 6, 10], [ 3, 7, 11], [ 4, 8, 12]])False'''

为什么要用contiguous

有人可能会有疑问,既然上述情况索引与之前不一样了(不连续了),为什么要让他变连续呢?因为pytorch的某些操作需要索引和内存连续,比如view

代码示例(接着上面的例子)

tensor = tensor.contiguous()print(tensor.is_contiguous())tensor = tensor.view(3, 4)print(tensor)'''Truetensor([[ 1, 5, 9, 2], [ 6, 10, 3, 7], [11, 4, 8, 12]])'''

如果不用contiguous会报以下错误

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces)、Use .reshape(...) instead.

为什么contiguous能有效?

contiguous用了一种简单粗暴的方法,既然你之前的索引和内存不连续了,那我就重新开辟一块连续的内存给他加上索引即可

代码示例,从下面代码中stride变化可以看出,transpose之后的tensor确实是改变了信息区的信息

tensor = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])print(tensor.is_contiguous())# Truefor i in range(3): for j in range(4): print(tensor[i][j], tensor[i][j].data_ptr(), end=' ') print()print(tensor.stride()) # (4, 1)'''Truetensor(1) 140430616321664 tensor(2) 140430616321672 tensor(3) 140430616321680 tensor(4) 140430616321688 tensor(5) 140430616321696 tensor(6) 140430616321704 tensor(7) 140430616321712 tensor(8) 140430616321720 tensor(9) 140430616321728 tensor(10) 140430616321736 tensor(11) 140430616321744 tensor(12) 140430616321752 (4, 1)'''tensor = tensor.transpose(1, 0)print(tensor.is_contiguous())# Falsefor i in range(4): for j in range(3): print(tensor[i][j], tensor[i][j].data_ptr(), end=' ') print()print(tensor.stride()) # (1, 4) changed'''Falsetensor(1) 140430616321664 tensor(5) 140430616321696 tensor(9) 140430616321728 tensor(2) 140430616321672 tensor(6) 140430616321704 tensor(10) 140430616321736 tensor(3) 140430616321680 tensor(7) 140430616321712 tensor(11) 140430616321744 tensor(4) 140430616321688 tensor(8) 140430616321720 tensor(12) 140430616321752 (1, 4)'''tensor = tensor.contiguous()print(tensor.is_contiguous())# Truefor i in range(4): for j in range(3): print(tensor[i][j], tensor[i][j].data_ptr(), end=' ') print()print(tensor.stride()) # (3, 1)'''Truetensor(1) 140431681244608 tensor(5) 140431681244616 tensor(9) 140431681244624 tensor(2) 140431681244632 tensor(6) 140431681244640 tensor(10) 140431681244648 tensor(3) 140431681244656 tensor(7) 140431681244664 tensor(11) 140431681244672 tensor(4) 140431681244680 tensor(8) 140431681244688 tensor(12) 140431681244696 (3, 1)'''

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。