본문 바로가기
AI 논문 리뷰(AI Paper Review)/생성 모델(Generative Model)

[논문정리][GAN] Conditional GAN: Conditional Generative Adversarial Nets

by stevenkim_ 2023. 11. 14.

*본 내용은 논문의 상세한 분석이 아닌, 간단한 복기용 정리입니다.

 

 

 

GAN Architecture

우선 일반 GAN부터 보자면, 목적함수는 다음과 같습니다. 왼쪽 판별자 박스 먼저 보겠습니다. P_data는 원본 데이터의 분포를 의미하는데요. 이 원본 데이터 중에서 한 개의 데이터인 x를 뽑아, D라고 나와있는 판별자에 이 하나의 이미지인 x를 넣습니다. 그럼 Discriminatorx에 대해 판별을 합니다. 판별자는 진짜 이미지에는 1, 가짜 이미지에는 0을 부여하는데요. 그래서 판별자의 아웃풋은 0에서 1 사이의 확률값으로 나오고, 그것에 log를 취한 것의 기댓값을 구합니다.

오른쪽 생성자에서는 p_z에서 하나의 노이즈 데이터인 z를 뽑습니다. p_z는 정규분포나 Uniform 분포를 주로 사용하는 데이터 분포입니다. 이 공간을 잠재 공간 (Latent Space)라고 부르는데요. Latent Space 에서 무작위로 뽑은 랜덤 벡터 z, 이번에는 생성자 G에 넣어서 가짜 이미지를 만듭니다. 이렇게 만들어진 이미지를 다시 판별자 D에 넣고, 나온 확률값을 1에서 뺍니다. 그걸 다시 log를 취합니다. 이것의 기댓값입니다.

목적함수의 맨 왼쪽에 표시 되어있다시피, 이렇게 만들어진 목적함수 V, 생성자인 G는 낮추고자 하고, 판별자인 DV를 높이는 것을 목표합니다.

 


 

Conditional GAN(CGAN) Architecture

다음으로 Conditional GAN에 대해서 설명하겠습니다. CGANGAN에서 조건정보인 y가 조건부 방식으로 추가되어 있습니다. 입력으로 latent space에서 뽑아낸 vectorz 뿐만 아니라, Condition vectory도 함께 넣어줍니다.

예를들어 MNIST 데이터셋을 이용한다고 했을 때, MNIST 데이터는 사람 손글씨 0부터 9까지의 데이터잖아요? 이중에서 예를들어 3을 만들고 싶다면, Conditional Vector3을 입력해주면 됩니다. 이런 식으로 y에 특정 class를 넣어, 출력하고 싶은 데이터를 마음대로 조절하는 것이 Conditional GAN입니다.

Discriminator에서도 마찬가지로 입력한 조건을 기준으로, 이미지가 진짜인지 가짜인지 판별하는 식으로 학습을 진행합니다. (3인지 아닌지 판단합니다.)

 


 

그림을 보시면 각 행마다 하나의 label에 조건을 붙인 것입니다. 순서대로 0부터 9까지 원하는 class의 이미지를 생성한 것을 볼 수 있습니다. 이제 Conditional GAN의 코드를 보겠습니다.

 


 

코드 리뷰

먼저 필요한 라이브러리들을 모두 불러와줍니다. MNIST 데이터셋을 사용할 것이기 때문에, torchvision 라이브러리를 불러옵니다.

 

먼저 파라미터를 설정해줍니다. 만들고자 하는 클래스는 총 10개로 설정했습니다. 그 다음 nz는 생성자의 인풋으로 들어가는 latent vector “z”를 뽑는 noise 분포의 차원이라고 생각하시면 됩니다. 100으로 설정해주고요. Input 사이즈는 채널 128 by 28 사이즈로 설정합니다.

생성자 구현 코드를 보면, 하나의 선형 함수를 거친 뒤에, Batch Normalization을 수행하고, 활성화 함수로는 LeakyReLU를 사용했습니다. 이 과정을 반복해 연속적인 여러 블록을 거쳐주고, 마지막으로는 1 x 28 x 28짜리의 하나의 MNIST 데이터를 생성할 수 있도록 합니다. 그 다음 마지막에 하이퍼볼릭 탄젠트 함수를 붙여서, -1부터 1 사이의 값을 가질 수 있도록 합니다.

이제 Forward 함수를 확인해보시면, noise 분포에서 뽑은 vector z 였죠. 여기선 noise로 표현되어 있는데, 이 벡터에다가 추가하고자 하는 조건인 label input으로 받습니다. 이 두개를 받아 torch.cat으로 결합해줍니다. 이 결합해준 데이터를 모델에 넣고, view 함수를 이용해서 이미지의 형태를 가질 수 있도록 만듭니다.

 

 

여기 보시는 것은 비교를 위해서, 일반 GAN의 생성자 코드인데요. 다른 사람이 쓴 코드를 가져와서 표현 방식은 약간 다른데, 내용은 같습니다.

여기는 아까 Linear을 빠져나와서 Batch Normalization LeakyReLU를 수행해주는 것을 block으로 묶었는데요, 이 과정은 동일합니다. 이렇게 여러 블록을 거쳐 나온 모델을 하이퍼볼릭 탄젠트 함수를 이용해 변형해주고, forward 함수에 집어넣는데, 이때 일반 GAN에서는 label이랑 concat 하는 과정 없이 그냥 노이즈 데이터만 집어넣어 줍니다. 

 

 

판별자 같은 경우는 생성자와 반대로, 한 장의 이미지가 들어 왔을 때, 그 이미지를 판별하기 위해 여러개의 Linear LeakyReLU activation function을 붙여서 결과적으로 sigmoid를 이용해 확률 값으로 나올 수 있도록 만듭니다. Forward 함수를 보시면, 여기도 마찬가지로 Conditional GAN이기 때문에 조건부 label과 인풋 이미지를 결합해줍니다. 그것을 다시 “discriminator” 모델에 넣어서 결과를 구해줍니다.

 

 

이제부터는 일반 GAN과 코드를 완전히 공유하는 부분이라서 복습 개념으로 코드를 준비했습니다.

이제 정의했던 생성자, 판별자 클래스를 이용해서 각각의 instance를 초기화해주는데요. , 여기 밑에는 cuda function을 이용해서 gpu에 올립니다. 손실함수는 여기 BCELoss, Binary Cross Entropy 로 설정을 해주고, 생성자와 판별자는 각각 Adam 옵티마이저를 이용해서 학습할 수 있도록 합니다. 여기 학습률이랑 베타 파라미터 값은 일반적으로 GAN에 제일 많이 사용되는 파라미터 그대로 가지고 왔습니다.

 

 

이제 모델 학습 과정인데요. 진짜 레이블과 가짜 레이블을 만들어서 학습에 사용합니다.

또 여기 밑에서 볼 부분은 여기 생성자의 손실 값 계산부분에서, 생성자 입장에서는 자신이 만든 이미지가 real 이미지로 분류될 수 있도록 해야 하기 때문에, 그쪽 방향으로 학습을 진행하고 있고요. 여기 판별자에서는 진짜 이미지는 real, 가짜 이미지는 fake로 잘 분류를 해줘야 하기 때문에, 그쪽 방향으로 학습을 진행해줍니다.