2023-05-31,   김혜원

본 포스팅은 U-Net-GAN에 대한 내용입니다.


 

U-Net-GAN 이란?

 

U-Net-GAN 모델은 GAN의 생성기 부분에 U-Net 구조를 넣은 모델입니다. U-Net을 통해 집약된 정보를 얻고, 이를 이용해 가짜 이미지를 생성합니다. 생성한 이미지는 판별기를 거쳐 실제 이미지와 얼마나 유사한지 계산합니다. 유사도를 바탕으로 실제와 비슷한 가짜를 생성할 때 까지 반복합니다. 이는 기존의 노이즈에서 이미지를 생성하는 것 보다 학습 시간을 단축시켰다는 장점이 있습니다.

unet-gan

 

구현

 

다음 코드는 각 모델의 구현 예시입니다.

 

<생성기> - U-Net

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # 논문의 파란색 화살표
        def CBR2d(in_channels, out_channels, kernel_size, stride, padding, bias=True):
            layers = []
            # Conv2d layer 정의
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                kernel_size=kernel_size, stride=stride, padding=padding,
                                bias=bias)]
            # layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr

        # Contracting path
        # 좌측 레이어 enc (인코더)
        # 1층 좌측 첫번째 레이어 두개
        self.enc1_1 = CBR2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)

        # 다음 빨간색 화살표 max_pool 2*2
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        # 2층 파란색 화살표
        self.enc2_1 = CBR2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)

        # 다음 빨간색 화살표 max_pool 2*2
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        # 3층 파란색 화살표
        self.enc3_1 = CBR2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)

        # 다음 빨간색 화살표 max_pool 2*2
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        # 4층 파란색 화살표
        self.enc4_1 = CBR2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)

        # 다음 빨간색 화살표 max_pool 2*2
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        # 5층 파란색 화살표
        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=True)

        # Expansive path

        # 5층 파란색 2번쨰 화살표인데 디코더로
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)

        # 초록색 화살표
        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2, padding=0,
                                        bias=True)

        # enc4_2와 대칭이되는 점을 보면 dec4_2 input값은 512가 맞는데, unet 아키텍쳐를 보니
        # enc4_2 에서 회색 화살표로 dec4_2로 와서 copy and crop이 일어남
        # 따라서 dec4_2 in_channels = 1024로 설정
        self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)

        # 초록색 화살표
        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                        kernel_size=2, stride=2, padding=0, bias=True)
        # 3층 파란색 화살표
        self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)

        # 초록색 화살표
        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                        kernel_size=2, stride=2, padding=0, bias=True)

        # 2층 파란색 화살표
        self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)

        # 초록색 화살표
        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                        kernel_size=2, stride=2, padding=0, bias=True)

        # 1층 파란색 화살표
        self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)

        # segmentation에 필요한 n개의 클래스에 대한 output 정의
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
        # 좌측 1층 레이어 2개 연결 및 빨간색 화살표
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        # 좌측 2층 레이어 2개 연결 및 빨간색 화살표
        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        # 좌측 3층 레이어 2개 연결 및 빨간색 화살표
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        # 좌측 4층 레이어 2개 연결 및 빨간색 화살표
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        # 좌측 5층 레이어
        enc5_1 = self.enc5_1(pool4)

        # 우측 5층 레이어 및 초록색 화살표
        dec5_1 = self.dec5_1(enc5_1)
        unpool4 = self.unpool4(dec5_1)

        # 하얀색 부분 연결하기
        cat4 = torch.cat((unpool4, enc4_2), dim=1)

        # 파란색 화살표 실행
        # cat에서 512 + 512 로 1024의 레이어 만들고 파란색 화살표 수행후 아웃풋값을 512로 만듬
        dec4_2 = self.dec4_2(cat4)

        # 여기까지 하면 우측 4층 레이어까지 생성
        dec4_1 = self.dec4_1(dec4_2)

        # 반복 3층
        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        # 반복 2층
        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        # 반복 1층
        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x


<판별기> - GAN
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 3),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

참고문헌

[1] U-Net

[2] GAN

[3] U-Net 코드

[4] GAN 코드

업데이트: