Error
Flask에서 model을 불러오려고 하니 다음과 같이 error가 발생하였다.
1
sentimentModel = torch.load('./model.pt', map_location='cpu')
Error 해결
torch.save(model, 'model.pt')
로 저장하여 model 전체를 불러오려고 하니 어떤 모델인지 정의 되어 있지 않아 에러가 나는거 같았다.torch.save(model.state_dict(), 'model.pt')
로 저장한 후 모델을 정의하고 정의된 모델에 load를 하여 가중치와 편향을 부여하는 방법으로 해결하였다.1 2 3
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) model.load_state_dict(torch.load('model_state_dict.pt', map_location='cpu'))