GAN/이론

10. SGAN - 구현

jwjwvison 2021. 6. 28. 22:05

 이 튜토리얼에서는 MNIST 데이터셋에서 100개의 훈련 샘플만 사용해 손글씨 숫자를 분류하는 SGAN 모델을 만들어보겠다. 

 

<모델 구조>

다음 그림은 이 튜토리얼에서 구현할 SGAN 모델을 고수준에서 그린 것이다. 이 그림은 이장의 서두에서 소개한 일반적인 개념도보다 조금 더 복잡하다.

 진짜 레이블을 분류하는 다중 분류 문제를 풀기 위해 판별자는 소프트맥스 함수를 사용한다. 이 함수는 지정된 클래스 개수(여기에서는 10개) 만큼 확률 분포를 반환한다. 어떤 레이블에 할당된 확률이 높을수록 판별자가 샘플이 해당 클래스에 속한다고 크게 확신한다. 분류 오차를 계산하려면 출력 확률과 원-핫 인코딩된 타깃 레이블 사이의 교차 엔트로피 손실(cross-entropy-loss)을 사용한다.

 진짜 대 가짜 확률을 출력하기 위해 판별자는 시그모이드 함수를 사용하고 이진 교차 엔트로피 손실을 역전파하여 모델 파라미터를 훈련한다.

 

<구현>

 

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras import backend as K

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import (Activation,BatchNormalization,Concatenate,Dense,Dropout,
                                     Flatten,Input,Lambda,Reshape)

from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D,Conv2DTranspose
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

 

 입력 이미지 크기 잡음 벡터 z의 크기, 준지도 분류를 위한 진짜 클래스 개수(판별자가 분류하도록 학습할 숫자당 한 개씩)를 다음 코드처럼 지정한다.

img_rows=28
img_cols=28
channels=1

img_shape=(img_rows,img_cols,channels)  

z_dim=100

num_classes=10

 <데이터셋>

 MNIST 훈련 데이터셋에 레이블된 훈련 이미지 50,000개가 있지만 그중 일부만 훈련에 사용한다(num_labeled 매개변수로 지정한다). 그리고 나머지 샘플은 모두 레이블이 없는 것처럼 다룬다. 레이블된 데이터 배치를 만들 때 처음 num_labeled개의 이미지를 사용하고 레이블이 없는 샘플 배치를 만들 때 나머지 (50000-num_labeled) 이미지를 사용한다.

 

 Dataset 오브젝트는 num_labeled 훈련 샘플을 반환하는 함수와 MNIST 데이터셋에서 레이블된 테스트 이미지 10,000개를 반환하는 함수를 제공한다. 훈련이 끝난 후 테스트 세트를 사용하여 모델의 분류 능력이 본 적 없는 샘플에 얼마나 잘 일반화되는지 평가해보자.

class Dataset:
  def __init__ (self,num_labeled):

    self.num_labeled=num_labeled

    (self.x_train,self.y_train),(self.x_test,self.y_test) = mnist.load_data()
    

    def preprocess_imgs(x):
      x=(x.astype(np.float32) - 127.5) /127.5  # [0,255] 사이 흑백 픽셀 값을 [-1,1] 사이로 변환
      x=np.expand_dims(x,axis=3)  # 너비x높이x채널로 이미지 차원을 확장
      return x

    def preprocess_labels(y):
      return y.reshape(-1,1)

    self.x_train=preprocess_imgs(self.x_train)
    self.y_train=preprocess_labels(self.y_train)

    self.x_test=preprocess_imgs(self.x_test)
    self.y_test=preprocess_labels(self.y_test)

  def batch_labeled(self,batch_size):
    idx=np.random.randint(0,self.num_labeled,batch_size)  # 레이블된 이미지와 레이블의 랜덤 배치 만들기
    imgs=self.x_train[idx]
    labels=self.y_train[idx]
    return imgs,labels

  def batch_unlabeled(self,batch_size):
    idx=np.random.randint(self.num_labeled,self.x_train.shape[0],batch_size) # 레이블이 없는 이미지의 랜덤 배치 만들기
    imgs=self.x_train[idx]
    return imgs

  def training_set(self):
    x_train=self.x_train[range(self.num_labeled)]
    y_train=self.y_train[range(self.num_labeled)]
    return x_train,y_train

  def test_set(self):
    return self.x_test,self.y_test

 

 이 튜토리얼에서는 레이블된 MNIST 이미지 100개만 훈련에 사용하겠다.

