Diffusion model에 대해 찾아보니 VAE(Variational Auto Encoder)와 관련된 개념이 많아 공부하게 되었다.
일단 전체적인 구조를 보면 다음과 같다.
VAE는 개념상으로 Encoder 와 Decoder 로 이루어져있는데, 입력 데이터 $x_i$에 대해 각각의 역할을 다음으로 정리해 볼 수 있다.
Encoder
입력 $x_i$를 통해 확률공간$q_{\phi}(z_i \mid x_i)$을 만든다. 이 확률공간은 latent space로 사용되며 가우시안 분포의 형태를 가진다.
가우시안 분포를 만들기 위해선 평균과 표준편차가 필요하므로, 그림처럼 Encoder는 $x_i$를 통해 평균벡터$\mu_i$와 표준편차벡터$\sigma_i$를 생성한다.
Reparameterization Trick
Encoder가 생성해낸 latent space에서 latent vector인 $z_i$를 추출하는데 사용되는 방법이다.
이 과정이 필요한 이유는, Encoder가 생성해낸 두 벡터에 대한 연산을 정의해야 역전파시 미분이 가능해지기 때문이다.
Decoder
latent vector인 $z_i$를 통해 output을 생성해낸다.
논문의 Problem Scenario 를 보면 “주어진 데이터 $x_i$가 어떻게 생성되었을까?” 라는 질문과 그것에 대한 가정이 전제가 된다.
논문의 전제
$x_i$는 관측 불가하며 연속적인 $z$ 라는 랜덤 변수를 통해 생성된다.
완벽한 주어진 데이터에 대한 완전한 모방이 가능한 파라미터 $\theta^*$ 가 주어질때 $x_i$가 생성되는 과정은 다음과 같다.
VAE의 목표
주어진 데이터 $x_i$에 대해 $p_{\theta}(x_i)$를 최대화하는 파라미터 $\theta$를 찾는다.
즉 주어진 데이터와 가장 유사한 generative model을 생성하는것을 확률적으로 접근한 것 이다.
2번을 통해 $p_{\theta}(x)$를 최대화 시키는것이 VAE의 목적임을 알았다. 하지만 $p_{\theta}(x)$를 직접 모델링하는데에는 몇가지 걸림돌이 있다.
위의 문제점을 해결하기 위해 VAE는 Encoder($q_{\phi}(z \mid x)$)를 통해 $p_{\theta}(z \mid x)$를 모방한다.
따라서 VAE는 Encoder를 통해 추출된 $z$를 기반으로 $p_{\theta}(x)$를 최대화하는 목표를 가진다.
수식으로 나타내면 $E_{z \sim q_{\phi}(z \mid x)}[p_{\theta}(x)]$를 최대화 하는것.
$E_{z \sim q_{\phi}(z \mid x)}[ p_{\theta}(x)]$를 최대화 하는것은 $E_{z \sim q_{\phi}(z \mid x)}[ \log p_{\theta}(x) ]$를 최대화 하는것과 같다.
여기서 KL은 kullback leibler divergence를 의미한다.
\(E_{z \sim q_{\phi}(z \mid x)}[ \log p_{\theta}(x) ] = E_{z}[ \log p_{\theta}(x \mid z) ] - E_{z}[ \log \frac{q_{\phi}(z \mid x)}{p_{\theta}(z)} ] + E_{z}[ \log \frac{q_{\phi}(z \mid x)}{p_{\theta}(z \mid x)} ] \\
= E_{z}[ \log p_{\theta}(x \mid z) ] - \int q_{\phi}(z \mid x) \log \frac{q_{\phi}(z \mid x)}{p_{\theta}(z)} dz + \int q_{\phi}(z \mid x) \log \frac{q_{\phi}(z \mid x)}{p_{\theta}(z \mid x)} \\
= E_{z}[ \log p_{\theta}(x \mid z) ] - KL(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z)) + KL(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z \mid x))\)
위처럼 계산된 수식에 아래의 두 가지 요인을 적용하여 loss를 단순화 한다.
따라서 마지막 항을 제거하고 $E_{z \sim q_{\phi}(z \mid x)}[ \log p_{\theta}(x) ]$의 하한값을 최대화 시키는 문제로 수식을 변형 할 수 있다.
최종적으로 lower bound로서 loss를 정의할 수 있다.
\(E_{z \sim q_{\phi}(z \mid x)}[ \log p_{\theta}(x) ] \ge E_{z}[ \log p_{\theta}(x \mid z) ] - KL(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z))\)
\(E_{z}[ \log p_{\theta}(x \mid z) ] - KL(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z))\)
위 수식을 최대화 하는 문제를 아래의 값을 최소화 시키는것 으로 변형 할 수 있다.(일반적인 loss를 계산할 때 처럼)
\({-} E_{z}[ \log p_{\theta}(x \mid z) ] + KL(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z))\)
이제 각각의 항들을 계산하면된다.
각 항은 모델의 출력값과 입력값을 비교하는 Reconstruction Error 와 Encoder가 만들어내는 분포와 $p_{\theta}(z)$의 차이를 줄이는 Regularization 으로 볼 수 있다.
\({-} E_{z \sim q_{\phi}(z \mid x)}[ \log p_{\theta}(x \mid z) ] = - \int q_{\phi}(z \mid x) \log p_{\theta}(x \mid z) dz\)
위에서 모든 $z$에 대한 계산이 불가능하기 때문에 Monte-carlo technique를 사용한다.
이 방법은 전체 $z$를 구하는대신에 $L$개의 $z$에 대한 계산만 진행하는것 이다.
VAE는 $L$을 1로 두고 계산한다.
\({-} \log p_{\theta}(x \mid z)\)
계산하기 위해 또 생각해야할 것이 있는데, 바로 $x$의 분포이다. $x_i$가 주어진 입력 데이터 벡터일때.
$x_i$의 각 값들이 Bernoulli Distribution이라 가정하면
모델의 출력은 $p_i$라는 하나의 벡터가 되야한다.
Reconstruction Error는 Cross Entropy로 유도된다.
\({-} \log p_{\theta}(x_i \mid z_i)
= - \log ( \prod_{j=1}^D p_{\theta}(x_{i, j} \mid z_i))
= - \sum_{j=1}^D \log(p_{\theta}(x_{i, j} \mid z_i))
= - \sum_{j=1}^D \log(p_{i,j}^{x_{i,j}}(1-p_{i,j})^{1-x_{i,j}})
= - \sum_{j=1}^D \{ x_{i,j}\log p_{i,j} + (1 - x_{i,j}) \log (1-p_{i,j}) \}\)
$x_i$의 각 값들이 Gaussian Distribution이라 가정하면
모델은 $\acute{\mu_i}$와 $\acute{\sigma_i}$ 두 벡터를 출력해야한다.
Reconstruction Error는 표준편차가 1일때 MSE가 된다.
\({-} E_{z}[ \log p_{\theta}(x \mid z) ]
\approx {-} \log p_{\theta}(x_i \mid z_i)
= - \log ( \mathcal{N}(x_i; \acute{\mu_i}, \acute{\sigma_i}^2 I) )
= \sum_{j=1}^{D} \{ \frac{1}{2} \log \acute{\sigma_{i,j}}^2 + \frac{(x_{i,j} - \acute{\mu_{i,j}})^2}{2 \acute{\sigma_{i,j}}^2} \}\)
위 항을 최소화하는데 두가지 가정이 필요하다.
위 두 가정을 통해 Regularization식을 $\mu_i$ 벡터와 $\sigma_i$ 벡터에 대한 식으로 변형 가능하다. \(KL(q_{\phi}(z_i \mid x_i) \mid \mid p_{\theta}(z_i)) = KL(\mathcal{N}(\mu_i, \sigma_i^2 I) \mid \mid p_{\theta}(z_i)) = \frac{1}{2} \sum_{j=1}^J \{ \mu_{i,j}^2 + \sigma_{i,j}^2 - \ln (\sigma_{i,j}^2) - 1 \}\)
각 input data $x_i$에 대해.
\(Loss = {-} \sum_{j=1}^D \{ x_{i,j}\log p_{i,j} + (1 - x_{i,j})\log(1 - p_{i,j}) \} + \frac{1}{2}\sum_{j=1}^J\{ \mu_{i,j}^2 + \sigma_{i,j}^2 - \ln(\sigma_{i,j}^2) - 1 \}\)
이때 $\mu_{i,j}, \sigma_{i,j}, \epsilon_i \in R^J$ 이고 $x_i, p_i \in R^D$ 이다.
각 input data $x_i$에 대해.
\(Loss = \sum_{j=1}^{D} \{ \frac{1}{2} \log \acute{\sigma_{i,j}}^2 + \frac{(x_{i,j} - \acute{\mu_{i,j}})^2}{2 \acute{\sigma_{i,j}}^2} \} +
\frac{1}{2}\sum_{j=1}^J\{ \mu_{i,j}^2 + \sigma_{i,j}^2 - \ln(\sigma_{i,j}^2) - 1 \}\)
이때 $\mu_{i,j}, \sigma_{i,j}, \epsilon_i \in R^J$ 이고 $x_i,\acute{\mu_i}, \acute{\sigma_i} \in R^D$ 이다.
Gaussian Encoder + Bernoulli Decoder를 pytorch로 구현해봤다.
import torch.nn as nn
import torch
class BernoulliMLP(nn.Module):
def __init__(self, device):
super(BernoulliMLP, self).__init__()
self.device = device
self.J = 10 # Encoder가 출력하는 벡터들의 크기
self.D = 116412 # 입력 벡터의 크기 ex. (3, 218, 178)의 사이즈를 가지는 이미지
self.encoder = nn.Sequential(
nn.Linear(in_features=116412, out_features=1164),
nn.ReLU(),
nn.Linear(in_features=1164, out_features=116),
nn.ReLU(),
nn.Linear(in_features=116, out_features=self.J * 2)
)
self.decoder = nn.Sequential(
nn.Linear(in_features=self.J, out_features=116),
nn.ReLU(),
nn.Linear(in_features=116, out_features=1164),
nn.ReLU(),
nn.Linear(in_features=1164, out_features=116412)
)
def forward(self, x):
batch_size = len(x)
x = x.view(batch_size, -1)
####
# Encoder에서 평균과 표준편차 벡터를 추출한다.
####
code = self.encoder(x) # [batch, J * 2]
mean, std = code[:, 0 : self.J], code[:, self.J : ]
####
# 벡터 z를 추출하기 위해 표준정규분포에서 epsilon을 추출한다.
####
epsilon = torch.normal(mean=torch.zeros(size=(batch_size, self.J)), std=torch.ones(size=(batch_size, self.J))).to(self.device)
####
# reparameterization technique
####
z = mean + std * epsilon
####
# Decoder로 출력값을 얻는다.
# 이때 Bernoulli 분포의 값은 0혹은 1 이므로 둘 사이의 값이 되도록 sigmoid를 적용한다.
####
p = self.decoder(z)
p = torch.sigmoid(p)
return p, mean, std
import torch.nn as nn
import torch
class BernoulliLoss(nn.Module):
def __init__(self):
super(BernoulliLoss, self).__init__()
def forward(self, x, p, mean, std):
"""_summary_
아래의 직접
Args:
x (_type_): input data
p (_type_): output of bernoulli decoder
mean (_type_): mean tensor from gaussian encoder
std (_type_): std tensor from gaussian encoder
"""
batch_size = len(x)
x = torch.flatten(input=x, start_dim=1) # 배치단위이기 때문에 dim=1 부터 flatten
p = torch.flatten(input=p, start_dim=1)
reconstruction_err = - (1 - x) * torch.log(1 - p + 1e-8) - x * torch.log(p + 1e-8)
reconstruction_err = torch.sum(reconstruction_err)
regularization_err = 0.5 * (mean**2 + std**2 - torch.log(std**2 + 1e-8) - 1)
regularization_err = torch.sum(regularization_err)
return (reconstruction_err + regularization_err) / batch_size