이 논문에서는 Semi-supervised video object segmentation을 위한 새로운 Matching-based framework를 제안한다.
기존의 matching-based 방식에 region matching과 classification을 위해서 feature bank를 사용했지만, 지속적으로 커지는 feature bank에 정보를 효과적으로 구성하는 방법을 찾진 못했다.
이를 해결하기 위해 저자는 동적으로 새로운 feature을 흡수하고 구식의 feature는 무시하는 adaptive한 feature bank를 소개하고 있다.
또한 불확실한 region에 대한 segmentation accuracy를 높이기 위해 새로운 confidence loss와 fine-grained segmentation module을 제안한다.

Proceeding: NIPS 2020
Authors: Yongqing Liang, Xin Li, Navid Jafari, Jim Chen

Code Link
Paper Link

Introduction

VOS(Video Object Segmentation)의 semi-supervised setting에서 성능에 가장 영향을 주는 두 스텝은
(1) object의 region과 배경과 구별
(2) segmentation의 boundary를 깔끔하게 구별
하는 것이다.

최근의 연구를 implicit learning과 explicit learing으로 카테고리화 할 수 있다.
implicit 방식은 detection-based와 propagation-based learning을 포함한다. FCN 파이프라인을 사용하고 종종 비디오에서 새로운 물체에 대해 학습하기 위해 online fine-tune을 요구한다.
explicit 방식은 물체의 appearance를 명시적으로 학습한다. 이 embedding 공간에서 pixel-wise classification로 공식화한다.
object의 appearance를 기억하기 위해 embedding 공간을 설계하고 similarity 계산 이용해서 subsequent frame들을 segment 한다.
그래서 이방식을 matching-based라고 칭하기도 한다. 최근에 이 방식은 VOS 벤치마크에서 SOTA를 달성하였다.

matching-based VOS의 근본적인 이슈는 어떻게 이전 프레임의 정보를 이용하여 새로운 프레임을 명시적으로 segment 하느냐에 있다.
메모리 크기는 제한적이고, 모든 프레임의 정보들을 기억하는 것은 불가능하다.
대부분의 방식은 처음과 최근의 프레임을 이용하거나 프레임 uniformly sampling하여 사용한다.
하지만 비디오가 점차 길어지면서 out-of-memory crash 문제를 일으키거나 key-frame을 샘플링 하지 못하는 문제를 발생시킨다.

저자는 이를 해결하기 위해 Adaptive Feature Bank(AFB)를 도입하였다.
weighted averaging과 least frequently used(LFU) index를 이용하여 구식의 feature를 무시하는 매커니즘을 사용한다.
이 방법으로 다양한 object의 특징을 기억할 수 있고, 긴 비디오에서도 적은 메모리 만으로 동시에 segmentation을 수행할 수 있다.

또한 uncertain-region refinement(URR)을 이용하여 segmentation의 퀄리티를 높였다. 이는 segmentation의 모호함을 평가하는 새로운 confidence loss와 이 모호한 영역을 refine하기 위한 local fine-grained segmentation을 이용한다.

Approach

Figure 1

Matching-based segmentation

첫번째 프레임과 gt annotation을 인코딩하여 feature bank를 설계하는데 사용한다.
그다음 두번째 프레임부터 target object를 segmentation 하기 위해 feature bank를 이용한다.
decoder는 matching 결과를 이용하여 segmentation을 수행한다.

Encoders

  • query Encoder: 현재의 frame을 인코딩하며, named query frame을 이용한다.
    • ResNet-50의 layer3 feature map을 이용 (Size: H/8 X W/8 X 1024)
  • Reference Encoder: Target object의 특징을 기억하기 위해 설계
    • t-th frame을 segmentation 하기 위해서 past frame으로 1부터 t-1번째 까지를 사용
    • 각 프레임의 reference frame: L개의 물체를 구분하기 위해서 L 개의 feature map을 사용 (Size: L X H/8 X W/8 X 1024)
    • Pj = {p1, p2, …, pL}, j ∈ [1, t − 1].

Feature map embedding

기존의 matching-based 방식에서는 query feature map과 reference feature map간에 direct하게 비교를 수행했다.
classification에서는 좋은 성능을 내지만 segmantic information에는 부족한 방식이다.
이를 보완하기 위해 두 convolutional module을 이용하여 두 개의 key와 value embedding space로 인코딩을 수행한다.
key를 이용하여 matching을 수행하고 value는 풍부한 sementic 정보를 보존하기 위해 다른 값을 갖을 수 있도록 허용한다.

Matcher

