아이공의 AI 공부 도전기

torch.autograd.Variable 알아보기

 

목차

 

     

PyTorch Autograd Variable 한글 튜토리얼

9bow.github.io/PyTorch-tutorials-kr-0.3.1/beginner/blitz/autograd_tutorial.html

 

Autograd: 자동 미분 — PyTorch Tutorials 0.3.1 documentation

PyTorch의 모든 신경망의 중심에는 autograd 패키지가 있습니다. 먼저 이것을 가볍게 살펴본 뒤, 첫번째 신경망을 학습시켜보겠습니다. autograd 패키지는 Tensor의 모든 연산에 대해 자동 미분을 제공합

9bow.github.io

우선 설명에 앞서 torch.autograd.Variable는 2021.01.12 기준 1.7 version에서는 더 이상 사용되지 않습니다. 현재는 모든 tensor가 자동적으로 Variable의 성질을 가지기 때문에 굳이 사용할 필요가 없으나 과거 버전의 코드를 수행함에 있어 Variable의 성질을 이해할 필요가 있기에 정리를 시도합니다. 

Autograd : 자동 미분

신경망은 Backpropagation이 핵심이며 이를 수행하기 위해서는 미분이 필수적입니다. 그렇기에 tensorflow 혹은 pytorch는 사용자의 편의를 위해 자동적으로 미분이 가능하도록 이미 구현된 코드를 간편하게 사용할 수 있습니다.

이는 모든 Tensor 연산에 있어 자동 미분이 가능하다는 의미입니다. 

물론 단순히 편의만을 위한 것은 아닙니다. 효율적인 연산을 통해 더 빠르게 미분을 구할 수 있게 도와줍니다.

autograd 이 가운데에서도 Variable이라는 클래스가 존재합니다.

 

torch.autograd.Variable

Variable 클래스 내 함수

 Variable 클래스는 각 tensor의 값을 볼 수 있는 data, 미분을 보는 grad, backward를 통한 미분을 계산한 함수 정보인 grad_fn 총 3개의 함수를 사용할 수 있습니다.

Variable 클래스 사용법

1) Variable에 Tensor를 감싼다. 그리고 requires_grad=True를 통해 미분을 진행하겠다는 의사를 표시한다.

ex)

import torch
from torch.autograd import Variable

x = Variable(torch.randint(2,2), requires_grad=True)

2) Variable을 활용한 수식들을 적는다.

ex) 

out = x**3 + 7*x + 10

3) 수식의 결과를 backpropagation을 진행한다. 이를 위해서 함수 backward를 사용한다.

ex)

out.backward()

4) 초기 x의 grad에 대해 살펴보며 해당 변화량을 확인한다. 

ex)

print(x.grad)

 

Neural Network에서의 Variable 사용

일반적으로 Variable을 활용하여 weight들이나 z noise의 변화량을 변화시키는 방향으로 진행하며 이때 최종 out에 해당하는 것은 loss를 사용합니다. 즉, loss의 backward()를 통해 backpropagation을 구하고 이를 통해 weight들을 자신의 데이터에 맞게 변화한다는 것으로 이해하시면 됩니다.

 

현재 Variable 클래스 상황(Pytorch 1.7 기준)

 

과거와 같이 Variable을 사용할 수는 있으나 굳이 필요 없다는 점을 pytorch 공식 document에서는 설명하고 있습니다. 이에 따라 단순히 tensor에서 requires_grad를 통한 설정을 함으로써 값을 구하시면 됩니다. 또한 Variable에서 사용하던 backward(), detach(), register_hook()를 tensor에서도 동일하게 적용 가능하다고 합니다.

결론

Variable은 과거 사용된 autograd 방식으로 현재는 torch에서 모든 tensor에 autograd가 가능하도록 설정되어 있다. 비록 그 필요성은 사라졌지만 Variable은 여전히 사용할 수 있으며 해당 함수들 또한 사용이 가능하다.

 

기타 예시 실험(과거 PyTorch official tutorial code)

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading