Semantic Segmentation

2. 데이터셋과 데이터 로더 구현

jwjwvison 2021. 11. 20. 20:15
  • 화상 데이터 및 어노테이션 데이터 파일의 경로 리스트 작성
# 패키지 import
import os.path as osp
from PIL import Image

import torch.utils.data as data

import matplotlib.pyplot as plt
def make_datapath_list(rootpath):
    """
    학습, 검증용 화상 데이터와 어노테이션 데이터의 파일 경로 리스트를 작성한다.

    Parameters
    ----------
    rootpath : str
        데이터 폴더의 경로

    Returns
    -------
    ret : train_img_list, train_anno_list, val_img_list, val_anno_list
        데이터의 경로를 저장한 리스트
    """

    # 화상 파일과 어노테이션 파일의 경로 템플릿을 작성
    imgpath_template = osp.join(rootpath, 'JPEGImages', '%s.jpg')
    annopath_template = osp.join(rootpath, 'SegmentationClass', '%s.png')

    # 훈련 및 검증 파일 각각의 ID(파일 이름)를 취득
    train_id_names = osp.join(rootpath + 'ImageSets/Segmentation/train.txt')
    val_id_names = osp.join(rootpath + 'ImageSets/Segmentation/val.txt')

    # 훈련 데이터의 화상 파일과 어노테이션 파일의 경로 리스트를 작성
    train_img_list = list()
    train_anno_list = list()

    for line in open(train_id_names):
        file_id = line.strip()  # 공백과 줄바꿈 제거
        img_path = (imgpath_template % file_id)  # 화상의 경로
        anno_path = (annopath_template % file_id)  # 어노테이션의 경로
        train_img_list.append(img_path)
        train_anno_list.append(anno_path)

    # 검증 데이터의 화상 파일과 어노테이션 파일의 경로 리스트 작성
    val_img_list = list()
    val_anno_list = list()

    for line in open(val_id_names):
        file_id = line.strip()  # 공백과 줄바꿈 제거
        img_path = (imgpath_template % file_id)  # 화상의 경로
        anno_path = (annopath_template % file_id)  # 어노테이션의 경로
        val_img_list.append(img_path)
        val_anno_list.append(anno_path)

    return train_img_list, train_anno_list, val_img_list, val_anno_list

 

  • 데이터셋 작성

먼저 화상 데이터와 어노테이션 데이터를 세트로 변환해야 한다. Compose 클래스를 준비하고 Compose 내에서 데이터 변환을 실시한다.

class DataTransform():
  def __init__(self,input_size,color_mean,color_std):
    self.data_transform={
        'train' : Compose([
                           Scale(scale=[0.5,1.5]),  # 화상 확대
                           RandomRotation(angle=[-10,10]),  # 회전
                           RandomMirror(), # 랜덤 미러
                           Resize(input_size),  # 리사이즈
                           Normalize_Tensor(color_mean,color_std)  # 색상 정보의 표준화와 텐서화

        ]),
        'val': Compose([
                        Resize(input_size),  # 리사이즈(input_size)
                        Normalize_Tensor(color_mean,color_std)  # 색상 정보의 표준화와 텐서화
        ])
    }

  def __call__(self,phase,img,anno_class_img):
    return self.data_transform[phase](img,anno_class_img)

 

class VOCDataset(data.Dataset):
    """
    VOC2012의 Dataset을 만드는 클래스. PyTorch의 Dataset 클래스를 상속받는다.

    Attributes
    ----------
    img_list : 리스트
        어노테이션의 경로를 저장한 리스트
    anno_list : 리스트
        어노테이션의 경로를 저장한 리스트
    phase : 'train' or 'test'
        학습 또는 훈련을 설정한다.
    transform : object
        전처리 클래스의 인스턴스
    """

    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform

    def __len__(self):
        '''화상의 매수를 반환'''
        return len(self.img_list)

    def __getitem__(self, index):
        '''
        전처리한 화상의 텐서 형식 데이터와 어노테이션을 취득
        '''
        img, anno_class_img = self.pull_item(index)
        return img, anno_class_img

    def pull_item(self, index):
        '''화상의 텐서 형식 데이터, 어노테이션을 취득한다'''

        # 1. 화상 읽기
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path)   # [높이][폭][색RGB]

        # 2. 어노테이션 화상 읽기
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path)   # [높이][폭]

        # 3. 전처리 실시
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)

        return img, anno_class_img

 

 제대로 동작하는지 확인해보자.

# 동작 확인

# (RGB) 색의 평균치와 표준편차
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

# 데이터 세트 작성
train_dataset = VOCDataset(train_img_list, train_anno_list, phase="train", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(val_img_list, val_anno_list, phase="val", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

# 데이터를 추출하는 예
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))

  • 데이터 로더 작성

 마지막으로 DataLoader를 만든다.

# 데이터 로더 작성
batch_size = 8

train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

# 사전 오브젝트로 정리한다
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

# 동작 확인
batch_iterator = iter(dataloaders_dict["val"])  # 반복자로 변환
imges, anno_class_imges = next(batch_iterator)  # 첫번째 요소를 꺼낸다
print(imges.size())  # torch.Size([8, 3, 475, 475])
print(anno_class_imges.size())  # torch.Size([8, 3, 475, 475])

 

'Semantic Segmentation' 카테고리의 다른 글

6. Decoder, AuxLoss 모듈  (0) 2021.11.21
5. Pyramid Pooling 모듈  (0) 2021.11.20
4. Feature 모듈 설명 및 구현(ResNet)  (0) 2021.11.20
3. PSPNet 네트워크 구성 및 구현  (0) 2021.11.20
1. 시맨틱 분할이란  (0) 2021.11.20