이제 학습된 가중치를 가져와서 직접 ROI 영역을 cv2 함수를 사용해 detection 해 보겠다. 내 노트북에서는 계정 이름이 한글로 설정되어 있어 오류가 나오는것 같다. 코드는 다음과 같다.
import numpy as np
import cv2
import tensorflow as tf
from tkinter import *
from tkinter import filedialog
import tkinter.scrolledtext as tkst
min_confidence = 0.5
width = 800
height = 0
show_ratio = 1.0
file_name = "./custom/fruit10.jpg"
classes_name = "./custom/classes.txt"
weight_name = "./model/fruit_custom.h5"
title_name = 'Tenserflow Custom Data Prediction'
classes = []
colors = [(0, 255, 0), (0, 0, 255)]
read_image = None
CW = 32
CH = 32
CD = 3
# Load TF model, classes
model = tf.keras.models.load_model(weight_name)
with open(classes_name, 'r') as txt:
for line in txt:
name = line.replace("\n", "")
classes.append(name)
def selectWeightFile():
global weight_name
global model
weight_name = filedialog.askopenfilename(initialdir = "./model",title = "Select Model file",filetypes = (("keras model files","*.h5"),("all files","*.*")))
weight_path['text'] = weight_name
model = tf.keras.models.load_model(weight_name)
def selectClassesFile():
global classes_name
global classes
classes_name = filedialog.askopenfilename(initialdir = "./custom",title = "Select Classes file",filetypes = (("text files","*.txt"),("all files","*.*")))
classes_path['text'] = classes_name
classes = []
with open(classes_name, 'r') as txt:
for line in txt:
name = line.replace("\n", "")
classes.append(name)
def selectFile():
global read_image
file_name = filedialog.askopenfilename(initialdir = "./",title = "Select file",filetypes = (("jpeg files","*.jpg"),("all files","*.*")))
read_image = cv2.imread(file_name)
file_path['text'] = file_name
detectAndDisplay()
def detectAndDisplay():
global read_image
global classes
test_images = []
h, w = read_image.shape[:2]
height = int(h * width / w)
img = cv2.resize(read_image, (width, height))
box = cv2.selectROI("Select Resign Of interest and Press Enter or Space key", img, fromCenter=False,
showCrosshair=True)
startX = int(box[0])
startY = int(box[1])
endX = int(box[0]+box[2])
endY = int(box[1]+box[3])
image = cv2.resize(img[startY:endY, startX:endX]
, (CW,CH), interpolation = cv2.INTER_AREA)
test_images.append(image)
# convert the data and labels to NumPy arrays
test_images = np.array(test_images)
# scale data to the range of [0, 1]
test_images = test_images.astype("float32") / 255.0
result = model.predict(test_images)
result_number = np.argmax(result[0])
print(result, result_number)
print("%s : %.2f %2s" % (classes[result_number], result[0][result_number]*100, '%'))
text = "{}: {}%".format(classes[result_number], round(result[0][result_number]*100,2))
y = startY - 10 if startY - 10 > 10 else startY + 10
cv2.rectangle(img, (startX, startY), (endX, endY),
colors[result_number], 2)
cv2.putText(img, text, (startX, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[result_number], 2)
for i in range(len(classes)):
print("%10s %15.7f %2s" % (classes[i], result[0][i]*100, '%'))
text = "{}: {}".format(classes[i], result[0][i])
cv2.putText(img, text, (10, height - ((i * 20) + 20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[i], 2)
img[0:CH, 0:CW] = image
cv2.imshow("Prediction Output", img)
cv2.waitKey(0)
main = Tk()
main.title(title_name)
main.geometry()
label=Label(main, text=title_name)
label.config(font=("Courier", 18))
label.grid(row=0,column=0,columnspan=4)
weight_title = Label(main, text='Weight')
weight_title.grid(row=1,column=0,columnspan=1)
weight_path = Label(main, text=weight_name)
weight_path.grid(row=1,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectWeightFile()).grid(row=1, column=3, columnspan=1, sticky=(N, S, W, E))
classes_title = Label(main, text='Classes')
classes_title.grid(row=2,column=0,columnspan=1)
classes_path = Label(main, text=classes_name)
classes_path.grid(row=2,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectClassesFile()).grid(row=2, column=3, columnspan=1, sticky=(N, S, W, E))
file_title = Label(main, text='Image')
file_title.grid(row=3,column=0,columnspan=1)
file_path = Label(main, text=file_name)
file_path.grid(row=3,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectFile()).grid(row=3, column=3, columnspan=1, sticky=(N, S, W, E))
log_ScrolledText = tkst.ScrolledText(main, height=20)
log_ScrolledText.grid(row=4,column=0,columnspan=4, sticky=(N, S, W, E))
log_ScrolledText.configure(font='TkFixedFont')
log_ScrolledText.tag_config('HEADER', foreground='gray', font=("Helvetica", 14))
log_ScrolledText.tag_config('TITLE', foreground='orange', font=("Helvetica", 18), underline=1, justify='center')
log_ScrolledText.insert(END, '\n\n1. Please select an Image\n2. Drag an Object\n3. Press Enter or Space key\n', 'TITLE')
main.mainloop()
기능은 이전 custom-yolo GUI와 비슷하다.
'Object Detection' 카테고리의 다른 글
직접 쓴 손 글씨(숫자) 인식하기 - 글씨 추출 및 검출 (0) | 2021.08.27 |
---|---|
직접 쓴 손 글씨(숫자) 인식하기 - 필터링을 통한 인식률 향상시키는 방법 (0) | 2021.08.26 |
object detection project - keras를 사용한 detection(1) (0) | 2021.08.24 |
Custom YOLO project - custom data 학습 (0) | 2021.08.21 |
Custom YOLO project - 코랩에 darknet 불러오기 (0) | 2021.08.21 |