GAN/이론

14. CycleGAN- 구현

jwjwvison 2021. 7. 2. 21:53
from __future__ import print_function,division
import scipy
from tensorflow.keras.datasets import mnist
from tensorflow_addons.layers import InstanceNormalization
from tensorflow.keras.layers import Input,Dense,Reshape,Flatten,Dropout,Concatenate
from tensorflow.keras.layers import BatchNormalization,Activation,ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D,Conv2D
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
import scipy
from glob import glob
import numpy as np

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = scipy.misc.imresize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)
class CycleGAN():
  def __init__(self):
    self.img_rows=128
    self.img_cols=128
    self.channels=3
    self.img_shape=(self.img_rows,self.img_cols,self.channels)

    self.dataset_name='apple2orange'   # 데이터 로더 설정
    self.data_loader=DataLoader(dataset_name=self.dataset_name,  # DataLoader 객체를 사용해 전처리된 데이터 임포트
                                img_res=(self.img_rows,self.img_cols))
    
    patch=int(self.img_rows/2**4)   #D(PatchGAN)의 출력 크기를 계산
    self.disc_path=(patch,patch,1)

    self.gf=32  # G의 첫 번째 층에 있는 필터의 개수
    self.df=64  # D의 첫 번째 층에 있는 필터의 개수

    self.lambda_cycle=10.0  # 사이클-일관성 손실 가중치
    self.lambda_id=0.9 * self.lambda_cycle  # 동일성 손실 가중치
    
    optimizer=Adam(0.0002,0.5)

 두 개의 새로운 변수 lambda_cycle와 lambda_id가 있다. 두 번째 하이퍼파라미터는 동일성 손실에 영향을 미친다. CycleGAN 저자들은 이 값이 (특히 훈련 과정 초기에) 변화에 얼마나 극적으로 영향을 미치는지 이야기한다. 이 값이 낮으면 불필요한 변화가 생긴다. 예를 들어 초기에 색이 완전히 반전된다.

 

 lambda_cycle는 사이클-일관성 손실을 얼마나 엄격히 강제할지 제어한다. 이 값을 높게 설정하면 원본 이미지와 재구성 이미지가 가능한 한 아주 가깝게 만든다.

 

<신경망 구성>

    # 판별자를 만들고 컴파일한다
    self.d_A=self.build_discriminator()
    self.d_B=self.build_discriminator()
    self.d_A.compile(loss='mse',
                     optimizer=optimizer,
                     metrics=['accuracy'])
    self.d_B.compile(loss='mse',
                     optimizer=optimizer,
                     metrics=['accuracy'])
    
    # 여기서부터 생성자의 계산 그래프를 만든다. 처음 두 라인이 생성자를 만든다
    self.g_AB=self.build_generator()
    self.g_BA=self.build_generator()

    # 두 도메인의 입력 이미지
    img_A=Input(shape=self.img_shape)
    img_B=Input(shape=self.img_shape)

    # 이미지를 다른 도메인으로 변환
    fake_B=self.g_AB(img_A)
    fake_A=self.g_BA(img_B)

    # 원본 도메인으로 이미지를 다시 변환
    reconstr_A=self.g_BA(fake_B)
    reconstr_B=self.g_AB(fake_A)

    # 동일한 이미지 매핑
    img_A_id=self.g_BA(img_A)
    img_B_id=self.g_AB(img_B)

    # 연결 모델에서는 생성자만 훈련
    self.d_A.trainable=False
    self.d_B.trainable=False

    # 판별자가 변환된 이미지의 유효성을 결정
    valid_A=self.d_A(fake_A)
    valid_B=self.d_B(fake_B)

    self.combined=Model(inputs=[img_A,img_B],
                        outputs=[valid_A,valid_B,reconstr_A,reconstr_B,
                                 img_A_id,img_B_id])
    
    self.combined.compile(loss=['mse','mse',
                                'mae','mae',
                                'mae','mae'],
                          loss_weights=[1,1,
                                       self.lambda_cycle,self.lambda_cycle,
                                       self.lambda_id,self.lambda_id],
                          optimizer=optimizer)

 이전 코드에서 명확하게 언급할 점 하나는 combined 모델의 출력이 6개라는 것이다. 이 모델은 (판별자의) 유효성(validity), 재구성, 동일성 손실을 위한 출력이 필요하다. A-B-A 사이클일 때 한벌, B-A-B 사이클에서 한 벌이 필요하므로 총 6개가 된다. 처음 2개의 손실은 제곱 오차이고 나머지는 평균 절댓값 오차(MAE) 이다. 

 

<생성자>

그다음 생성자 코드를 구성한다. 스킵연결을 사용한 U-Net 구조이다.

 그다음 실제 생성자를 만든다.

 일반적인 2D 합성곱을 사용해 여분의 특성 맵을 제거하고 128x128x3(높이x너비x컬러 채널) 크기가 된다.

  def build_generator(self):
  
    def conv2d(layer_input,filters,f_size=4,normalization=True):
      ''' 다운 샘플링하는 동안 사용되는 층'''
      d=Conv2D(filters,kernel_size=f_size,strides=2,padding='same')(layer_input)
      d=LeakyReLU(alpha=0.2)(d)
      if normalization:
        d=InstanceNormalization()(d)

      return d

    def deconv2d(layer_input,skip_input,filters,f_size=4,dropout_rate=0):
      ''' 업샘플링 하는 동안 사용되는 층'''
      u=UpSampling2D(size=2)(layer_input)
      u=Conv2D(filters,kernel_size=f_size,strides=1,
              padding='same',activation='relu')(u)

      if dropout_rate:
        u=Dropout(dropout_rate)(u)
      u=InstanceNormalization()(u)
      u=Concatenate()([u,skip_input])

      return u

    d0=Input(shape=self.img_shape)

    # 다운샘플링
    d1=conv2d(d0,self.gf)
    d2=conv2d(d1,self.gf * 2)
    d3=conv2d(d2,self.gf * 4)
    d4=conv2d(d3,self.gf * 8)
    # 업샘플링
    u1=deconv2d(d4,d3,self.gf * 4)
    u2=deconv2d(u1,d2,self.gf * 2)
    u3=deconv2d(u2,d1,self.gf)

    u4=UpSampling2D(size=2)(u3)
    output_img=Conv2D(self.channels,kernel_size=4,
                      strides=1,padding='same',activation='tanh')(u4)


    return Model(d0,output_img)

 

<판별자>

이제 판별자 메서드를 만들어보자. 2D 합성곱, LeakyReLU 그리고 선택적으로 Instance Normalization 층을 만드는 헬퍼 함수를 사용한다.

  def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape)

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model(img, validity)

 

<훈련>

  def train(self,epochs,batch_size=1,sample_interval=50):
    valid=np.ones((batch_size,) + self.disc_path)   # 적대 손실에 대한 정답
    fake=np.zeros((batch_size,) + self.disc_path)

    for epoch in range(epochs):
      for batch_i,(imgs_A,imgs_B) in enumerate(
          self.data_loader.load_batch(batch_size)):
        # 판별자 훈련을 시작한다. 이 두 라인은 이미지를 상대 도메인으로 변환한다.
        fake_B=self.g_AB.predict(imgs_A) 
        fake_A=self.g_BA.predict(imgs_B)

        dA_loss_real=self.d_A.train_on_batch(imgs_A,valid)
        dA_loss_fake=self.d_A.train_on_batch(fake_A,fake)
        dA_loss=0.5 * np.add(dA_loss_real + dA_loss_fake)

        dB_loss_real=self.d_B.train_on_batch(imgs_B,valid)
        dB_loss_fake=self.d_B.train_on_batch(fake_B,fake)
        dB_loss=0.5 * np.add(dB_loss_real,dB_loss_fake)
        d_loss= 0.5 * np.add(dA_loss,dB_loss)  # 판별자 전체 손실

        g_loss=self.combined.train_on_batch([imgs_A,imgs_B],
                                            [valid,valid,imgs_A,imgs_B,
                                             imgs_A,imgs_B])
        
        if batch_i % sample_interval ==0:
          self.sample_images(epoch,batch_i)

'GAN > 이론' 카테고리의 다른 글

1. CUDA 기초  (0) 2021.10.31
=====================================================  (0) 2021.10.31
13. CycleGAN - 이론  (0) 2021.07.02
12. CGAN - 구현  (0) 2021.06.30
11. CGAN - 이론  (0) 2021.06.30