num_labeled=100  # 사용할 레이블된 샘플 개수(나머지는 레이블 없이 사용)

dataset=Dataset(num_labeled)

 

<생성자>

생성자 네트워크는 DCGAN에서 만든 것과 동일하다. 다음 코드처럼 생성자는 전치 합성곱 층을 사용해 랜덤한 입력 벡터를 28x28x1 크기 이미지로 변환한다.

def build_generator(z_dim):
  model=Sequential()

  model.add(Dense(256*7*7,input_dim=z_dim))
  model.add(Reshape((7,7,256)))  #  완전 연결 층을 사용해 입력을 7x7x256 크기 텐서로 바꾼다
  model.add(Conv2DTranspose(128,kernel_size=3,strides=2,padding='same')) # 7x7x256 에서 14x14x128 텐서로 바꾸는 전치 합성곱 층
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2DTranspose(64,kernel_size=3,strides=1,padding='same'))
  # 14x14x128 에서 14x14x64 텐서로 바꾸는 전치 합성곱 층
  model.add(BatchNormalization())  # 배치 정규화
  model.add(LeakyReLU(alpha=0.01))
  
  model.add(Conv2DTranspose(1,kernel_size=3,strides=2,padding='same'))
  # 14x14x64 에서 28x28x1 텐서로 바꾸는 전치 합성곱 층
  model.add(Activation('tanh'))

  return model

 

<판별자>

 판별자는 SGAN 모델에서 가장 복잡한 부분이다. SGAN 판별자는 두 가지 목표를 갖는다는 것을 기억해야 한다. 진짜와 가짜 샘플을 구별한다. 이를 위해 SGAN 판별자는 시그모이드 함수를 사용해 이진 분류를 위한 하나의 확률을 출력한다. 진짜 샘플일 경우 레이블을 정확히 분류한다. 이를 위해 SGAN 판별자는 소프트맥스 함수를 사용해 타깃 클래스마다 하나씩 확률을 출력한다.

 

 먼저 판별자 네트워크의 핵심 부분을 정의해보자.

def build_discriminator_net(img_shape):

  model=Sequential()

  model.add(  # 28x28x1 에서 14x14x32 텐서로 바꾸는 합성곱층
      Conv2D(32,kernel_size=3,strides=2,input_shape=img_shape,padding='same')
  )
  model.add(LeakyReLU(alpha=0.01))
  
  model.add( # 14x14x32 에서 7x7x64 텐서로 바꾸는 합성곱 층
      Conv2D(64,kernel_size=3,strides=2,input_shape=img_shape,padding='same')
  )
  model.add(LeakyReLU(alpha=0.01))

  model.add(  # 7x7x64 에서 3x3x128 텐서로 바꾸는 합성곱 층
      Conv2D(128,kernel_size=3,strides=2,input_shape=img_shape,padding='same')
  )
  model.add(LeakyReLU(alpha=0.01))

  model.add(Dropout(0.5))
  
  model.add(Flatten())  # 텐서 펼치기

  model.add(Dense(num_classes)) # num_classes개의 뉴런을 가진 완전 연결 층

  return model

 앞의 신경망은 10개의 뉴런을 가진 완전 연결 층으로 끝난다. 그다음 이 뉴런에서 두 개의 판별자 출력을 계산하는 신경망을 정의해야 한다. 하나는 지도 학습이고 (소프트맥스 함수로) 다중 분류를 수행해야한다. 다른 하나는 비지도 학습이고 (시그모이드 함수를 사용해) 이진 분류를 수행한다.

 

<지도 학습 판별자>

다음 코드에서 앞서 구현한 판별자 모델을 받아 지도 학습에 해당하는 판별자 모델을 만든다.

def build_discriminator_supervised(discriminator_net):
  model=Sequential()
  model.add(discriminator_net)
  model.add(Activation('softmax'))  # 진짜 클래스에 대한 예측 확률을 출력하는 소프트맥스 활성화 함수
  
  return model

 

<비지도 학습 판별자>

다음 코드는 판별자 기반 모델 위에 비지도 학습에 해당하는 판별자 모델을 만든다. predict(x) 함수는 (기반 모델에서 온) 10개 뉴런의 출력을 진짜 대 가짜의 이진 예측으로 변환한다.

def build_discriminator_unsupervised(discriminator_net):
  model=Sequential()
  model.add(discriminator_net)

  def predict(x):
    prediction=1.0 - (1.0/(K.sum(K.exp(x),axis=-1,keepdims=True) + 1.0))
    # 진짜 클래스에 대한 확률 분포를 진짜 대 가짜의 이진 확률로 변환한다

    return prediction

  model.add(Lambda(predict))

  return model

 

<GAN 모델 구성>

 그다음 판별자와 생성자 모델을 구성하고 컴파일한다. 지도 손실과 비지도 손실을 위해 categorical_crossentropy와 binary_crossentropy 손실 함수를 사용한다.

def build_gan(generator,discriminator):
  model=Sequential()

  model.add(generator)
  model.add(discriminator)

  return model
  
discriminator_net=build_discriminator_net(img_shape)
# 판별자 기반 모델: 이 층들은 지도 학습 훈련과 비지도 학습 훈련에 공유된다

discriminator_supervised=build_discriminator_supervised(discriminator_net)
discriminator_supervised.compile(loss='categorical_crossentropy',
                                 metrics=['accuracy'],
                                 optimizer=Adam(learning_rate=0.0003))

discriminator_unsupervised=build_discriminator_unsupervised(discriminator_net)  
# 비지도 학습 훈련을 위해 판별자를 만들고 컴파일
discriminator_unsupervised.compile(loss='binary_crossentropy',optimizer=Adam())

generator=build_generator(z_dim)  # 생성자를 만든다.
discriminator_unsupervised.trainable=False # 생성자 훈련을 위해 판별자의 모델 파라미터를 동결

gan=build_gan(generator,discriminator_unsupervised)
gan.compile(loss='binary_crossentropy',optimizer=Adam())
# 생성자를 훈련하기 위해 고정된 판별자로 GAN 모델을 만들고 컴파일한다. 참고: 비지도 학습용 판별자를 사용

 

<훈련>

 다음 의사코드는 SGAN 훈련 알고리즘을 나타낸다.

supervised_losses=[]
iteration_checkpoints=[]

def train(iterations,batch_size,sample_interval):
  real=np.ones((batch_size,1))
  fake=np.zeros((batch_size,1))

  for iteration in range(iterations):
    imgs,labels=dataset.batch_labeled(batch_size)  # 레이블된 샘플을 가져온다
    labels=to_categorical(labels,num_classes=num_classes) # 레이블을 원-핫 인코딩한다
    imgs_unlabeled=dataset.batch_unlabeled(batch_size) # 레이블이 없는 샘플을 가져온다

    z=np.random.normal(0,1,(batch_size,z_dim)) # 가짜 이미지의 배치를 생성
    gen_imgs=generator.predict(z)

    d_loss_supervised,accuracy=discriminator_supervised.train_on_batch(imgs,labels)
    # 레이블된 진짜 샘플에서 훈련

    d_loss_real=discriminator_unsupervised.train_on_batch(imgs_unlabeled,real)
    # 레이블이 없는 진짜 샘플에서 훈련

    d_loss_fake=discriminator_unsupervised.train_on_batch(gen_imgs,fake)
    # 가짜 샘플에서 훈련
    d_loss_unsupervised=0.5 * np.add(d_loss_real,d_loss_fake)

    z=np.random.normal(0,1,(batch_size,z_dim))
    gen_imgs=generator.predict(z)

    g_loss=gan.train_on_batch(z,np.ones((batch_size,1)))  # 생성자를 훈련

    if (iteration + 1) % sample_interval == 0 :
      supervised_losses.append(d_loss_supervised) # 훈련이 끝난 후 그래프를 그리기 위해 판별자의 지도 학습 분류 손실을 기록
      iteration_checkpoints.append(iteration+1)

      print(
          '%d [D 손실: %.4f, 정확도: %.2f%%] [D 손실: %.4f] [G 손실: %f]'
          % (iteration + 1,d_loss_supervised,100 * accuracy,
             d_loss_unsupervised,g_loss)
      )

 

<모델 훈련>

훈련을 위해 레이블된 샘프이 100개뿐이므로 작은 배치 크기를 사용한다. 반복 횟수는 시행 착오를 거쳐 결정한다. 판별자의 지도 학습 손실 값이 평탄해질 때까지 이 횟수를 늘린다. 하지만 (과대적합의 위험을 줄이기 위해) 이 지점을 너무 지나서까지 늘리진 않는다.

iterations=8000
batch_size=32
sample_interval=800

train(iterations,batch_size,sample_interval)

losses = np.array(supervised_losses)

# 판별자의 지도 학습 손실을 그립니다.
plt.figure(figsize=(15, 5))
plt.plot(iteration_checkpoints, losses, label="Discriminator loss")

plt.xticks(iteration_checkpoints, rotation=90)

plt.title("Discriminator – Supervised Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()

 

 SGAN이 분류를 얼마나 잘 수행하는지 알아보자. 훈련하는 동안 지도 학습 정확도는 100%를 달성한다. 중요한 것은 이 분류기가 이전에 본 적 없는 훈련 세트 데이터에 얼마나 잘 일반화되는가이다. 다음 코드를 실행해보자.

x,y=dataset.test_set()
y=to_categorical(y,num_classes=num_classes)

_,accuracy=discriminator_supervised.evaluate(x,y)
print("테스트 정확도: %.2f%%" % (100*accuracy))

 이 SGAN은 테스트 세트에 있는 약 91%의 샘플을 정확하게 분류할 수 있다. 이것이 얼마나 놀라운 결과인지 완전히 지도 학습으로 훈련한 분류기와 성능을 비교해보자.

 

<지도 학습 분류기와 비교하기>

 가능한 한 공정하게 비교하기 위해 다음 코드처럼 지도 학습 판별자 모델에 사용했던 것과 같은 구조를 사용해 완전한 지도 학습 분류기를 만든다. 이를 통해 GAN 방식의 준지도 학습으로 분류기의 능력이 얼마나 향상되었는지 구분할 수 있다.

 

mnist_classifier=build_discriminator_supervised(
    build_discriminator_net(img_shape))  # SGAN 판별자와 같은 네트워크 구조를 가진 지도 학습 분류기
mnist_classifier.compile(loss='categorical_crossentropy',
                         metrics=['accuracy'],
                         optimizer=Adam())
imgs,labels=dataset.training_set()

# 레이블을 원-핫 인코딩한다.
labels=to_categorical(labels,num_classes=num_classes)

# 분류기를 훈련한다.
training=mnist_classifier.fit(x=imgs,y=labels,batch_size=32,epochs=30,verbose=1)
losses=training.history['loss']
accuracies=training.history['accuracy']
x,y=dataset.test_set()
y=to_categorical(y,num_classes=num_classes)

# 테스트 세트에 대한 분류 정확도를 계산한다.
_,accuracy=mnist_classifier.evaluate(x,y)
print('Test Accuract: %.2f%%' % (100 * accuracy))

 약 70%만 정확하게 분류한다. SGAN 보다 20% 포인트나 나쁘다. 다르게 말하면 SGAN이 정확도를 30%나 향상시켰다.

 

 동일한 조건과 훈련하에서 레이블된 샘플 10,000개(앞에서보다 100배 많다)를 사용하는 완전 지도 학습 분류기는 98% 정확도를 달성한다. 하지만 더는 준지도 학습 설정은 아니다.

 

<결론>

이 장에서 준지도 학습을 위해 판별자가 진짜 샘플의 클래스 레이블을 출력하도록 GAN을 사용하는 방법을 알아보았다. 적은 수의 훈련 샘플에서 SGAN으로 훈련한 분류기 성능이 완전한 지도 학습 분류기보다 훨씬 뛰어난 것을 보았다.

 

 GAN의 혁신 측면에서 보면 SGAN의 핵심 차별 요소는 판별자 훈련에 레이블을 사용하는 것이다.

'GAN > 이론' 카테고리의 다른 글

12. CGAN - 구현  (0) 2021.06.30
11. CGAN - 이론  (0) 2021.06.30
9. SGAN - 이론  (0) 2021.06.28
8. 텐서플로 허브를 사용한 실습  (0) 2021.06.27
7. ProGAN - 주요한 혁신들(2)  (0) 2021.06.27