참조 URL : discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7
1)
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
model_a = model()
get_n_params(model_a)
ex) p.size() = torch.Size([64, 3, 3, 3])
nn = 64*3*3*3
pp += nn
for문 돌기
2)
# count trainable model parameters only
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# count total model parameters including trainable and non-trainable
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
model_a = model()
count_parameters(model_a)