AI 공부 도전기

torch.Tensor.scatter_ 알아보기

2020.4.19 기준 Version 1.4.0

https://pytorch.org/docs/stable/tensors.html?highlight=scatter#torch.Tensor.scatter_

1. scatter와 scatter_의 비교 


현 버전 기준 scatter_는 inplace=True인 상태로 scatter과 다르게 바로 해당 tensor에 적용한 것을 의미합니다.

parameter의 순서 또한 동일


cf) torch.gather()라는 API 또한 존재하는데 그 내용이 scatter_와 일맥상통하다. 다만 차이가 있는데 작동 방식이 완전 반대 방법이라는 것이다.

자세한 내용은 아래 주소를 참조

https://pytorch.org/docs/stable/torch.html#torch.gather

2. parameter의 순서와 내용



이번에는 parameter의 순서와 내용에 대해 알아보자.


1. dim차원 축을 이야기한다. 

당연히 정수이어야하며 다루는 tensor의 dimension보다 작거나 같아야한다.

참고로 2차원에서 dim=0은 아래방향(행방향)쪽이다.

dim=1은 가로방향(열방향)쪽이다.

이 방향은 numpy에서도 통용되며 tensorflow dim에서도 많이 사용한다.



홈페이지에 나와있는 값들이 적용되는 방식을 3D에 대해 나와있는데 당연한 이야기를 하고있다.

dim=0이라면 제일 앞의 값의 index 위치에 따라 지정해준다는 의미이고 다른 dim에 대해서도 마찬가지 이야기를 하고있다.


2. index는 위치라고 생각하면 된다. 

LongTensor type을 받는다.

scatter할 tensor의 위치 개념으로 3번 소제목인 사용 방법 예시에서 이해하는 편이 더 수월하다.

기억해야하는 것은 scatter하고자하는 Tensor와의 dimension size와 parameter dim의 axis 방향을 잘 고려해야한다는 점이다.


3. src나 valuescatter할 값이라고 생각하자. 

type은 모든 Tensor뿐 아니라 float type까지 가능하다.

다만 주의해야하는 점은 scatter하고자 하는 Tensor보다 size가 작아야한다는 점이다.

3. 사용 방법 - 홈페이지 예시

우선 홈페이지에 나와있는 예시를 먼저 살펴보자


x를 normal distribution에 따라 random 변수를 2행5열의 2D로 받았고 torch.zeros를 통해 3행5열로 만든 0으로 가득찬 2D 행렬에 scatter하고 싶다는 의미로 보인다.

이 때의 dim=0 아래 방향이고 index는 [[0,1,2,0,0],[2,0,0,1,2]]이고 이 때 x에 대응하는 값을 넣고싶다는 의미이다.

index를 2D 형태로 적어보면 아래와 같다.

[[0,1,2,0,0]

 [2,0,0,1,2]]


(유의 : 첫 행, 열을 0부터 시작하는 것으로 설명함. 0,1,2,3,... 순)

dim=0이므로 아래방향 index 첫 열(0 열)부터 확인하면 [[0],[2]]이다.

즉, torch.zeros(3,5)의 0행 0열에는 x의 0행 0열의 값을 넣고, 

torch.zeros(3,5)의 2행 0열에는 x의 1행 0열의 값을 넣는다는 의미이다.


1열에 대해서도 잠시보면 [[1],[0]]이다. 

즉, torch.zeros(3,5)의 1 1열에는 x의 0행 1열의 값을 넣고, 

torch.zeros(3,5)의 0 1열에는 x의 1행 1열의 값을 넣는다는 의미이다.


이런 방식이 모든 열에 적용된 값이 위의 결과이다.


scatter_를 사용하는 방식은 One hot Encoding할 때도 편리하다.


MNIST의 예로 label data를 입력받는데 이 때의 값들을 index로 할 수 있게끔 reshape해주고 index에 넣어주면 위와 같이 멋진 one hot encoding이 가능하다.


그러나 scatter_를 사용할 때 index에 대해 많은 신경을 써야한다.

4. 주의 사항

scatter_를 처음 사용하다보면 짜증나는 것은 index를 잘못 생각할 때가 많다는 점이다.

그로인해 생기는 Error가 아래 예와 같다.


1. index tensor는 반드시 output tensor와 dimension이 같지 않거나 비어있지 않을 때 생기는 오류


RuntimeError: invalid argument 3: Index tensor must either be empty or have same dimensions as output tensor 


해결책 : output tensor의 dimension과 같게 tensor를 더해준다.


2. parameter dim의 방향에 따른 indexing이 output tensor와 dimension이 맞지 않을 때 생기는 오류 


RuntimeError: Expected tensor [2, 5] and index [1, 4] to have the same size in dimension 0


해결책 : dim의 정수 숫자를 고려해서 scatter하고자하는 tensor의 위치가 index의 값과 일치하는지를 생각해야한다.


ex)


2행5열짜리 x Tensor를 만들었다.

dim=0 방향일 때 index를 [3,0,1,1]로 만들었다. 당연히 output tensor x와 dim이 다르니 dimension을 맞춰줘야한다.(1 해결)


그럼에도 2번에서 이야기했던 문제가 생겼다.

당연히 dim이 0일 때 5개의 열이 존재해야하고 2행까지 밖에 없으므로 0또는 1의 값만이 존재해야하도록 변경(2 해결)


이런식으로 문제를 해결해나가면 조금 더 좋은 API 활용일 수 있을 것 같다.

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading