분명 torch.save를 통해 model, optimizer를 저장했고 torch.load를 통해 불러오면 될 것이다라는 개념으로 접근 했지만 갑작스레 위와 같은 문제를 겪으실 분들을 위해 기록을 남기고자 합니다.
다를 수 있지만 저에게 있어 해당 문제는 다중 GPU를 사용하고자 torch.nn.DataParallel를 사용하면서 생긴 문제라고 판단했습니다. nn.DataParallel 모델로 감싸진 모델을 저장한 상태는 불러올 때 model.module 형태로 가져옵니다. 이는 위에 나와있는 것처럼 module 형태로 불러오는 것을 확인할 수 있습니다.(위 이미지 Unexpected key를 확인해보세요.)
조금 더 풀어서 설명하면 1개의 GPU에서 단순히 torch.save, torch.load를 진행한다면 크게 문제없이 진행될 것입니다. 그러나 훈련할 때는 Multiple-GPU로 진행하고 다시 불러올 때 1개의 GPU device인 상태에서 load하고자 할 때 문제가 생깁니다.
저장을 잘하면 사실 어렵게 해결책을 적을 필요도 없습니다. 때문에 다시 저장하고 불러오는 방식 또한 하나의 해결책이 될 것 같아 적어봤습니다.
# original method to save the model parameters
torch.save(model.state_dict(), '~.pt')
# suggested method to save the model parameters
torch.save(model.module.state_dict(), '~.pt')
일반적으로 torch.save(model.state_dict(), '~.pt') 방식을 사용합니다만 해당 방법은 제목과 같은 에러를 일으키고 문제가 생기기에 저장부터 잘해봅시다.
# save parameters
if isinstance(G, nn.DataParallel):
torch.save(G.module.state_dict(), model_save_name)
else:
torch.save(G.state_dict(), model_save_name)
#loading pretrained parameters
if isinstance(G, nn.DataParallel):
G.module.load_state_dict(state_dict)
else:
G.load_state_dict(state_dict)
optimizer와 scheduler는 그냥 state_dict()와 load_state_dict()를 저장, 사용함
해당 소스는 아래 링크에 남겨두겠습니다.
같은 환경임에도 실행되지 않는다면 코드 순서의 문제일 경우가 큽니다. 이 때는 단순히 아래와 같은 순서로 바꿔서 진행해주시면 됩니다.
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
Mymodel = model(...).to(device)
optimizer = torch.optim.Adam(...)
Mymodel = nn.DataParallel(Mymodel, ...)
Mymodel, optimizer = torch.load_state_dict(torch.load( ... ))
model, optimizer를 먼저 정의하고 다중 GPU 설정 또한 진행한 후 load를 실행하는 것입니다.
가능하다면 저장한 환경과 동일한 곳에서 돌리는 것이 사실 가장 쉬운 방법일 것입니다. 때문에 사실 해당 내용에 대해서는 뭐라 드릴 말씀이 없을 것 같습니다. 다중 GPU를 사용하여 torch.save를 진행했기 때문에 해당 GPU의 개수를 맞춰주면 됩니다.
보통은 CPU 혹은 특정 GPU로 load하기 위해 사용하는 방법인데 해당 device로 불러오는 방식으로 load하는 것이 또 하나의 해결책이 될 수 있습니다. 시도해보시길 바랍니다.
tutorials.pytorch.kr/beginner/saving_loading_models.html
보통 tutorial에서는 strict=False 설정과 관련하여 모델을 부분적으로 불러오거나 일부 불러오는 전이학습에서 사용한다고 되어 있습니다. 그러나 위 방법에서는 간혹 해당 방법으로 모델만 불러올 수 있어 사용이 가능한 방법이라 생각됩니다.
또 다른 해결책으로는 .module key를 제거하는 방법으로 아래 링크에서 그런 언급을 하는데 이 부분은 어떻게 진행해야할지 저도 잘 모르겠습니다.
해당 문제에 대한 다른 분들의 해결책을 공유해주시면 감사하겠습니다.