아이공의 AI 공부 도전기

[Pytorch] model parameter 세는 방법

 

참조 URL : discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7

 

How do I check the number of parameters of a model?

def get_n_params(model): pp=0 for p in list(model.parameters()): nn=1 for s in list(p.size()): nn = nn*s pp += nn return pp

discuss.pytorch.org

1) 

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
    
model_a = model()
get_n_params(model_a)

ex) p.size() = torch.Size([64, 3, 3, 3])

nn = 64*3*3*3

pp += nn

for문 돌기

 

2)

# count trainable model parameters only
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# count total model parameters including trainable and non-trainable
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
    
model_a = model()
count_parameters(model_a)

 

 

 

 

 

 

 

 

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading