U-Net-GAN
본 포스팅은 U-Net-GAN에 대한 내용입니다.
U-Net-GAN 이란?
U-Net-GAN 모델은 GAN의 생성기 부분에 U-Net 구조를 넣은 모델입니다. U-Net을 통해 집약된 정보를 얻고, 이를 이용해 가짜 이미지를 생성합니다. 생성한 이미지는 판별기를 거쳐 실제 이미지와 얼마나 유사한지 계산합니다. 유사도를 바탕으로 실제와 비슷한 가짜를 생성할 때 까지 반복합니다. 이는 기존의 노이즈에서 이미지를 생성하는 것 보다 학습 시간을 단축시켰다는 장점이 있습니다.
구현
다음 코드는 각 모델의 구현 예시입니다.
<생성기> - U-Net
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 논문의 파란색 화살표
def CBR2d(in_channels, out_channels, kernel_size, stride, padding, bias=True):
layers = []
# Conv2d layer 정의
layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias)]
# layers += [nn.BatchNorm2d(num_features=out_channels)]
layers += [nn.ReLU()]
cbr = nn.Sequential(*layers)
return cbr
# Contracting path
# 좌측 레이어 enc (인코더)
# 1층 좌측 첫번째 레이어 두개
self.enc1_1 = CBR2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
self.enc1_2 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
# 다음 빨간색 화살표 max_pool 2*2
self.pool1 = nn.MaxPool2d(kernel_size=2)
# 2층 파란색 화살표
self.enc2_1 = CBR2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
self.enc2_2 = CBR2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
# 다음 빨간색 화살표 max_pool 2*2
self.pool2 = nn.MaxPool2d(kernel_size=2)
# 3층 파란색 화살표
self.enc3_1 = CBR2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
self.enc3_2 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
# 다음 빨간색 화살표 max_pool 2*2
self.pool3 = nn.MaxPool2d(kernel_size=2)
# 4층 파란색 화살표
self.enc4_1 = CBR2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
self.enc4_2 = CBR2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
# 다음 빨간색 화살표 max_pool 2*2
self.pool4 = nn.MaxPool2d(kernel_size=2)
# 5층 파란색 화살표
self.enc5_1 = CBR2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=True)
# Expansive path
# 5층 파란색 2번쨰 화살표인데 디코더로
self.dec5_1 = CBR2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
# 초록색 화살표
self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2, padding=0,
bias=True)
# enc4_2와 대칭이되는 점을 보면 dec4_2 input값은 512가 맞는데, unet 아키텍쳐를 보니
# enc4_2 에서 회색 화살표로 dec4_2로 와서 copy and crop이 일어남
# 따라서 dec4_2 in_channels = 1024로 설정
self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
self.dec4_1 = CBR2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
# 초록색 화살표
self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
kernel_size=2, stride=2, padding=0, bias=True)
# 3층 파란색 화살표
self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
self.dec3_1 = CBR2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
# 초록색 화살표
self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
kernel_size=2, stride=2, padding=0, bias=True)
# 2층 파란색 화살표
self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
self.dec2_1 = CBR2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
# 초록색 화살표
self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
kernel_size=2, stride=2, padding=0, bias=True)
# 1층 파란색 화살표
self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
self.dec1_1 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
# segmentation에 필요한 n개의 클래스에 대한 output 정의
self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
# 좌측 1층 레이어 2개 연결 및 빨간색 화살표
enc1_1 = self.enc1_1(x)
enc1_2 = self.enc1_2(enc1_1)
pool1 = self.pool1(enc1_2)
# 좌측 2층 레이어 2개 연결 및 빨간색 화살표
enc2_1 = self.enc2_1(pool1)
enc2_2 = self.enc2_2(enc2_1)
pool2 = self.pool2(enc2_2)
# 좌측 3층 레이어 2개 연결 및 빨간색 화살표
enc3_1 = self.enc3_1(pool2)
enc3_2 = self.enc3_2(enc3_1)
pool3 = self.pool3(enc3_2)
# 좌측 4층 레이어 2개 연결 및 빨간색 화살표
enc4_1 = self.enc4_1(pool3)
enc4_2 = self.enc4_2(enc4_1)
pool4 = self.pool4(enc4_2)
# 좌측 5층 레이어
enc5_1 = self.enc5_1(pool4)
# 우측 5층 레이어 및 초록색 화살표
dec5_1 = self.dec5_1(enc5_1)
unpool4 = self.unpool4(dec5_1)
# 하얀색 부분 연결하기
cat4 = torch.cat((unpool4, enc4_2), dim=1)
# 파란색 화살표 실행
# cat에서 512 + 512 로 1024의 레이어 만들고 파란색 화살표 수행후 아웃풋값을 512로 만듬
dec4_2 = self.dec4_2(cat4)
# 여기까지 하면 우측 4층 레이어까지 생성
dec4_1 = self.dec4_1(dec4_2)
# 반복 3층
unpool3 = self.unpool3(dec4_1)
cat3 = torch.cat((unpool3, enc3_2), dim=1)
dec3_2 = self.dec3_2(cat3)
dec3_1 = self.dec3_1(dec3_2)
# 반복 2층
unpool2 = self.unpool2(dec3_1)
cat2 = torch.cat((unpool2, enc2_2), dim=1)
dec2_2 = self.dec2_2(cat2)
dec2_1 = self.dec2_1(dec2_2)
# 반복 1층
unpool1 = self.unpool1(dec2_1)
cat1 = torch.cat((unpool1, enc1_2), dim=1)
dec1_2 = self.dec1_2(cat1)
dec1_1 = self.dec1_1(dec1_2)
x = self.fc(dec1_1)
return x
<판별기> - GAN
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 3),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
참고문헌
[1] U-Net
[2] GAN
[3] U-Net 코드
[4] GAN 코드