아이공의 AI 공부 도전기

Pytorch MNIST에서의 다음 Error 해결법  : RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

2020.05.10일 기준

MNIST를 불러와서 dataloader에 넣고 불러올 때 위 에러가 뜰 때가 있다.

어디에서 문제가 생긴 것일까?

우리의 문제는 왜 발생했을까

torchvision.dataset에서 다운 받은 MNIST 혹은 어디에서 받은 MNIST든 gray dimension(channel)인 1을 가지는 값을 가질 것이다.

그래서 결과적으로 dataset의 shape 혹은 size는 60000,28,28이 될 것이다.

즉, gray dimension은 무시.

(만약 색이 있는 이미지 파일이라면 60000,3,28,28의 형태가 되었을 것이다.)

이것을 바탕으로 생각하면 문제의 발생을 이해하기 쉽다.


그 이유는 대체로 dataset을 load할 때 torchvision.transforms 함수의 normalize에서 실수를 했기 때문이다.

앞에서 이야기한대로 MNIST는 gray scale만 가지는 값이지만 


torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))


이렇게 설정했을 가능성이 높다.

3 dimension(빛의 3원색)으로 색이 있는 이미지 파일을 나타내는 normalize를 한 것!


해결책

때문에 우리는 아래와 같이 바꿔준다면 해결될 것이다.


torchvision.transforms.Normalize((0.5,),(0.5,))


즉 gray scale만을 normalize하겠어요 라는 의미이다.


그렇게 된다면 후에 dataloader를 통해 불러올 때 대체로 해결된다.

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading