温馨提示×

CentOS上PyTorch的模型保存与加载方法

小樊
67
2025-03-07 20:28:18
栏目: 智能运维

在CentOS上使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码,帮助你完成这一任务。

安装PyTorch

首先,确保你已经安装了PyTorch。你可以使用以下命令来安装:

pip install torch torchvision 

保存模型

在PyTorch中,你可以使用torch.save()函数来保存模型。以下是一个简单的示例:

import torch import torch.nn as nn # 定义一个简单的神经网络 class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc = nn.Linear(784, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc(x) return x # 创建模型实例 model = SimpleNet() # 假设模型已经训练好 # 保存模型 torch.save(model, 'model.pth') 

加载模型

加载模型时,你可以使用torch.load()函数。以下是一个示例:

# 加载模型 model = torch.load('model.pth') # 如果模型是在不同的设备上训练的(例如GPU),需要将其移动到当前设备 model.to('cpu') # 或者 'cuda' 如果你在GPU上工作 

检查模型是否加载成功

你可以通过前向传播一些数据来检查模型是否加载成功:

# 假设我们有一些输入数据 input_data = torch.randn(1, 1, 28, 28) # 示例输入数据 # 使用加载的模型进行前向传播 output = model(input_data) print(output) 

完整示例

以下是一个完整的示例,包括模型的定义、训练、保存和加载:

import torch import torch.nn as nn import torch.optim as optim # 定义一个简单的神经网络 class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc = nn.Linear(784, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc(x) return x # 创建模型实例 model = SimpleNet() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 假设我们有一些训练数据 inputs = torch.randn(64, 1, 28, 28) labels = torch.randint(0, 10, (64,)) # 训练模型 for epoch in range(5): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item()}') # 保存模型 torch.save(model, 'model.pth') # 加载模型 model = torch.load('model.pth') model.to('cpu') # 或者 'cuda' 如果你在GPU上工作 # 检查模型是否加载成功 output = model(inputs) print(output) 

通过以上步骤,你可以在CentOS上轻松地保存和加载PyTorch模型。希望这些信息对你有所帮助!

0