아이공의 AI 공부 도전기

torch.clamp 사용법 살펴보기

 

Official Pytorch API Document URL : 

pytorch.org/docs/stable/generated/torch.clamp.html

 

torch.clamp — PyTorch 1.7.0 documentation

Shortcuts

pytorch.org

     

 

torch.clamp 함수 설명

 

torch.clamp는 min 혹은 max의 범주에 해당하도록 값을 변경하는 것을 의미합니다.

 

TIP
 
 

clamp라는 영어 자체가 "좀쇠로 고정시키다.", "꽉 물다[잡다]"와 같은 뜻이라는 점을 생각해보았을 때 그 해당 범주 값 이하 또는 이상이 되지 않도록 꽉 물어놓은 곳이라고 생각하면 편하실 수 있습니다.

 

조금 더 쉽게 이해하기 위해 예를 들어 아래와 같은 tensor가 존재한다고 가정해봅시다.

tensor([-2.5, 3, -1, 10])

이 때 torch.clamp 함수를 사용하여 min=0.5라고 한다면 최소가 0.5가 되도록 그 이하의 값들을 교체합니다.

tensor([0.5, 3, 0.5, 10])

위와 같이 -2.5는 min=0.5보다 작으므로 해당 값을 0.5로 변경해줍니다.

마찬가지로 -1 역시 0.5보다 작으므로 0.5로 변환됩니다.

 

max는 반대로 생각하시면 쉽겠죠.

사용법 3가지 

1) input, min, max 모두 넣는 방법

torch.clamp(input, min, max)

 

해당 방법은 위에 그림에 해당하는 부분으로

input(tensor)

min(number : 실수 또는 정수)

max(number : 실수 또는 정수)

를 넣으면 됩니다.

아래에 코드 예가 있으니 확인해주세요.

 

2) input, min만 넣는 방법

torch.clamp(input, min)

 

해당 방법은 input과 min에 적합한 값을 넣어주면 됩니다.

아래에 코드 예가 있으니 확인해주세요.

 

3) input, max만 넣는 방법

torch.clamp(input, min)

 

해당 방법은 당연히 input과 max에 적합한 값을 넣어주면 됩니다.

아래에 코드 예가 있으니 확인해주세요.

예시

코드 예 1

import torch

a = torch.randn(6)
print(a)
print(torch.clamp(a, min=-0.5, max=0.5))
print(torch.clamp(a, min=0.5))
print(torch.clamp(a, max=0.5))

 

코드 예 2

import torch

a = torch.randn(2,3)
print(a)
print(torch.clamp(a, min=-0.5, max=0.5))
print(torch.clamp(a, min=0.5))
print(torch.clamp(a, max=0.5))

 

 

TIP
 
 

단순히 input과 숫자만 넣었을 때 즉, torch.clamp(a, 0.5)와 같이 넣었을 때 0.5는 min과 max 어떤 것에 해당하는 숫자일까요? 정답은 min입니다.

 

여기까지 torch.clamp에 대해서 알아보았습니다. 

해당 clamp 함수로 activation function ReLU 함수를 구현할 수 있어 예제로 많이 사용됨을 확인할 수 있습니다. 한 번 해보시는 것은 어떤가요?

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading