AI 공부 도전기

[Pytorch] make_grid 사용해서 이미지 살펴보기(torchvision.utils.make_grid)

 

     

개요

일반적으로 pytorch에서 Neural Network를 사용하여 이미지를 훈련시킬 때 중간중간의 결과가 어떻게 나오는지 확인하고 싶은 욕구가 생깁니다. 이와 관련하여 사용할 수 있는 함수가 바로 make_grid입니다. 정확히는 torchvision.utils.make_grid 함수를 통해 확인할 수 있습니다. 해당 글은 바로 그런 내용에 대해 살펴보도록 하겠습니다.

 

참고로 단순히 이미지를 여는 방법에 대해서 알고 싶으시다면 아래 링크를 확인해주세요.

aigong.tistory.com/182

 

Pytorch에서 이미지 여는 방법(PIL, matplotlib, torchvision)

Pytorch에서 이미지 여는 방법 목차 개요 단순히 어떤 폴더에 있는 이미지를 열고자 한다면 PIL을 사용하여 show 하는 방식이 가장 보편적인 방법일 것입니다. import PIL img = PIL.Image.open('./tree.jpg')..

aigong.tistory.com

make_grid 함수 살펴보기

 

해당 함수에서 가장 중요한 것은 바로 batch를 포함한 4D 이미지 tensor 혹은 같은 사이즈의 이미지 리스트를 반드시 포함해야 한다는 점입니다. 나머지 nrow, padding, normalize, range, scale_each, pad_value는 상황에 따라 자신의 취향에 맞게 사용하시면 됩니다.

 

필수적인 4D인 이유는 앞에서 설명드렸던 것처럼 여러개의 이미지를 훈련시키는 와중에 batch가 포함된 dataloader image tensor들이 있는 상황에서 도중에 보고자 할 때 사용하기 위한 목적성을 가진 함수이기 때문입니다. 물론 3D(channel x height x width) 이미지를 보고자 한다면 일부 수정을 통해 확인합니다.

 

FashionMNIST를 활용한 예

import torch
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataset as Dataset

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

trainset = torchvision.datasets.FashionMNIST('./', download=True, train=True,transform=tf)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5, shuffle=True)

for data in trainloader:
    img, label = data    
    plt.imshow(torchvision.utils.make_grid(img, normalize=True).permute(1,2,0))
    plt.show()
    break

torch가 가지고 있는 FashionMNIST를 다운받아 make_grid를 통해 grid tensor를 만들고 이를 matplotlib를 통해 출력하는 코드를 짜 봤습니다. 아래 보이시는 바와 같이 grid 격자 형태로 여러 개의 image가 보임을 확인할 수 있습니다. 이와 같이 훈련 도중 이미지가 잘 학습되고 있는지 원하는 이미지가 들어가고 있는지를 확인할 수 있습니다.

임의의 dataset을 활용한 예

 

import torch
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataset as Dataset

path = './dataset/pixabay'

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

image_train_folder = torchvision.datasets.ImageFolder(path, transform=tf)
trainloader = torch.utils.data.DataLoader(image_train_folder, batch_size=5, shuffle=True)

for data in trainloader:
    img, label = data
    plt.imshow(torchvision.utils.make_grid(img, normalize=True).permute(1,2,0))
    plt.show()
    break

임의의 사진들을 모은 Dataset을 사용하여 결과를 출력할 때도 마찬가지로 사용합니다. 해당 내용의 범주는 넘어서지만 torchvision.datasets.ImageFolder를 사용하여 해당 폴더 내 이미지들을 읽어내고 읽어낸 것을 토대로 이미지를 grid 형태로 사용합니다.

이미지의 사이즈는 모두 1920x1280이지만 잘 출력됨을 확인할 수 있습니다.

1920x1280 Image

Optional Parameter

1) nrows

이제 나머지 parameter들에 대해서 살펴보겠습니다. nrows는 각 행에 몇 개의 grid가 표현되기를 원하냐는 것을 묻는 parameter입니다. 즉 column의 개수를 물어보는 것입니다. 저는 6개의 batch 중 nrows=2로 해보겠습니다. 제가 코드를 아래와 같이 바꿨을 때 결과는 어떻게 나올까요

trainloader = torch.utils.data.DataLoader(image_train_folder, batch_size=6)

for data in trainloader:
    img, label = data
    plt.imshow(torchvision.utils.make_grid(img, nrow=2, normalize=True).permute(1, 2, 0))
    plt.show()
    break

앞에서 설명한 것과 같이 열의 개수라고 생각되는 결과를 얻을 수 있습니다.

 

2) padding

 

이미지와 이미지 사이의 padding은 default로 2가 설정되어 있습니다. 즉 이미지 테두리에 2pixel의 둘레가 검은색 선으로 있다고 가정하시면 됩니다. 만약 padding을 100으로 늘리면 어떻게 될까요

보이는 바와 같이 두께가 더 두꺼워 지는 것을 확인할 수 있습니다. 다만 해당 이미지의 사이즈가 변형되지는 않습니다. 때문에 이를 잘 활용하시는 것이 좋을 것 같습니다. 또한 저는 이미지가 크므로 padding을 100으로 설정해야만 저렇게 뚜렷하게 보이지만 작은 이미지를 다루신다면 약간의 padding 변화만으로도 큰 변화를 얻으실 수 있습니다.

 

3) pad_value

torchvision.utils.make_grid(img, nrow=2, padding=20, pad_value=0.5, normalize=True).permute(1, 2, 0)

pad_value는 pad 색상과 관련된 부분으로 0이 default로 설정되어 검은색이 나타납니다. 일반적으로 0~1 사이의 값을 통해 검은색~흰색 사이의 값을 나타냅니다. 저는 0.5를 설정해보았습니다. 해당 그림은 아래와 같습니다.

pad에 약간의 회색 빛이 나는 것이 보이나요

사실 그렇게 큰 영향력을 끼치는 것 같지않아보입니다.

 

4) normalize, range, scale_each

 

normalize, range, scale_each의 경우 해당 이미지의 normalize를 다시 복구 시킴으로써 0~1 사이의 pixel value를 가지도록 혹은 min, max 값을 가지도록 설정하는 구간으로 사용자의 편의에 따라 사용할 수 있습니다. 저의 경우 그냥 normalize를 많이 사용하는 편입니다.

 

해당 글을 보시면서 이해가 잘 되었는지 잘 모르겠습니다.

이 글들이 여러분들께 도움이 되었으면 하는 바람으로 글을 마칩니다.

 

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading