Object Detection

object detection project - keras를 사용한 detection(2)

jwjwvison 2021. 8. 26. 00:09

 이제 학습된 가중치를 가져와서 직접 ROI 영역을 cv2 함수를 사용해 detection 해 보겠다. 내 노트북에서는 계정 이름이 한글로 설정되어 있어 오류가 나오는것 같다. 코드는 다음과 같다.

import numpy as np
import cv2
import tensorflow as tf
from tkinter import *
from tkinter import filedialog
import tkinter.scrolledtext as tkst

min_confidence = 0.5
width = 800
height = 0
show_ratio = 1.0
file_name = "./custom/fruit10.jpg"
classes_name = "./custom/classes.txt"
weight_name = "./model/fruit_custom.h5"
title_name = 'Tenserflow Custom Data Prediction'
classes = []
colors = [(0, 255, 0), (0, 0, 255)]
read_image = None
CW = 32
CH = 32
CD = 3

# Load TF model, classes
model = tf.keras.models.load_model(weight_name)
with open(classes_name, 'r') as txt:
    for line in txt:
        name = line.replace("\n", "")
        classes.append(name)

def selectWeightFile():
    global weight_name
    global model
    weight_name =  filedialog.askopenfilename(initialdir = "./model",title = "Select Model file",filetypes = (("keras model files","*.h5"),("all files","*.*")))
    weight_path['text'] = weight_name
    model = tf.keras.models.load_model(weight_name)

def selectClassesFile():
    global classes_name
    global classes
    classes_name =  filedialog.askopenfilename(initialdir = "./custom",title = "Select Classes file",filetypes = (("text files","*.txt"),("all files","*.*")))
    classes_path['text'] = classes_name
    classes = []
    with open(classes_name, 'r') as txt:
        for line in txt:
            name = line.replace("\n", "")
            classes.append(name)
            
def selectFile():
    global read_image
    file_name =  filedialog.askopenfilename(initialdir = "./",title = "Select file",filetypes = (("jpeg files","*.jpg"),("all files","*.*")))
    read_image = cv2.imread(file_name)
    file_path['text'] = file_name
    detectAndDisplay()

def detectAndDisplay():
    global read_image
    global classes
    test_images = []
    
    h, w = read_image.shape[:2]
    height = int(h * width / w)
    img = cv2.resize(read_image, (width, height))
    
    box = cv2.selectROI("Select Resign Of interest and Press Enter or Space key", img, fromCenter=False,
            showCrosshair=True)
    startX = int(box[0])
    startY = int(box[1])
    endX = int(box[0]+box[2])
    endY = int(box[1]+box[3])
    image = cv2.resize(img[startY:endY, startX:endX]
                       , (CW,CH), interpolation = cv2.INTER_AREA)
    test_images.append(image)
    # convert the data and labels to NumPy arrays
    test_images = np.array(test_images)
    # scale data to the range of [0, 1]
    test_images = test_images.astype("float32") / 255.0
    
    result = model.predict(test_images)
    result_number = np.argmax(result[0])
    print(result, result_number)
    print("%s : %.2f %2s" % (classes[result_number], result[0][result_number]*100, '%'))
        
    text = "{}: {}%".format(classes[result_number], round(result[0][result_number]*100,2))
    y = startY - 10 if startY - 10 > 10 else startY + 10
    cv2.rectangle(img, (startX, startY), (endX, endY),
        colors[result_number], 2)
    cv2.putText(img, text, (startX, y),
        cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[result_number], 2)
    for i in range(len(classes)):
        print("%10s %15.7f %2s" % (classes[i], result[0][i]*100, '%'))
        text = "{}: {}".format(classes[i], result[0][i]) 
        cv2.putText(img, text, (10, height - ((i * 20) + 20)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[i], 2)

    img[0:CH, 0:CW] = image
    cv2.imshow("Prediction Output", img)
    cv2.waitKey(0)

        
main = Tk()
main.title(title_name)
main.geometry()

label=Label(main, text=title_name)
label.config(font=("Courier", 18))
label.grid(row=0,column=0,columnspan=4)

weight_title = Label(main, text='Weight')
weight_title.grid(row=1,column=0,columnspan=1)
weight_path = Label(main, text=weight_name)
weight_path.grid(row=1,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectWeightFile()).grid(row=1, column=3, columnspan=1, sticky=(N, S, W, E))

classes_title = Label(main, text='Classes')
classes_title.grid(row=2,column=0,columnspan=1)
classes_path = Label(main, text=classes_name)
classes_path.grid(row=2,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectClassesFile()).grid(row=2, column=3, columnspan=1, sticky=(N, S, W, E))

file_title = Label(main, text='Image')
file_title.grid(row=3,column=0,columnspan=1)
file_path = Label(main, text=file_name)
file_path.grid(row=3,column=1,columnspan=2)
Button(main,text="Select", height=1,command=lambda:selectFile()).grid(row=3, column=3, columnspan=1, sticky=(N, S, W, E))

log_ScrolledText = tkst.ScrolledText(main, height=20)
log_ScrolledText.grid(row=4,column=0,columnspan=4, sticky=(N, S, W, E))

log_ScrolledText.configure(font='TkFixedFont')

log_ScrolledText.tag_config('HEADER', foreground='gray', font=("Helvetica", 14))
log_ScrolledText.tag_config('TITLE', foreground='orange', font=("Helvetica", 18), underline=1, justify='center')

log_ScrolledText.insert(END, '\n\n1. Please select an Image\n2. Drag an Object\n3. Press Enter or Space key\n', 'TITLE')

main.mainloop()

 기능은 이전 custom-yolo GUI와 비슷하다.