GAN/이론

4. 손으로 쓴 숫자 훈련 (1)

jwjwvison 2021. 11. 3. 21:18

 전체적인 구조는 위와 같다. 생성기의 역할은 동일한 크기의 이미지를 만드는 것이다.

 

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import pandas, numpy, random
import matplotlib.pyplot as plt

 

  • 데이터셋 클래스

 csv 파일에서 MNIST 데이터를 불러와 본격적으로 사용하기 위해 파이토치의 torch.utils.data.Dataset 클래스를 사용한다.

 

 이 클래스는 데이터를 텐서로 묶고 각 레코드마다 정답 레이블, 0부터 1사이의 값으로재조정된 이미지 픽셀 값, 원핫 인코딩이 된 텐서를 반환한다.

class MnistDataset(Dataset):
    
    def __init__(self, csv_file):
        self.data_df = pandas.read_csv(csv_file, header=None)
        pass
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, index):
        # image target (label)
        label = self.data_df.iloc[index,0]
        target = torch.zeros((10))
        target[label] = 1.0
        
        # image data, normalised from 0-255 to 0-1
        image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
        
        # return label, image data tensor and target tensor
        return label, image_values, target
    
    def plot_image(self, index):
        img = self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(img, interpolation='none', cmap='Blues')
        pass
    
    pass
# load data

mnist_dataset=MnistDataset('/content/drive/MyDrive/GAN-pytorch/mnist_data/mnist_train.csv')
mnist_dataset.plot_image(17)

 

 

  • MNIST 판별기
# discriminator class

class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.Sigmoid(),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.MSELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    
    
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass
    
    pass

 

  • 판별기 테스트하기

 생성기를 만들기 전에, 판별기가 실제 이미지와 임의의 노이즈 간에 구별을 할 수 있는지 한번 살펴보고 넘어가야 한다.

D=Discriminator()

for label,image_data_tensor,target_tensor in mnist_dataset:
  D.train(image_data_tensor,torch.FloatTensor([1.0]))
  D.train(generate_random(784),torch.FloatTensor([0.0]))
  pass

 각각의 실제 이미지마다 generate_random(784)를 이용해 임의의 가짜 노이즈 이미지 픽셀 값을 만든다. 판별기는 이 노이즈 값에 대해서는 0.0을 출력하도록 훈련한다.

 

 이제 훈련을 진행하면서 손실이 어떻게 변하는지 차트를 통해 확인한다.

 

 이제 훈련된 판별기에 임의로 선택한 이미지를 수동으로 넣어 결과를 한번 확인해보자.

# 결과 확인
 
for i in range(4):
  image_data_tensor=mnist_dataset[random.randint(0,60000)][1]
  print(D.forward(image_data_tensor).item())
  pass

for i in range(4):
  print(D.forward(generate_random(784)).item())
  pass

 

  • MNIST 생성기 

 생성기는 MNIST 데이터셋과 같은 형식으로 데이터를 만들어야 한다. 생성기가 생성한 이미지는 28x28 크기이며 총 784개의 픽셀 값을 지녀야 한다는 뜻이다.

 

 전에 훈련했던 1010GAN은 생성기가 1010패턴으로 항상 데이터를 만들었다. 하지만 생성기가 항상 정확히 같은 값을 출력하는 것은 우리가 원하는 바가 아니다. 생성기는 훈련 데이터의 여러 양상을 다양하게 반영하도록 이미지를 생성해야 한다. 3,7,4,9 등등 실제 숫자로 보이는 이미지들을 생성해야 한다.

 

 이는 생성기의 입력으로 상수 0.5가 들어가면 안 된다는 점을 뜻한다. 매 훈련 사이클마다 임의적인 입력을 사용하면 된다.

# generator class

class Generator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(1, 200),
            nn.Sigmoid(),
            nn.Linear(200, 784),
            nn.Sigmoid()
        )
        
        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
        
        pass
    
    
    def forward(self, inputs):        
        # simply run model
        return self.model(inputs)
    
    
    def train(self, D, inputs, targets):
        # calculate the output of the network
        g_output = self.forward(inputs)
        
        # pass onto Discriminator
        d_output = D.forward(g_output)
        
        # calculate error
        loss = D.loss_function(d_output, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass
    
    pass

 

  • 생성기 결과 확인하기
G=Generator()
output=G.forward(generate_random(1))
img=output.detach().numpy().reshape(28,28)
plt.imshow(img,interpolation='none',cmap='Blues')

 생성기가 훈련 전이라는 사실을 알 수 있다. 이는 정상적이며, 오히려 이미지에 패턴이 있다면 어딘가에 실수가 있었다는 뜻이다.

 

  • GAN 훈련하기
# 판별기 및 생성기 생성

D=Discriminator()
G=Generator()

# 판별기와 생성기 훈련
for label,image_data_tensor,target_tensor in mnist_dataset:

  # 참에 대해 판별기 훈련
  D.train(image_data_tensor,torch.FloatTensor([1.0]))

  # 거짓에 대해 판별기 훈련
  # G의 기울기가 계산되지 않도록 detach() 함수를 이용
  D.train(G.forward(generate_random(1)).detach(),torch.FloatTensor([0.0]))

  # 생성기 훈련
  G.train(D,generate_random(1),torch.FloatTensor([1.0]))

  pass

 

 이제 판별기를 훈련시키면서 나온 손실값들을 그려보자.

 손실은 0에 가까워졌고, 한동안 이를 유지한다. 이 시기는 판별기가 생성기를 성능으로 앞서는 부분이다. 이후 손실값이 0.25 바로 아래까지 치솟게 되는데 이는 판별기와 생성기가 균형이 맞기 시작했다는 의미로 볼 수 있다. 이후에는 판별기가 앞서나가 손실값이 다시 낮은 상태에서 머물러 있다.

 

 손실값이 0.25 정도를 보이면서 판별기와 생성기의 성능이 균형이 맞는 상태가 우리가 원하는 바임을 기억하자.

 초기에 손실값이 치솟는 이유는, 판별기가 생성기에서 나온 이미지들을 잘 구별하기 때문이다. 그 이후 손실은 0.25 근처로 하락하여 판별기와 생성기 간 균형이 잘 맞는 상태가 된다. 이후 중반 이후에는 다시 손실이 상승하는데, 이는 판별기의 성능이 생성기보다 더 나은 구간이라고 볼 수 있다.

 

 

서로 다른 임의 시드에서 각기 다른 이미지가 생성되리라 예상되므로, 여러 장의 이미지를 골라서 그려보자.

f,axarr=plt.subplots(2,3,figsize=(16,8))
for i in range(2):
  for j in range(3):
    output=G.forward(generate_random(1))
    img=output.detach().numpy().reshape(28,28)
    axarr[i,j].imshow(img,interpolation='none',cmap='Blues')
    pass
  pass

 

'GAN > 이론' 카테고리의 다른 글

6. 얼굴 이미지(HDF5 데이터 형식, GPU 가속)  (0) 2021.11.06
5. 손으로 쓴 숫자 훈련(2)  (0) 2021.11.03
3. 단순한 1010 패턴  (0) 2021.10.31
2. GAN 개념  (0) 2021.10.31
1. CUDA 기초  (0) 2021.10.31