- 화상 데이터 및 어노테이션 데이터 파일의 경로 리스트 작성
# 패키지 import
import os.path as osp
from PIL import Image
import torch.utils.data as data
import matplotlib.pyplot as plt
def make_datapath_list(rootpath):
"""
학습, 검증용 화상 데이터와 어노테이션 데이터의 파일 경로 리스트를 작성한다.
Parameters
----------
rootpath : str
데이터 폴더의 경로
Returns
-------
ret : train_img_list, train_anno_list, val_img_list, val_anno_list
데이터의 경로를 저장한 리스트
"""
# 화상 파일과 어노테이션 파일의 경로 템플릿을 작성
imgpath_template = osp.join(rootpath, 'JPEGImages', '%s.jpg')
annopath_template = osp.join(rootpath, 'SegmentationClass', '%s.png')
# 훈련 및 검증 파일 각각의 ID(파일 이름)를 취득
train_id_names = osp.join(rootpath + 'ImageSets/Segmentation/train.txt')
val_id_names = osp.join(rootpath + 'ImageSets/Segmentation/val.txt')
# 훈련 데이터의 화상 파일과 어노테이션 파일의 경로 리스트를 작성
train_img_list = list()
train_anno_list = list()
for line in open(train_id_names):
file_id = line.strip() # 공백과 줄바꿈 제거
img_path = (imgpath_template % file_id) # 화상의 경로
anno_path = (annopath_template % file_id) # 어노테이션의 경로
train_img_list.append(img_path)
train_anno_list.append(anno_path)
# 검증 데이터의 화상 파일과 어노테이션 파일의 경로 리스트 작성
val_img_list = list()
val_anno_list = list()
for line in open(val_id_names):
file_id = line.strip() # 공백과 줄바꿈 제거
img_path = (imgpath_template % file_id) # 화상의 경로
anno_path = (annopath_template % file_id) # 어노테이션의 경로
val_img_list.append(img_path)
val_anno_list.append(anno_path)
return train_img_list, train_anno_list, val_img_list, val_anno_list
- 데이터셋 작성
먼저 화상 데이터와 어노테이션 데이터를 세트로 변환해야 한다. Compose 클래스를 준비하고 Compose 내에서 데이터 변환을 실시한다.
class DataTransform():
def __init__(self,input_size,color_mean,color_std):
self.data_transform={
'train' : Compose([
Scale(scale=[0.5,1.5]), # 화상 확대
RandomRotation(angle=[-10,10]), # 회전
RandomMirror(), # 랜덤 미러
Resize(input_size), # 리사이즈
Normalize_Tensor(color_mean,color_std) # 색상 정보의 표준화와 텐서화
]),
'val': Compose([
Resize(input_size), # 리사이즈(input_size)
Normalize_Tensor(color_mean,color_std) # 색상 정보의 표준화와 텐서화
])
}
def __call__(self,phase,img,anno_class_img):
return self.data_transform[phase](img,anno_class_img)
class VOCDataset(data.Dataset):
"""
VOC2012의 Dataset을 만드는 클래스. PyTorch의 Dataset 클래스를 상속받는다.
Attributes
----------
img_list : 리스트
어노테이션의 경로를 저장한 리스트
anno_list : 리스트
어노테이션의 경로를 저장한 리스트
phase : 'train' or 'test'
학습 또는 훈련을 설정한다.
transform : object
전처리 클래스의 인스턴스
"""
def __init__(self, img_list, anno_list, phase, transform):
self.img_list = img_list
self.anno_list = anno_list
self.phase = phase
self.transform = transform
def __len__(self):
'''화상의 매수를 반환'''
return len(self.img_list)
def __getitem__(self, index):
'''
전처리한 화상의 텐서 형식 데이터와 어노테이션을 취득
'''
img, anno_class_img = self.pull_item(index)
return img, anno_class_img
def pull_item(self, index):
'''화상의 텐서 형식 데이터, 어노테이션을 취득한다'''
# 1. 화상 읽기
image_file_path = self.img_list[index]
img = Image.open(image_file_path) # [높이][폭][색RGB]
# 2. 어노테이션 화상 읽기
anno_file_path = self.anno_list[index]
anno_class_img = Image.open(anno_file_path) # [높이][폭]
# 3. 전처리 실시
img, anno_class_img = self.transform(self.phase, img, anno_class_img)
return img, anno_class_img
제대로 동작하는지 확인해보자.
# 동작 확인
# (RGB) 색의 평균치와 표준편차
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)
# 데이터 세트 작성
train_dataset = VOCDataset(train_img_list, train_anno_list, phase="train", transform=DataTransform(
input_size=475, color_mean=color_mean, color_std=color_std))
val_dataset = VOCDataset(val_img_list, val_anno_list, phase="val", transform=DataTransform(
input_size=475, color_mean=color_mean, color_std=color_std))
# 데이터를 추출하는 예
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))
- 데이터 로더 작성
마지막으로 DataLoader를 만든다.
# 데이터 로더 작성
batch_size = 8
train_dataloader = data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False)
# 사전 오브젝트로 정리한다
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}
# 동작 확인
batch_iterator = iter(dataloaders_dict["val"]) # 반복자로 변환
imges, anno_class_imges = next(batch_iterator) # 첫번째 요소를 꺼낸다
print(imges.size()) # torch.Size([8, 3, 475, 475])
print(anno_class_imges.size()) # torch.Size([8, 3, 475, 475])
'Semantic Segmentation' 카테고리의 다른 글
6. Decoder, AuxLoss 모듈 (0) | 2021.11.21 |
---|---|
5. Pyramid Pooling 모듈 (0) | 2021.11.20 |
4. Feature 모듈 설명 및 구현(ResNet) (0) | 2021.11.20 |
3. PSPNet 네트워크 구성 및 구현 (0) | 2021.11.20 |
1. 시맨틱 분할이란 (0) | 2021.11.20 |