query frame과 feature bank 간의 similarity는 object by object 간에 수행된다. Figure2
Figure2-1
i ∈ [1, L]
query value와 가장 비슷한 value map과 concatenate를 수행 한다. Figure2-2

    def forward(self, feature_bank, q_in, q_out):

        mem_out_list = []

        for i in range(0, feature_bank.obj_n):
            d_key, bank_n = feature_bank.keys[i].size()

            try:
                p = torch.matmul(feature_bank.keys[i].transpose(0, 1), q_in) / math.sqrt(d_key)  # THW, HW
                p = NF.softmax(p, dim=1)  # bs, bank_n, HW
                mem = torch.matmul(feature_bank.values[i], p)  # bs, D_o, HW
            except RuntimeError as e:
                device = feature_bank.keys[i].device
                key_cpu = feature_bank.keys[i].cpu()
                value_cpu = feature_bank.values[i].cpu()
                q_in_cpu = q_in.cpu()

                p = torch.matmul(key_cpu.transpose(0, 1), q_in_cpu) / math.sqrt(d_key)  # THW, HW
                p = NF.softmax(p, dim=1)  # bs, bank_n, HW
                mem = torch.matmul(value_cpu, p).to(device)  # bs, D_o, HW
                p = p.to(device)
                print('\tLine 158. GPU out of memory, use CPU', f'p size: {p.shape}')

            mem_out_list.append(torch.cat([mem, q_out], dim=1))

            if self.update_bank:
                try:
                    ones = torch.ones_like(p)
                    zeros = torch.zeros_like(p)
                    bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0]
                except RuntimeError as e:
                    device = p.device
                    p = p.cpu()
                    ones = torch.ones_like(p)
                    zeros = torch.zeros_like(p)
                    bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device)
                    print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}')

                feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1)

        mem_out_tensor = torch.stack(mem_out_list, dim=0).transpose(0, 1)  # bs, obj_n, dim, pixel_n

        return mem_out_tensor

Decoder

matching 결과인 y를 이용하여 각 object의 mask를 독립적으로 예측한다.
refinement module을 이용하여 feature map을 점진적으로 upscale한다.
각 stage에서 previous stage와 query encoder로부터 온 feature map간 skip connection을 이용한다.
최종적으로 object i에 대한 마스크인 Mi를 구한다.
Figure3

Adaptive Feature Bank

AFB의 주요 특징은 새로운 feature를 흡수하고 구식은 제거하는 것이다.

Absorbing new features

대부분 최근의 프레임이 더 중요한 정보를 가지고 있지만 초기의 프레임도 필요한 정보를 담고 있을 수 있다.
그래서 초기 프레임을 단순히 무시하기 보다 이를 저장하고 weighted averaging을 이용하여 관리하는 방식을 취한다. 새로운 특징이 추출되고, 이전 특징과 충분히 가깝다면 이들을 merge 하므로써 중복된 정보를 저장하는 것을 피할 수 있다.
Figure4

new feature embedding: a(i) = (Kp, Vp)
old feature embeddings: b(j) = (Kfb, Vfb)
Similarity Function: Figure5

a(i)에 대해서 가장 비슷한 b(j`)을 구한 후 merging을 수행
Figure6
비슷하지 않으면 feature bank에 추가한다.

    def update(self, prev_key, prev_value, frame_idx):

        for class_idx in range(self.obj_n):

            d_key, bank_n = self.keys[class_idx].shape
            d_val, _ = self.values[class_idx].shape

            normed_keys = NF.normalize(self.keys[class_idx], dim=0)
            normed_prev_key = NF.normalize(prev_key[class_idx], dim=0)
            mag_keys = self.keys[class_idx].norm(p=2, dim=0)
            corr = torch.mm(normed_keys.transpose(0, 1), normed_prev_key)  # bank_n, prev_n
            related_bank_idx = corr.argmax(dim=0, keepdim=True)  # 1, HW
            related_bank_corr = torch.gather(corr, 0, related_bank_idx)  # 1, HW

            # greater than threshold, merge them
            selected_idx = (related_bank_corr[0] > self.thres_close).nonzero()
            class_related_bank_idx = related_bank_idx[0, selected_idx[:, 0]]  # selected_HW
            unique_related_bank_idx, cnt = class_related_bank_idx.unique(dim=0, return_counts=True)  # selected_HW

            # Update key
            key_bank_update = torch.zeros((d_key, bank_n), dtype=torch.float, device=self.device)  # d_key, THW
            key_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_key, -1)  # d_key, HW
            scatter_mean(normed_prev_key[:, selected_idx[:, 0]], key_bank_idx, dim=1, out=key_bank_update)
            # d_key, selected_HW

            self.keys[class_idx][:, unique_related_bank_idx] = \
                mag_keys[unique_related_bank_idx] * \
                ((1 - self.update_rate) * normed_keys[:, unique_related_bank_idx] + \
                 self.update_rate * key_bank_update[:, unique_related_bank_idx])

            # Update value
            normed_values = NF.normalize(self.values[class_idx], dim=0)
            normed_prev_value = NF.normalize(prev_value[class_idx], dim=0)
            mag_values = self.values[class_idx].norm(p=2, dim=0)
            val_bank_update = torch.zeros((d_val, bank_n), dtype=torch.float, device=self.device)
            val_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_val, -1)
            scatter_mean(normed_prev_value[:, selected_idx[:, 0]], val_bank_idx, dim=1, out=val_bank_update)

            self.values[class_idx][:, unique_related_bank_idx] = \
                mag_values[unique_related_bank_idx] * \
                ((1 - self.update_rate) * normed_values[:, unique_related_bank_idx] + \
                 self.update_rate * val_bank_update[:, unique_related_bank_idx])

            # less than the threshold, concat them
            selected_idx = (related_bank_corr[0] <= self.thres_close).nonzero()

            if self.class_budget < bank_n + selected_idx.shape[0]:
                self.remove(class_idx, selected_idx.shape[0], frame_idx)

            self.keys[class_idx] = torch.cat([self.keys[class_idx], prev_key[class_idx][:, selected_idx[:, 0]]], dim=1)
            self.values[class_idx] = \
                torch.cat([self.values[class_idx], prev_value[class_idx][:, selected_idx[:, 0]]], dim=1)

            new_info = torch.zeros((selected_idx.shape[0], 2), device=self.device)
            new_info[:, 0] = frame_idx
            self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)

            self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])

            self.info[class_idx][:, 1] = torch.clamp(self.info[class_idx][:, 1], 0, 1e5)  # Prevent inf

Removing obsolete features

cache 정책과 비슷하게 오래된 feature에 대해서 최근 사용된 빈도를 계산한다.(LFU)
만약 각 matching에서 similarity가 10^-4보다 크다면 count를 증가시킨다
Figure7
feature bank의 크기가 정해진 크기를 초과하면 삭제를 진행한다.

    def remove(self, class_idx, request_n, frame_idx):

        old_size = self.keys[class_idx].shape[1]

        LFU = frame_idx - self.info[class_idx][:, 0]  # time length
        LFU = self.info[class_idx][:, 1] / LFU
        thres_dynamic = int(LFU.min()) + 1
        iter_cnt = 0

        while True:
            selected_idx = LFU > thres_dynamic
            self.keys[class_idx] = self.keys[class_idx][:, selected_idx]
            self.values[class_idx] = self.values[class_idx][:, selected_idx]
            self.info[class_idx] = self.info[class_idx][selected_idx]
            LFU = LFU[selected_idx]
            iter_cnt += 1

            balance = (self.class_budget - self.keys[class_idx].shape[1]) - request_n
            if balance < 0:
                thres_dynamic = int(LFU.min()) + 1
            else:
                break

        new_size = self.keys[class_idx].shape[1]
        self.replace_n[class_idx] += old_size - new_size

        return balance

Uncertain-region Refinement

Confidence loss

decoding과 softmax normalization을 수행한 initial segmentation Mi는 [0, 1]의 값을 가지고
합은 1이다. 이는 해당 object에 대한 likelihood를 뜻한다.
pixel-wise uncertainity map U를 구하기 위해서 마스크에서 가장 큰 값과 두번 째로 가장 큰 값의 비율을 이용한다. 그리고 0에서 1의 값을 갖게 하기 위해 L2 norm을 수행한다.
Figure8

def calc_uncertainty(score):
    # seg shape: bs, obj_n, h, w
    score_top, _ = score.topk(k=2, dim=1)
    uncertainty = score_top[:, 0] / (score_top[:, 1] + 1e-8)  # bs, h, w
    uncertainty = torch.exp(1 - uncertainty).unsqueeze(1)  # bs, 1, h, w
    return uncertainty

Local Refinement mechanism

Figure9
Figure9
Figure9
Figure9
Figure9
e: local refinement mask
c: confidence score
S: final segmentation

아래는 decoder 소스 코드의 일부이다.

    def forward(self, patch_match, r3, r2, r1=None, feature_shape=None):

        p = self.ResMM(self.convFM(patch_match))
        p = self.RF3(r3, p)  # out: 1/8, 256
        p = self.RF2(r2, p)  # out: 1/4, 256
        p = self.pred2(NF.relu(p))
        p = NF.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
        bs, obj_n, h, w = feature_shape
        rough_seg = NF.softmax(p, dim=1)[:, 1]
        rough_seg = rough_seg.view(bs, obj_n, h, w)
        rough_seg = NF.softmax(rough_seg, dim=1)  # object-level normalization

        # Local refinement
        uncertainty = myutils.calc_uncertainty(rough_seg) 
        uncertainty = uncertainty.expand(-1, obj_n, -1, -1).reshape(bs * obj_n, 1, h, w)

        rough_seg = rough_seg.view(bs * obj_n, 1, h, w)  # bs*obj_n, 1, h, w
        r1_weighted = r1 * rough_seg
        r1_local = self.local_avg(r1_weighted)  # bs*obj_n, 64, h, w -> AveragePooling(same padding)을 수행
        r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8)  # neighborhood reference
        r1_conf = self.local_max(rough_seg)  # bs*obj_n, 1, h, w 

        local_match = torch.cat([r1, r1_local], dim=1)
        q = self.local_ResMM(self.local_convFM(local_match)) # convolution 수
        q = r1_conf * self.local_pred2(NF.relu(q))

        p = p + uncertainty * q
        p = NF.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
        p = NF.softmax(p, dim=1)[:, 1]  # no, h, w

        return p