欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

图像分割的tf-serving

时间:2023-07-14

记录一下使用tensorflow-serving部署图像分割的过程

一、将h5权重文件转成saved_model可以部署的模型

changeH5tosavedModel.py

import tensorflow as tffrom nets.unet import Unet as unetif __name__ == '__main__': model = unet((512, 512, 3), 2, 'vgg') model.load_weights('EP100-loss0.196-valoss0.284.h5') tf.saved_model.save(model, "test/1")

二、利用docker开启tensorflow serving服务

docker run -p 8501:8501 --mount type=bind,source=E:projectFilesstandardunetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving

gpu(目前只在linux下测试了,因为win10似乎安装不能nvidia-docker):

首先安装必要的东西:

docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi

然后拉取tensorflow-serving gpu镜像:

docker pull tensorflow/serving:latest-gpu

最后开启模型服务

docker run --gpus all -p 8501:8501 --mount type=bind,source=/home/hbli/pythonFiles/unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving:latest-gpu

MODEL_NAME是自己定的,target最后的unetV1的名字和MODEL_NAME一致,source是被部署的模型所在的文件夹。其他都一样。

三、客户端进行访问

httpClient.py

""" 图像分割的serving """import cv2import numpy as npimport requestsimport jsonimport timefrom PIL import Imageimport colorsysimport matplotlib.pyplot as pltimport osdef resize_image(image, size): """ 等比例resize """ iw, ih = image.size w, h = size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) return new_image, nw, nhdef preprocess_input(image): image = image / 127.5 - 1 return imageinput_shape = (512,512) # 与训练的时候一致num_classes = 2 # 类别+1def preProcessing(filepath): inputs = cv2.imread(filepath) old_img = Image.open(filepath) h,w = inputs.shape[0],inputs.shape[1] # print(f'初始图像size: {h},{w}') """ 数据预处理 """ image_data, nw, nh = resize_image(old_img, (input_shape[1], input_shape[0])) image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0) return old_img,(h,w),(nw,nh),image_datadef mainProcess(): start = time.time() ####--------------------------核心代码----------------------------------------#### """ REST API端口 """ url = 'http://localhost:8501/v1/models/unetV1:predict' data = json.dumps({'inputs':image_data.tolist()}) # 要求输入的数据是json格式 response = requests.post(url,data=data) result = json.loads(response.content) outputs = result['outputs'][0] output_array = np.array(outputs) # list转numpy数组 ####--------------------------核心代码---------------------------------------#### print(f'花费时间:{time.time()-start:.2f}s') # print(type(output_array)) return output_arraydef postProcessing(): """ 对预测结果进行后处理 """ # resize回图像原始的大小 pr = cv2.resize(output_array, (w, h), interpolation = cv2.INTER_LINEAR) pr = pr.argmax(axis=-1) # 取出每一个像素点的种类 seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) if num_classes <= 21: colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)] else: hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)] colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) for c in range(num_classes): seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8') seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8') seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8') resultImage = Image.fromarray(np.uint8(seg_img)) image = Image.blend(old_img,resultImage,0.7) return imagedef saveAndShow(image): savename = os.path.basename(filepath)[:-4]+"httpResult.jpg" savePath = 'servingOut/' if not os.path.exists(savePath): os.mkdir(savePath) image.save(savePath+savename) plt.title(os.path.basename(filepath)) plt.imshow(image) plt.show() if __name__ == '__main__': while True: try: filepath = input('请输入待预测图像路径(输入c退出): ') if filepath == 'c': break old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath) output_array = mainProcess() image = postProcessing() saveAndShow(image) except Exception as e: print(e) continue

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。