아이공의 AI 공부 도전기

[Solution][Pytorch] RuntimeError: Error(s) in loading state_dict for ... :    Missing key(s) in state_dict: ... Unexpected key(s) in state_dict: ...

     

분명 torch.save를 통해 model, optimizer를 저장했고 torch.load를 통해 불러오면 될 것이다라는 개념으로 접근 했지만 갑작스레 위와 같은 문제를 겪으실 분들을 위해 기록을 남기고자 합니다.

문제 이유

다를 수 있지만 저에게 있어 해당 문제는 다중 GPU를 사용하고자 torch.nn.DataParallel를 사용하면서 생긴 문제라고 판단했습니다. nn.DataParallel 모델로 감싸진 모델을 저장한 상태는 불러올 때 model.module 형태로 가져옵니다. 이는 위에 나와있는 것처럼 module 형태로 불러오는 것을 확인할 수 있습니다.(위 이미지 Unexpected key를 확인해보세요.)

 

조금 더 풀어서 설명하면 1개의 GPU에서 단순히 torch.save, torch.load를 진행한다면 크게 문제없이 진행될 것입니다. 그러나 훈련할 때는 Multiple-GPU로 진행하고 다시 불러올 때 1개의 GPU device인 상태에서 load하고자 할 때 문제가 생깁니다. 

해결책

1) torch.save(model.module.state_dict(), '~.pt') 방법을 통해 다시 저장하고 불러오기 

저장을 잘하면 사실 어렵게 해결책을 적을 필요도 없습니다. 때문에 다시 저장하고 불러오는 방식 또한 하나의 해결책이 될 것 같아 적어봤습니다.

# original method to save the model parameters
torch.save(model.state_dict(), '~.pt')

# suggested method to save the model parameters
torch.save(model.module.state_dict(), '~.pt') 

일반적으로 torch.save(model.state_dict(), '~.pt') 방식을 사용합니다만 해당 방법은 제목과 같은 에러를 일으키고 문제가 생기기에 저장부터 잘해봅시다. 

# save parameters
if isinstance(G, nn.DataParallel):
    torch.save(G.module.state_dict(), model_save_name)
else:
    torch.save(G.state_dict(), model_save_name)
    
#loading pretrained parameters    
if isinstance(G, nn.DataParallel):
    G.module.load_state_dict(state_dict)
else:
    G.load_state_dict(state_dict)

optimizer와 scheduler는 그냥 state_dict()와 load_state_dict()를 저장, 사용함

해당 소스는 아래 링크에 남겨두겠습니다.

discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/8

 

Missing keys & unexpected keys in state_dict when loading self trained model

In my case it would not throw any errors anymore, but it wouldn’t correctly load the state_dict either. In the end I just had to do this to just remove the “.module” part.

discuss.pytorch.org

2) 코드 순서 변경 - 같은 GPU 개수 전제 하

같은 환경임에도 실행되지 않는다면 코드 순서의 문제일 경우가 큽니다. 이 때는 단순히 아래와 같은 순서로 바꿔서 진행해주시면 됩니다.

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

Mymodel = model(...).to(device)

optimizer = torch.optim.Adam(...)

Mymodel = nn.DataParallel(Mymodel, ...)

Mymodel, optimizer = torch.load_state_dict(torch.load( ... ))

model, optimizer를 먼저 정의하고 다중 GPU 설정 또한 진행한 후 load를 실행하는 것입니다.

3) GPU 추가 - GPU 개수가 다르다는 전제 하

가능하다면 저장한 환경과 동일한 곳에서 돌리는 것이 사실 가장 쉬운 방법일 것입니다. 때문에 사실 해당 내용에 대해서는 뭐라 드릴 말씀이 없을 것 같습니다. 다중 GPU를 사용하여 torch.save를 진행했기 때문에 해당 GPU의 개수를 맞춰주면 됩니다.

4) torch.load_state_dict(torch.load(PATH, map_location=device) 사용

 보통은 CPU 혹은 특정 GPU로 load하기 위해 사용하는 방법인데 해당 device로 불러오는 방식으로 load하는 것이 또 하나의 해결책이 될 수 있습니다. 시도해보시길 바랍니다. 

tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기 — PyTorch Tutorials 1.6.0 documentation

Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다

tutorials.pytorch.kr

5) torch.load_state_dict(torch.load(PATH), strict=False) 사용

보통 tutorial에서는 strict=False 설정과 관련하여 모델을 부분적으로 불러오거나 일부 불러오는 전이학습에서 사용한다고 되어 있습니다. 그러나 위 방법에서는 간혹 해당 방법으로 모델만 불러올 수 있어 사용이 가능한 방법이라 생각됩니다. 

6) .module key 제거

 또 다른 해결책으로는 .module key를 제거하는 방법으로 아래 링크에서 그런 언급을 하는데 이 부분은 어떻게 진행해야할지 저도 잘 모르겠습니다.

discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/4

 

Missing keys & unexpected keys in state_dict when loading self trained model

Thanks for your suggestions! @ptrblck

discuss.pytorch.org

 

해당 문제에 대한 다른 분들의 해결책을 공유해주시면 감사하겠습니다.

 

 

 

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading