memo: PyTorch | Contiguous & Stride

Table of contents


Stride

  • (2023-09-28) contiguous means the relative sequence (underlying 1D memory arrangement when they were first created) of elements in a tensor hasn’t been changed, although the shape can be mutated. (The data on memory blocks never reshape.)

    E.g., the layout of a = torch.arange(4) means 0,1,2,3 are next to each other in this order. a.stride() is (1,).

    1. b = a.reshape(2,2) $\[^{[0,1]}_{[2,3]}]$ is still contiguous. Because only the data’s dimensional interpretation changed, and the read order of 0,1,2,3 didn’t change. b.stride() is (2,1).

    2. torch.flip(b, [0]) $\[^{[2,3]}_{[0,1]}]$ is contiguous, because they’re next to each other in terms of the tensor’s .stride() which remains (2,1).

    3. The stride is responsible to the initial underlying layout, if the stride gets changed, the tensor won’t be contiguous anymore.

      The operation transpose and permute will change stride, which isn’t row-contiguous any longer.

      so b.transpose(0,1) or b.permute(1,0): $\[^{[0,2]}_{[1,3]}]$ is not contiguous any more.

    4. .contiguous() will change the stride to match the current shape. What does .contiguous() do in PyTorch?

      (2023-12-22) .contiguous() will copy the data to a new memory strip, which can be checked via .storage(). The new and the original tensors won’t affect each other. Pytorch - Contiguous vs Non-Contiguous Tensor / View — Understanding view(), reshape(), transpose() - Kathryn

    5. .permute() made .stride() decoupled with the tensor’s varying shapes mutating from the original shape.

      That means when the tensor reverts to the original size, the .stride() won’t return to the structure matched with the tensor’s original shape.

      1
      2
      3
      4
      5
      6
      7
      8
      9
      
      a = torch.arange(8).reshape(2,4) # contiguous
      b = a.transpose(0,1)   # (4,2) not contiguous
      b.view(2,4)    # Can't work 
      # because the stride don't match the target shape
      
      b.contiguous().view(2,4)
      
      # Or use 
      b.reshape(2,4) # will copy data if .view() don't work.
      

      Therefore, .contiguous() is necessary before view() after transpose() or permute().


Read-Write Head

  • (2023-10-13) stride 是为了从一条内存上索引(“拼凑”) 取出 张量的某一个维度,读写头的步幅。 How does NumPy’s transpose() method permute the axes of an array?

    • 计算一个 tensor 的 stride:给定一个 contiguous 的tensor: x.shape = (2,2,4), 则 x.stride() = (8,4,1), 所以 一个维度上的 stride 等于它后面的维数累乘:8=2×4, 4=4×1, 1=1.

      【pytorch Tensor shape 变化 view 与 reshape(contiguous 的理解)】

    • transpose 转置某2个维度,即交换那两个维度上的 stride,不同的维度在内存上走不同的距离。

    • .T 会倒序全部维度上的 stride, e.g., (2,2,4) -> (4,2,2)

  • (2023-10-23) Examples:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    
    >>> a = torch.arange(6).reshape(2,3,1) # stride: (3,1,1)
    tensor([[[0],
           [1],
           [2]],
    
          [[3],
           [4],
           [5]]])
    
    >>> b = a.transpose(0,1)  # stride: (1,3,1), not contiguous
    tensor([[[0],
           [3]],
    
          [[1],
           [4]],
    
          [[2],
           [5]]])
    
    >>> c = a.transpose(1,2)  # stride: (3,1,1), contiguous
    tensor([[[0, 1, 2]],
    
          [[3, 4, 5]]])
    
    >>> d = a.transpose(0,2)  # stride: (1,1,3), not contiguous
    tensor([[[0, 3],
           [1, 4],
           [2, 5]]])
    >> e = a.T    # stride: (1,1,3)
    
    • .transpose swaps strides, .T reverses strides, and permute reorders strides, while .view changes stride to match the shape while keeping elements “互相接壤的”。

    • contiguous “互相接壤的”: 矩阵的一行中相邻的 2 个元素,在内存上也相邻。 其中,矩阵一行的末尾与下一行的开头相邻。

  • (2023-10-24) contiguous means that no matter how the shape changes, the movement of the read/write head acts like indexing a flattened tensor.


