- Decoder 및 AuxLoss 모듈 구조
이 두 모듈은 Pyramid Pooling 또는 Feature 모듈에서 출력된 텐서 정보를 Decode(읽기) 한다. 텐서 정보를 읽은 후 픽셀별로 물체 라벨을 클래스 분류로 추정하고 마지막으로 화상 크기를 원래의 475x475로 업샘플링 한다.
Decoder 및 AuxLoss 모듈에서 최종 출력되는 텐서는 21x475x475(클래스 수 x 높이 x 폭)이다. 출력 텐서는 화상의 각 픽셀에 21종류 클래스의 확률이 대응하는 값(신뢰도)이다. 해당 값이 가장 큰 클래스가 그 픽셀이 속하는 것으로 예측된 물체 라벨이다.
- Decoder 및 AuxLoss 모듈 구현
class DecodePSPFeature(nn.Module):
def __init__(self, height, width, n_classes):
super(DecodePSPFeature, self).__init__()
# forward에 사용하는 화상 크기
self.height = height
self.width = width
self.cbr = conv2DBatchNormRelu(
in_channels=4096, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
self.dropout = nn.Dropout2d(p=0.1)
self.classification = nn.Conv2d(
in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.cbr(x)
x = self.dropout(x)
x = self.classification(x)
output = F.interpolate(
x, size=(self.height, self.width), mode="bilinear", align_corners=True)
return output
class AuxiliaryPSPlayers(nn.Module):
def __init__(self, in_channels, height, width, n_classes):
super(AuxiliaryPSPlayers, self).__init__()
# forward에 사용하는 화상 크기
self.height = height
self.width = width
self.cbr = conv2DBatchNormRelu(
in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
self.dropout = nn.Dropout2d(p=0.1)
self.classification = nn.Conv2d(
in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.cbr(x)
x = self.dropout(x)
x = self.classification(x)
output = F.interpolate(
x, size=(self.height, self.width), mode="bilinear", align_corners=True)
return output
마지막 클래스 분류(self.classification) 시 전결합 층을 사용하지 않고 클래스 수와 동일하게 21을 출력 채널로 하는 커널 크기 1의 합성곱 층을 사용하는 것이 특징이다.
PSPNet 네트워크 구조, 네트워크 순전파 계산을 모두 구현했다. 마지막으로 네트워크 모델인 PSPNet 클래스의 인스턴스를 작성하고 오류 없이 계산되는지 확인해보자.
net=PSPNet(n_classes=21)
net
batch_size=2
dummy_img=torch.rand(batch_size,3,475,475)
outputs=net(dummy_img)
print(outputs)
'Semantic Segmentation' 카테고리의 다른 글
8. 시맨틱 분할 추론 (0) | 2021.11.21 |
---|---|
7. 파인튜닝을 활용한 학습 및 검증 실시 (0) | 2021.11.21 |
5. Pyramid Pooling 모듈 (0) | 2021.11.20 |
4. Feature 모듈 설명 및 구현(ResNet) (0) | 2021.11.20 |
3. PSPNet 네트워크 구성 및 구현 (0) | 2021.11.20 |