AI 공부 도전기

[Pytorch] torch.Tensor.contiguous 이유와 사용법 

 

     

torch.Tensor.contiguous란

https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html

 

torch.Tensor.contiguous — PyTorch 1.11.0 documentation

Shortcuts

pytorch.org

 

 

Pytorch의 새로운 모듈이나 parameter를 조절하는 가운데 contiguous를 사용하는 경우가 종종 있다.

이에 따라 contiguous의 필요성에 대해 조사하는 과정을 거쳤다.

 

영문으로 contiguous는 "인접한, 근접한"의 뜻을 가지고 있다.

위 이미지와 링크에서 확인할 수 있듯 contiguous는 tensor에 사용하며 메모리 tensor를 contiguous할 수 있도록 재정의 하는 것을 의미한다.

이와 관련하여 글만으로는 이해가 어려워 아래 예제를 가져왔다.

 

기본적으로 정의된 a라는 tensor가 있다고 해보겠다.

여기서 해당 tensor의 메모리 정렬은 순차적으로 4씩 증가하는 것을 확인할 수 있다.

참고로 torch.Tensor.data_ptr()는 저장되는 tensor의 메모리 주소를 반환한다.

 

import torch

a = torch.randn(2,3)
print(a.size()) # torch.Size([2, 3])
print(a)

"""
tensor([[2.2655, 0.4245, 1.9640],
        [0.2972, 0.8499, 0.2631]])
"""

for i in range(2):
    for j in range(3):
        print(a[i][j].data_ptr())

"""
1822596866880
1822596866884
1822596866888
1822596866892
1822596866896
1822596866900
"""

 

torch.Tensor.contiguous 사용 이유

 

만약 narrow(), expand(), view(), transpose()를 통해 tensor의 모양을 변화시킬 경우 새로운 텐서를 생성하는 것이 아니라 저장된 tensor memory 주소는 그대로 둔채 모양만 바꾼다.

즉, original tensor와 modified tensor의 메모리는 같이 공유한 상태이다.

 

a.transpose_(0,1)

for i in range(3):
    for j in range(2):
        print(a[i][j].data_ptr())

"""
1822596866880
1822596866892
1822596866884
1822596866896
1822596866888
1822596866900
"""

print('Is a contiguous? ', a.is_contiguous()) # Is a contiguous?  False

 

한 번의 function의 사용은 에러를 일으키지 않지만 비슷한 모양 변화를 시키는 경우 contiguous하지 않다는 에러를 접할 수 있다.

 

# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
# a.view(2,3)

 

이때 우리가 오늘 배운 contiguous 함수를 사용한다.

 

a.contiguous().view(2,3)

print(a.size()) # # torch.Size([3, 2])

 

이렇게 강제로 메모리를 새롭게 할당하여 연속적으로 만든 후 view를 사용하면 그 모양대로 수정이 가능하다.

물론 reshape을 통해서도 수정이 가능하다.

 

예제 통합

 

import torch

a = torch.randn(2,3)
print(a.size()) # torch.Size([2, 3])
print(a)

"""
tensor([[ 1.9853, -1.2616,  0.2045],
        [-0.1057, -1.9393, -1.3162]])
"""

a.transpose_(0,1)
print(a.size()) # torch.Size([3, 2])
print(a)

"""
tensor([[ 1.9853, -0.1057],
        [-1.2616, -1.9393],
        [ 0.2045, -1.3162]])
"""

for i in range(3):
    for j in range(2):
        print(a[i][j].data_ptr())

"""
1581749542400
1581749542412
1581749542404
1581749542416
1581749542408
1581749542420
"""

# torch.Tensor.data_ptr() : Returns the address of the first element of self tensor.
print()

b = torch.randn(3,2)

for i in range(3):
    for j in range(2):
        print(a[i][j].data_ptr())

"""
1581749542400
1581749542412
1581749542404
1581749542416
1581749542408
1581749542420
"""

print(a.stride()) # (1, 3)
print(b.stride()) # (2, 1)

print('Is a contiguous? ', a.is_contiguous()) # Is a contiguous?  False
print('Is b contiguous? ', b.is_contiguous()) # Is b contiguous?  True

a = a.contiguous()

print('Is a contiguous? ', a.is_contiguous()) # Is a contiguous?  True

for i in range(3):
    for j in range(2):
        print(a[i][j].data_ptr())

"""
1581749547776
1581749547780
1581749547784
1581749547788
1581749547792
1581749547796
"""

 

참조 

 

https://stackoverflow.com/questions/48915810/pytorch-what-does-contiguous-do

 

PyTorch - What does contiguous() do?

I was going through this example of a LSTM language model on github (link). What it does in general is pretty clear to me. But I'm still struggling to understand what calling contiguous() does, which

stackoverflow.com

https://jimmy-ai.tistory.com/122

 

[Pytorch] contiguous 원리와 의미

torch의 contiguous에 대해서 안녕하세요. 이번 시간에는 파이토치에서 메모리 내에서의 자료형 저장 상태로 등장하는 contiguous의 원리와 의미에 대해서 간단히 살펴보도록 하겠습니다. contiguous 여부

jimmy-ai.tistory.com

 

 

 

 

 

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading