欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

pytorch学习之Dataset类

时间:2023-06-01

一、

from torch.utils.data import Datasetfrom PIL import Imageimport osclass MyData(Dataset): def __init__(self, root_dir, label_dir): # 提供一个全局变量 self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir,self.label_dir) self.img_path = os.listdir(self.path) def __getitem__(self,idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) img = Image.open(img_item_path) label = self.label_dir return img,label def __len__(self): return len(self.img_path)root_dir = r"Dataset_data/train"ants_label_dir = "ants"bees_label_dir = "bees"ants_dataset = MyData(root_dir,ants_label_dir)bees_dataset = MyData(root_dir,bees_label_dir)

 

例子1:
from PIL import Image
path = "d:a"                   #获取图片的地址
img = Image.open(path)
img.size                         #尺寸
Img.show()                #显示数据,此处Image.show(Img)错误

例子2:import os
想获取图片的地址
1.获取所有图片的地址的列表list
2.通过相应的索引获取图片的地址
dir_path = "dataset/train/ants"
import os
img_path_list = os.listdir(dir_path)

想获取所有图片的地址
import os

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。