Object Detection

Custom YOLO project - 대용량 xml파일 라벨링 해주기

jwjwvison 2021. 8. 21. 18:43

 이전 포스팅에서는 몇장 안되는 이미지의 객체들을 쉽게 손으로 라벨링 할 수 있었다. 그러나 대용량의 이미지를 직접 손으로 하는 것은 매우 힘들것이다. 만약 xml파일이 있다면 다음 코드를 통해 yolo format의 이미지 txt 파일로 yolo 좌표체계를 추출할 수 있다.

from xml.dom import minidom
import os

current_path=os.path.abspath(os.curdir)
print('Current path is {}'.format(current_path))
YOLO_FORMAT_PATH=current_path + '/apple-data/train/images'
XML_FORMAT_PATH=current_path + '/apple-data/train/annotations'
file_count=0
classes={}
classes['apple']=0
classes['damaged_apple'] = 1

def getYoloCordinates(size,box):
    width_ratio=1.0/size[0]
    height_ratio=1.0/size[1]

    x=(box[0]+box[1]/2.0)  #center x
    y=(box[2]+box[3]/2.0)  #center y
    w=box[1]-box[0]        #width
    h=box[3]-box[2]        #height
    
    # 위 수치는 절대값이기 때문에 ratio를 곱해줘야한다
    x=x*width_ratio
    w=w*width_ratio
    y=y*height_ratio
    h=h*height_ratio
    return(x,y,w,h)

os.chdir(XML_FORMAT_PATH)

with open(YOLO_FORMAT_PATH + '/' + 'classes.txt','w') as txt:
    for item in classes:
        txt.write(item+'\n')
        print('{%s} is added in classes.txt' %item)

# xml파일들을 하나하나씩 읽는다
for current_dir,dirs,files in os.walk('.'):
    for file in files:
        if file.endswith('.xml'):
            xmldoc=minidom.parse(file)
            yolo_format=(YOLO_FORMAT_PATH + '/' + file[:-4] + '.txt') #.xml부분을 .txt로 바꿔준다

            with open(yolo_format,'w') as f:
                objects=xmldoc.getElementsByTagName('object') #xml파일에서 object태그로 간다
                size=xmldoc.getElementsByTagName('size')[0]
                width=int((size.getElementsByTagName('width')[0]).firstChild.data)
                height=int((size.getElementsByTagName('height')[0]).firstChild.data)

                for item in objects:
                    name=(item.getElementsByTagName('name')[0]).firstChild.data
                    if name in classes:
                        class_name=str(classes[name])
                    else:
                        class_name='-1'
                        print("[warning] Class name {'%s'} is not in classes" %name)

                    # xml과 yolo의 좌표 체계가 다르다
                    xmin = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('xmin')[0]).firstChild.data
                    ymin = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('ymin')[0]).firstChild.data
                    xmax = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('xmax')[0]).firstChild.data
                    ymax = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('ymax')[0]).firstChild.data
                    xml_cordinates=(float(xmin),float(xmax),float(ymin),float(ymax))
                    yolo_cordinates=getYoloCordinates((width,height),xml_cordinates)

                    f.write(class_name + ' ' + ' '.join([('%.6f' %a) for a in yolo_cordinates]))
                    file_count+=1
            
            print('{}. [{}] is created'.format(file_count,yolo_format))

 이를 통해 앞에 이미지로부터 txt 좌표파일을 얻는것처럼 똑같이 얻을수 있다.