使用PaddleX对大量图片进行分类(仅包含预测的内容)
# -*- coding: UTF-8 -*- import os import cv2 from shutil import copyfile import numpy as np import paddlex as pdx #处理中文路径 import importlib,sys importlib.reload(sys) #Paddle加载模型 model = pdx.load_model('./inference_model') def paddle_predict(source_path,result_path,threshold_value): #对源路径的图片进行推理 #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下 isExists=os.path.exists(result_path)#判断目标文件夹是否存在 if not isExists: os.makedirs(result_path) for filename in os.listdir(source_path): try: im = cv2.imdecode(np.fromfile(source_path+'/'+filename,dtype=np.uint8),-1) im = im.astype('float32') result = model.predict(im) print(result) isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在 if not isExists: os.makedirs(result_path+'/'+result[0]['category']) if(result[0]['score']>threshold_value): copyfile(source_path+'/'+filename, result_path+'/'+result[0]['category']+'/'+filename) #os.remove(source_path+'/'+filename) except: print('ERROR:'+filename) paddle_predict("C:/test","C:/output_new",0.9)
首先使用PaddleX进行训练,然后使用该Python程序对大量图片进行分类。
本程序支持中文路径和错误处理功能,能够稳定用于生产用途。
另附上支持对源路径的子文件夹进行遍历的程序
# -*- coding: UTF-8 -*- import os import cv2 from shutil import copyfile import numpy as np import paddlex as pdx #处理中文路径 import importlib,sys importlib.reload(sys) #Paddle加载模型 model = pdx.load_model('./inference_model') def paddle_predict(source_path,result_path,threshold_value): #对源路径的图片进行推理 #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下 isExists=os.path.exists(result_path)#判断目标文件夹是否存在 if not isExists: os.makedirs(result_path) for dirpath, dirnames, filenames in os.walk(source_path): for filename in filenames: try: im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1) im = im.astype('float32') result = model.predict(im) print(os.path.join(dirpath, filename)) isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在 if not isExists: os.makedirs(result_path+'/'+result[0]['category']) if(result[0]['score']>threshold_value): copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename) #os.remove(os.path.join(dirpath, filename)) except: print('ERROR:'+os.path.join(dirpath, filename)) paddle_predict("C:/test","C:/output_new",0.9)
然后为了方便展示效果,进一步优化
# -*- coding: UTF-8 -*- import os import cv2 import time import threading from shutil import copyfile import numpy as np import paddlex as pdx #处理中文路径 import importlib,sys importlib.reload(sys) TEMP_SEC_CALC=0 TEMP_SEC_CALC_ERROR=0 TEMP_SUM=0 TEMP_SUM_DONE=0 model = pdx.load_model('./inference_model') #Paddle加载模型 def time_convert(seconds): seconds = seconds % (24 * 3600) hour = seconds // 3600 seconds %= 3600 minutes = seconds // 60 seconds %= 60 return "%02d:%02d:%02d" % (hour, minutes, seconds) def paddle_predict_calc(): global TEMP_SEC_CALC global TEMP_SEC_CALC_ERROR global TEMP_SUM global TEMP_SUM_DONE while TEMP_SUM!=TEMP_SUM_DONE: if TEMP_SEC_CALC: print('FPS:'+str(TEMP_SEC_CALC)+' Remain:'+time_convert((TEMP_SUM-TEMP_SUM_DONE)/TEMP_SEC_CALC) + ' ERROR:'+str(TEMP_SEC_CALC_ERROR)) TEMP_SEC_CALC = 0 time.sleep(1) def paddle_predict(source_path,result_path,threshold_value): #对源路径的图片进行推理 #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下 global TEMP_SEC_CALC_ERROR global TEMP_SUM_DONE global TEMP_SEC_CALC isExists=os.path.exists(result_path)#判断目标文件夹是否存在 if not isExists: os.makedirs(result_path) for dirpath, dirnames, filenames in os.walk(source_path): for filename in filenames: try: im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1) im = im.astype('float32') result = model.predict(im) #print(os.path.join(dirpath, filename)) isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在 if not isExists: os.makedirs(result_path+'/'+result[0]['category']) if(result[0]['score']>threshold_value): copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename) #os.remove(os.path.join(dirpath, filename)) except: #print('ERROR:'+os.path.join(dirpath, filename)) TEMP_SEC_CALC_ERROR = TEMP_SEC_CALC_ERROR + 1 TEMP_SUM_DONE = TEMP_SUM_DONE + 1 TEMP_SEC_CALC = TEMP_SEC_CALC + 1 def predict(source_path,result_path,threshold_value): #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下 global TEMP_SUM for dirpath, dirnames, filenames in os.walk(source_path): for filename in filenames: TEMP_SUM = TEMP_SUM + 1 print('总共发现了:'+str(TEMP_SUM)+'个文件!') time.sleep(2) main_func = threading.Thread(target=paddle_predict, args=(source_path,result_path,threshold_value)) calc_func = threading.Thread(target=paddle_predict_calc) main_func.start() calc_func.start() predict("C:/test","C:/output_new1",0.9)
运行效果
为了方便对网页图片进行判断,修改了一个精简版的
# -*- coding: UTF-8 -*- import cv2 import paddlex as pdx # 处理中文路径 import importlib import sys importlib.reload(sys) Paddle_Func = pdx.load_model('./inference_model') # Paddle加载模型 def Paddle_Url_Predit(Pic_Url): try: cap = cv2.VideoCapture(Pic_Url) if(cap.isOpened()): ret, im = cap.read() im = im.astype('float32') result = Paddle_Func.predict(im) print(result) else: print('Download Failure!') except: print('Unknown Error!') Paddle_Url_Predit('图片url地址')
进一步添加Socket传输图片的url地址
服务端:
# -*- coding: UTF-8 -*- import cv2 import paddlex as pdx # 处理中文路径 import importlib import sys importlib.reload(sys) Paddle_Func = pdx.load_model('./inference_model') # Paddle加载模型 def Paddle_Url_Predit(Pic_Url): try: cap = cv2.VideoCapture(Pic_Url) if(cap.isOpened()): ret, im = cap.read() im = im.astype('float32') result = Paddle_Func.predict(im) return(result) else: return('Download Failure!') except: return('Unknown Error!') import os import stat import socket # 创建服务器端套接字 sk = socket.socket() sk.bind(('127.0.0.1', 8898)) sk.listen() conn, addr = sk.accept() while True: ret = conn.recv(1024) # 打印客户端信息 Socket_rst = Paddle_Url_Predit(ret.decode('utf-8')) print(Socket_rst) try: conn.send(bytes(str(Socket_rst), encoding='utf-8')) except: print('Connect Error!') # 关闭客户端链接 conn.close() # 关闭服务器套接字 sk.close()
客户端(Python,参考自https://zhuanlan.zhihu.com/p/279968757):
import socket # 创建客户端套接字 sk = socket.socket() # 尝试连接服务器 sk.connect(('127.0.0.1',8898)) while True: # 信息发送 info = input('>>>') sk.send(bytes(info,encoding='utf-8')) # 信息接收 ret = sk.recv(1024) # 结束会话 if ret == b'bye': sk.send(b'bye') break # 信息打印 print(ret.decode('utf-8')) # 关闭客户端套接字 sk.close()
客户端(PHP):
待续