pytorch 保存模型和加载模型

首页 / 新闻资讯 / 正文

def save_model(save_dir, phase, name, epoch, f1score, model):     if not os.path.exists(save_dir):         os.mkdir(save_dir)     save_dir = os.path.join(save_dir, args.model)     if not os.path.exists(save_dir):         os.mkdir(save_dir)     save_dir = os.path.join(save_dir, phase)     if not os.path.exists(save_dir):         os.mkdir(save_dir)     state_dict = model.state_dict()     for key in state_dict.keys():         state_dict[key] = state_dict[key].cpu()     state_dict_all = {         'state_dict': state_dict,         'epoch': epoch,         'f1score': f1score,     }     torch.save(state_dict_all, os.path.join(save_dir, '{:s}.ckpt'.format(name)))     if 'best' in name and f1score > 0.3:         torch.save(state_dict_all, os.path.join(save_dir, '{:s}_{:s}.ckpt'.format(name, str(epoch)))) 

pytorch 保存模型

pytorch 加载模型进行继续训练

    if args.resume:         state_dict = torch.load(args.resume)         model.load_state_dict(state_dict['state_dict'])         best_f1score = state_dict['f1score']         start_epoch = state_dict['epoch'] + 1

 

Top