在CentOS上使用PyTorch进行数据处理,首先需要确保系统上安装了合适的Python版本和PyTorch。以下是详细的步骤指南:
sudo yum update -y sudo yum install python3 python3-pip python3 --version wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh bash Miniconda3-latest-Linux-x86_64.sh conda create -n torch_env python=3.8 conda activate torch_env conda install pytorch torchvision torchaudio cpuonly -c pytorch conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 请将11.3替换为你系统上安装的CUDA版本。
python -c "import torch; print(torch.__version__)" import torch from torchvision import datasets, transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) import os import pandas as pd from torchvision.io import read_image from torch.utils.data import Dataset class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label import matplotlib.pyplot as plt labels_map = {0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot"} figure, axes = plt.subplots(3, 3, figsize=(8, 8)) for i in range(1, 9): sample_idx = torch.randint(len(train_data), size=1).item() img, label = train_data[sample_idx] axes[i // 3, i % 3].imshow(img.squeeze(), cmap='gray') axes[i // 3, i % 3].set_title(labels_map[label]) axes[i // 3, i % 3].axis("off") plt.show() from torch.utils.data import DataLoader train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=2) for images, labels in train_dataloader: print(f"Feature batch shape: {images.size()}") print(f"Labels batch shape: {labels.size()}") break 以上步骤涵盖了在CentOS上安装PyTorch以及进行数据处理的基本流程。确保系统环境配置正确,使用合适的命令安装PyTorch,并通过示例代码展示数据处理的基本操作。