PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

'''
# -*-coding:utf-8 -*-
# author:sakia
from torch.utils.data import Dataset
import CV2
from PIL import Image
import os
class MyData(Dataset):
# 提供全局变量。,这里把label_dir改成相应的类的绝对路径label_dir
# 比如 root_dir : "dataset/train" , label_dir = "ants_image"
def __init__(self, root_dir, img_dir, label_dir):
self.root_dir = root_dir
self.img_dir = img_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.img_dir)
self.lpath = os.path.join(self.root_dir, self.label_dir)
# 获取所有图片
self.img_path = os.listdir(self.path)
self.label_path = os.listdir(self.lpath)
# 获取每一个图片
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.img_dir, img_name)
img = Image.open(img_item_path)
label_name = self.label_path[idx]
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
with open(label_item_path, 'r') as f:
label = f.read().strip()
return img, label
# 列表长度
def __len__(self):
return len(self.img_path)
root_dir = "练手数据集\\train"
ants_img_dir = "ants_image"
ants_label_dir = "ants_label"
bees_img_dir = "bees_image"
bees_label_dir = "bees_label"
#
ants_dataset = MyData(root_dir, ants_img_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_img_dir, bees_label_dir)
# 手工创建数据集可以创建这个方法有用
train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[1]
p7 标签和图像分开的代码
'''