1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss  = loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i   1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss  = criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct  = pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i   1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch 1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持Devmax。

pytorch实现加载保存查看checkpoint文件的更多相关文章

  1. vue自定义加载指令v-loading占位图指令v-showimg

    这篇文章主要为大家介绍了vue自定义加载指令和v-loading占位图指令v-showimg的示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  2. JavaScript实现动态加载删除表格

    这篇文章主要为大家详细介绍了JavaScript实现动态加载删除表格,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  3. Android自定义View实现圆形加载进度条

    这篇文章主要为大家详细介绍了Android自定义View实现圆形加载进度条,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  4. Android Flutter绘制有趣的 loading加载动画

    在网络速度较慢的场景,一个有趣的加载会提高用户的耐心和对 App 的好感。本篇我们利用Flutter 的 PathMetric来玩几个有趣的 loading 效果,感兴趣的可以动手尝试一下

  5. Python pkg_resources模块动态加载插件实例分析

    当编写应用软件时,我们通常希望程序具有一定的扩展性,额外的功能——甚至所有非核心的功能,都能通过插件实现,具有可插拔性。特别是使用 Python 编写的程序,由于语言本身的动态特性,为我们的插件方案提供了很多种实现方式

  6. 浅谈jQuery双事件多重加载的问题

    下面小编就为大家带来一篇浅谈jQuery双事件多重加载的问题。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  7. Python使用pytorch动手实现LSTM模块

    这篇文章主要介绍了Python使用pytorch动手实现LSTM模块,LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出

  8. Pytorch搭建yolo3目标检测平台实现源码

    这篇文章主要为大家介绍了Pytorch搭建yolo3目标检测平台实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  9. 用ajax动态加载需要的js文件

    这篇文章给大家介绍了用ajax动态加载需要的js文件的相关知识,感兴趣的朋友跟随脚本之家小编一起学习吧

  10. Ajax点击不断加载数据列表

    这篇文章主要介绍了Ajax点击不断加载数据列表的相关资料,需要的朋友可以参考下

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部