- 모드 붕괴
모드 붕괴(model collapse)현상은 GAN을 훈련할 때 자주 맞닥뜨리는 현상이다.
모드 붕괴가 일어나면 생성기는 오직 하나만 만들게 되거나, 선택지의 극히 일부만 만들게 된다.
모드 붕괴가 왜 발생하는지에 대한 하나의 그럴듯한 이론은, 생성기가 판별기보다 더 앞서간 후에 항상 실제에 가깝게 결과가 나오는 '꿀 지점'을 발견하여 그 이미지를 계속 만들어내게 된다는 것이다. 이 현상을 완화하는 방법을 빠르게 생각해보면, 판별기를 생성기보다 좀 더 자주 훈련시키는 것이다. 하지만 사실 훈련의 양보다는 질이 더 중요하다.
- GAN 훈련 성능 향상하기
1. 이진 교차 엔트로피 BCELoss() 를 이용해 평균제곱오차 MSELoss() 손실함수를 대체하는 방법이다. 이는 MSELoss() 보다 훨씬 더 보상과 벌이 강하다.
2. LeakyReLU() 활성화 함수를 생성기와 판별기에 적용하는 것이다.
3. 신경망에서 나오는 신호에 대해 정규화를 진행하여 평균을 0으로 맞추고, 분산을 제한하여 극단적인 값을 피하는 방법이다.
판별자 네트워크
class Discriminator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 1),
nn.Sigmoid()
)
# create loss function
self.loss_function = nn.BCELoss()
# create optimiser, simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, inputs, targets):
# calculate the output of the network
outputs = self.forward(inputs)
# calculate loss
loss = self.loss_function(outputs, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 10000 == 0):
print("counter = ", self.counter)
pass
# zero gradients, perform a backward pass, update weights
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
pass
pass
생성자 네트워크
class Generator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(100, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)
# create optimiser, simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, D, inputs, targets):
# calculate the output of the network
g_output = self.forward(inputs)
# pass onto Discriminator
d_output = D.forward(g_output)
# calculate error
loss = D.loss_function(d_output, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
# zero gradients, perform a backward pass, update weights
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
pass
pass
아직은 모든 이미지가 같은 형상이다.
생성의 첫 단계는 시드인데 아마도 하나의 값만으로 생성기가 10개 숫자에 대한 784개 픽셀을 온전히 만드는 것은 역부족인것 같다.
시도해볼 만한 쉬운 방법은 입력 시드에 충분히 많은 숫자를 넣는 것이다. 일단 임의의 값으로 입력 노드를 100개 만들어 시작해 보겠다.
이전보다 더 글씨같지만 여전히 모두 같은 이미지이다.
판별기와 생성기에 시드를 주기 위한 임의의 값에 대해 계속 생각해보면, 이 둘에 입력되는 값이 달라야 할 것 같다는 생각이 든다.
1. 판별기에 입력되는 임의의 이미지 픽셀 값은 0에서 1사이에서 고르게 선택해야 한다. 범위가 0부터 1인 이유는 이것이 실제 데이터셋에서 관찰되는 값이기 때문이다.
2. 생성기에 투입되는 임의의 값은 0부터 1 사이의 값이 아니어도 된다. 신경망에서 평균이 0이고 분산이 제한된 정규화된 값들이 학습에 유리하다는 것을 우리는 이미 알고 있다.
# function to generate uniform random data
def generate_random_image(size):
random_data = torch.rand(size)
return random_data
def generate_random_seed(size):
random_data=torch.randn(size)
return random_data
이제 판별기에 투입할 때마다 generate_random_image(784)를 사용하고, 생성기에는 generate_random_seed(100)을 사용하면 된다.
%%time
# create Discriminator and Generator
D = Discriminator()
G = Generator()
epochs = 4
for epoch in range(epochs):
print ("epoch = ", epoch + 1)
# train Discriminator and Generator
for label, image_data_tensor, target_tensor in mnist_dataset:
# train discriminator on true
D.train(image_data_tensor, torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
pass
pass
이제 손실 차트를 확인해보자. 현재는 BCELoss() 를 썼기 때문에 값들이 항상 0에서 1사이에 머물러 있지는 않는다.
생성기의 손실은 어느 정도 고정된 값을 중심으로 약간의 변화가 있는 게 오히려 더 바람직하다.
'GAN > 이론' 카테고리의 다른 글
7. 합성곱 GAN (전치 합성곱) (0) | 2021.11.06 |
---|---|
6. 얼굴 이미지(HDF5 데이터 형식, GPU 가속) (0) | 2021.11.06 |
4. 손으로 쓴 숫자 훈련 (1) (0) | 2021.11.03 |
3. 단순한 1010 패턴 (0) | 2021.10.31 |
2. GAN 개념 (0) | 2021.10.31 |