欢迎光临散文网 会员登陆 & 注册

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

2023-06-30 16:16 作者:没睡午觉可真不行  | 我要投稿

'''

# -*-coding:utf-8 -*-

# author:sakia


from torch.utils.data import Dataset

import CV2

from PIL import Image

import os



class MyData(Dataset):


  # 提供全局变量。,这里把label_dir改成相应的类的绝对路径label_dir

  # 比如 root_dir : "dataset/train" , label_dir = "ants_image"

  def __init__(self, root_dir, img_dir, label_dir):

    self.root_dir = root_dir

    self.img_dir = img_dir

    self.label_dir = label_dir

    self.path = os.path.join(self.root_dir, self.img_dir)

    self.lpath = os.path.join(self.root_dir, self.label_dir)

    # 获取所有图片

    self.img_path = os.listdir(self.path)

    self.label_path = os.listdir(self.lpath)


  # 获取每一个图片

  def __getitem__(self, idx):

    img_name = self.img_path[idx]

    img_item_path = os.path.join(self.root_dir, self.img_dir, img_name)

    img = Image.open(img_item_path)

    label_name = self.label_path[idx]

    label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)

    with open(label_item_path, 'r') as f:

      label = f.read().strip()

    return img, label

   

  # 列表长度

  def __len__(self):

    return len(self.img_path)



root_dir = "练手数据集\\train"

ants_img_dir = "ants_image"

ants_label_dir = "ants_label"

bees_img_dir = "bees_image"

bees_label_dir = "bees_label"

#

ants_dataset = MyData(root_dir, ants_img_dir, ants_label_dir)

bees_dataset = MyData(root_dir, bees_img_dir, bees_label_dir)


# 手工创建数据集可以创建这个方法有用

train_dataset = ants_dataset + bees_dataset


img, label = train_dataset[1]


p7 标签和图像分开的代码

'''

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】的评论 (共 条)

分享到微博请遵守国家法律