import torchfrom torch.utils.data import Datasetimport PILfrom PIL import Imageimport os
二、类的属性def __init__(self,root_dir,label_dir): self.root_dir=root_dir self.label_dir=label_dir self.path=os.path.join(self.root_dir,self.label_dir) self.img_path=os.listdir(self.path)
其中:
root_dir是根目录,可以是绝对路径,也可以是相对路径。
label_dir是标签
self.path=os.path.join(self.root_dir,self.label_dir)的作用是把根目录和标签拼在一起,所形成的就是一个保存有训练所需数据的文件夹的目录。
def __getitem__(self, index): img_name=self.img_path[index] img_item_path=os.path.join(self.root_dir,self.label_dir,img_name) img=Image.open(img_item_path) label=self.label_dir return img,label def __len__(self): return len(self.img_path)
其中:
__getitem__返回的是img和label,img是图像本身,label是标签。
①定义一个数据集:
root_dir="D:\Pycharm\pythonProject7\hymenoptera_data\train"ants_label_dir="ants"bees_label_dir="bees"ants_dataset=MyData(root_dir,ants_label_dir)bees_dataset=MyData(root_dir,bees_label_dir)
②数据集之间可以相加:
train_dataset=ants_dataset+bees_dataset
假设:ants_dataset有124项,则train_dataset[0]~train_dataset[123]都是ants_dataset的数据,123往后才是bees_dataset的数据。
③img=train_dataset[123]与img,label=train_dataset[123]的区别:
img=train_dataset[123]
执行完此语句,img是:
(
,
‘ants’)
此时的img并不是一张图片,而是图片和标签
img,label=train_dataset[123]
而执行完后者,img是: