GAN/이론

8. 조건부 GAN

jwjwvison 2021. 11. 8. 18:10

 앞에서 개발한 MNIST GAN은 굉장히 다양한 범위의 이미지들을 만들었다. 이 사실은 상당히 고무적인데, GAN 설계 시 부딪히는 지속적인 문제가 바로 다양성 없이 이미지를 생성하는 모드 붕괴이기 때문이다.

 

GAN이 생성하는 어떻게든 이미지를 단일한 클래스로 고정한 채로 다양한 이미지를 생성할 수 있다고 하면 상당히 유용할 것이다.

 

  • 조건부 GAN 구조

 훈련된 GAN 생성기가 주어진 클래스에 해당하는 이미지를 생성하게 하려면, 일단 무슨 클래스를 목표로 하는지 생성기에 알려줘야 한다.

 판별기는 클래스 레이블과 이미지 사이의 관계를 학습해야 한다.

 

  • 판별기

 이미지 픽셀 데이터와 클래스 레이블 정보를 동시에 받도록 판별기를 업데이트 해야 한다. 간단한 방법은 forward() 함수에서 이미지 텐서와 레이블 텐서를 동시에 받게 하고 이를 단순히 결합하는 것이다. 레이블 텐서는 원핫 인코딩 되어 있는 텐서로서 데이터셋 클래스에서 이미 준비해둔 텐서이다.

class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.BCELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, image_tensor, label_tensor):
        # combine seed and label
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)
    
    
    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

 

 이제 판별기를 테스트해보자. 훈련 반복문 안에서 레이블 텐서를 추가로 train() 함수에 전달하도록 수정해야 한다.

%%time
# test discriminator can separate real data from random noise

D = Discriminator()
D.to(device)
for label, image_data_tensor, label_tensor in mnist_dataset:
    # real data
    D.train(image_data_tensor, label_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image(784), generate_random_one_hot(10), torch.cuda.FloatTensor([0.0]))
    pass

 

  • 생성기

 시드와 레이블 텐서를 생성기에 투입하게 했으니 두 텐서를 결합해서 신경망에 전달하기 위해 forward() 함수를 고쳐야한다.

class Generator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )
        
        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
        
        pass
    
    
    def forward(self, seed_tensor, label_tensor):        
        # combine seed and label
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)


    def train(self, D, inputs, label_tensor, targets):
        # calculate the output of the network
        g_output = self.forward(inputs, label_tensor)
        
        # pass onto Discriminator
        d_output = D.forward(g_output, label_tensor)
        
        # calculate error
        loss = D.loss_function(d_output, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # plot a 3 column, 2 row array of sample images
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
                pass
            pass
        pass
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

 생성기에서 생성된 이미지들을 판별기의 forward() 함수에 레이블과 함께 넘겨 주도록 했다. 이는 생성기가 다른 레이블로 잘못 판단하는 것을 막아준다.

 

  • 훈련 반복문
%%time 

# train Discriminator and Generator

epochs = 2

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  # train Discriminator and Generator

  for label, image_data_tensor, label_tensor in mnist_dataset:
    # train discriminator on true
    D.train(image_data_tensor, label_tensor, torch.cuda.FloatTensor([1.0]))

    # random 1-hot label for generator
    random_label = generate_random_one_hot(10)
    
    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.cuda.FloatTensor([0.0]))
    
    # different random 1-hot label for generator
    random_label = generate_random_one_hot(10)

    # train generator
    G.train(D, generate_random_seed(100), random_label, torch.cuda.FloatTensor([1.0]))

    pass
    
  pass

 

 

  • 결과 확인

 

 조건부 GAN은 실제로 숫자 9를 생성했고, 거기에서 그치지 않고 더 좋은 점은 모든 이미지가 다 똑같이 생기지는 않았다는 것이다.

'GAN > 이론' 카테고리의 다른 글

9. MSE 손실, BCE 손실  (0) 2021.11.08
7. 합성곱 GAN (전치 합성곱)  (0) 2021.11.06
6. 얼굴 이미지(HDF5 데이터 형식, GPU 가속)  (0) 2021.11.06
5. 손으로 쓴 숫자 훈련(2)  (0) 2021.11.03
4. 손으로 쓴 숫자 훈련 (1)  (0) 2021.11.03