Rectangle’s Shape

  • (2023-12-10) Changing the shape of a tensor is like stretching an area-fixed rectangle, although h and w are changed, the relative order of internal elements is not changed.

    • Imagine dragging the bottom right corner to change the height and width of the rectangle.

    Alternatively, one can imagine the data as sand grains being enclosed by a size-changble frame. And once the shape changed, the rows are filled up first:

    m e m o r y : R W 3 1 h x e 2 a 6 d 3 4 D f r i H c a l , h g l a W n h i g e n e r t e o 2 x 9 1 8

    The reading order is unchanged, keeping the sequence from 1 to 18:

    1 7 1 3 2 8 1 4 3 9 1 5 4 1 1 0 6 5 1 1 1 7 6 1 1 2 8 1 1 0 2 1 1 3 1 2 4 1 3 5 1 4 6 1 5 7 1 6 8 1 7 9 1 8

    (我忘了是不是有这么一种玩具:有一个长方形,框住了一些“棋子”。 当你拖动长方形的一个角的时候,因为它面积是固定的,所以里面的棋子会重新排列。对应到 .view() 就是每行先填满。)

    Since the view doesn’t changes the stride of the read/write header, the target shape requires to match the current stride. torch view - Docs

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    
    xy= torch.cartesian_prod(torch.arange(3), 
                             torch.arange(2))
    # tensor([0, 0],
    #        [0, 1],
    #        [1, 0],
    #        [1, 1],
    #        [2, 0],
    #        [2, 1]])
    # memory: 0 0 0 1 1 0 1 1 2 0 2 1
    # shape: (2,6), stride: (2,1)
    
    print(xy.view(2,6))   # contiguous, stride: (4,1)
    # tensor([[0, 0, 0, 1, 1, 0],
    #        [1, 1, 2, 0, 2, 1]])
    # 一条龙串下来
    
    xy.t().is_contiguous()  # False, stride: (1,2)
    # tensor([[0, 0, 1, 1, 2, 2],
    #        [0, 1, 0, 1, 0, 1]])
    

    view is filling an empty box from the innermost dimension to outermost by consuming the 1D memory data.

  • (2023-12-22) The tensors created by torch.meshgrid() are not contiguous.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    
    y, x = torch.meshgrid(torch.arange(3),
                          torch.arange(2))
    # y: vertical coordinate
    # tensor([[0, 0],
    #         [1, 1],
    #         [2, 2]])
    # Memory is: 0 1 2
    # Not contiguous. stride on each dim: (1,0)
    
    y.storage()   
    #  0
    #  1
    #  2
    # [torch.LongStorage of size 3]
    
    # x: horizontal coordinate
    # tensor([[0, 1],
    #         [0, 1],
    #         [0, 1]])
    # Memory is 0 1
    # Not contiguous. stride on each dim: (0,1)
    
    x.storage()   
    #  0
    #  1
    # [torch.LongStorage of size 2]
    
    y.contiguous().stride()   # stride: (2,1)
    

    When reading y and x, the read-write head has to go back or repeat some bytes, instead iterates the 1D memory sequence once, so they’re not contiguous.

  • repeat will copy the data, while expand won’t, with a stride of 0 on the singleton dimension. (singleton means the size of that dimension is 1)

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    
    y.unsqueeze(2).repeat(1,1,2) # shape: (3,2,2), contiguous
    # tensor([[[0, 0],
    #         [0, 0]],
    #
    #        [[1, 1],
    #         [1, 1]],
    #
    #        [[2, 2],
    #         [2, 2]]])
    # stride: (4,2,1)
    
    y.unsqueeze(2).expand(-1,-1,2) # shape: (3,2,2), Not contiguous
    # tensor([[[0, 0],
    #          [0, 0]],
    #
    #         [[1, 1],
    #          [1, 1]],
    #
    #         [[2, 2],
    #          [2, 2]]])
    # stride: (1,0,0)
    
Built with Hugo
Theme Stack designed by Jimmy