欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

数据集划分

时间:2023-06-03

train是训练集,val是训练过程中的测试集,是为了让你在边训练边看到训练的结果,及时判断学习状态。test就是训练模型结束后,用于评价模型结果的测试集。只有train就可以训练,val不是必须的,比例也可以设置很小。

验证数据集可以理解为训练数据集的一块

制作图书馆数据集代码如下:

### Data Format for Semantic Segmentation

The raw data will be processed by generator shell scripts、There will be two subdirs('train' & 'val')

```

train or val dir {

image: contains the images for train or val.

label: contains the label png files(mode='P') for train or val.

mask: contains the mask png files(mode='P') for train or val.

}

```

""" -*- coding: utf-8 -*- author: Hao Hu @date 2022/1/20 11:02 AM"""import cv2import numpy as npfrom matplotlib import pyplot as pltimport os.path as ospimport osfrom tqdm import tqdmfrom PIL import Imageimport PILfrom concurrent.futures import ThreadPoolExecutordef grab_cut(img_path): """使用了grab_cut算法获得物体和背景轮廓""" img_ori = cv2.imread(img_path) # 将img二值化 retVal, image = cv2.threshold(img_ori, 50, 100, cv2.THRESH_BINARY) mask = np.zeros(image.shape[:2], np.uint8) bgdModel = np.zeros((1, 65), np.float64) fgdModel = np.zeros((1, 65), np.float64) ix = int(img_ori.shape[0] / 22) iy = int(img_ori.shape[1] / 20) w = iy * 20 h = ix * 22 rect = (ix, iy, int(w), int(h)) # cv2.rectangle(img, (ix*2, iy*3), (int(w*0.9), int(h*0.9)), (0, 255, 0), 2) # 默认几个点作为物体和背景像素点 # (ix*15,iy*26),(ix*21,iy*15),(ix*21,iy*10)为背景像素点 cv2.circle(mask, (ix*15, iy*26), 15, [0,0,0], -1) cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT) mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8') mask2[ix * 21, iy * 19] = 1 #plt.imshow(mask2), plt.colorbar(), plt.show() img = image * mask2[:, :, np.newaxis] return img,image,mask2,img_oridef get_mask_box(mask): contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours = list(contours) contours.sort(key=lambda x: cv2.contourArea(x), reverse=True) cnt = cv2.approxPolyDP(contours[0], epsilon=100, closed=True) cnt = cv2.minAreaRect(cnt) box = np.int0(cv2.boxPoints(cnt)) return mask, boxdef imwrite_the_label_img(ori_folder,end_folder_path,img_NAME): img_path = osp.join(ori_folder,img_NAME) img,image,mask,img_ori = grab_cut(img_path) _, box=get_mask_box(mask) re = cv2.drawContours(image.copy(), [box], 0, (0, 255, 0), -1) end_path = osp.join(end_folder_path, img_NAME[:-2]+'.png') cv2.imwrite((end_path), re) # 将图片转为model = P re = PIL.Image.open(end_path) re = re.convert('P') re.save(end_path)if __name__ == '__main__': ori_folder = '/cloud_disk/users/huh/dataset/lib_dataset/train/image' img_list = os.listdir(ori_folder) end_folder_path = '/cloud_disk/users/huh/dataset/lib_dataset/train/label' executor = ThreadPoolExecutor(max_workers=100) # 最大线程数量 for img_NAME in tqdm(img_list): executor.map(imwrite_the_label_img(ori_folder,end_folder_path,img_NAME))

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。