Featured image of post read: Render - NVS | S3IM Loss for NeRF

read: Render - NVS | S3IM Loss for NeRF

S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields (ICCV 2023)

(Discussed in QQ group 706949479)

Code | Arxiv | ProjPage

Notes

Abs

Previous NeRF didn’t utilize structural information on image level, but train and predict point-wise.

Method

= 3 1 = 6 1 9 = 2 3 = = 2 5 = 3 2 S 5 - 4 = 3 - 1 3 I - = - M 8 - 0 = 9 - 4 7 - = 1 - = = 2 7 = 4 6 0 = 6 8 R e o r d e r = 9 7 = 6 2 1 = 8 0 = = 0 4 = 4 8 S - 6 = - 7 3 I - = - M 2 - 2 = 6 - 1 - 5 = 2 - 3 = = 5 3 = 1 1 4 3 = 9 3 R e o r d e r

Steps:

  1. Apply SSIM on the randomly selected training pixel patch with a kernel size $K$ (=2) and stride size S (=K).

  2. Repeatedly reorder the predicted and target pixel patchs, and calculate S3IM multiple ($M$=10) times.

  3. The final loss term is the average of them multiplied with a weight factor (hyperparameter) $λ$.

    $$ \rm L_{S3IM} = λ ⋅ (1 - \frac{1}{M} \sum_{m=1}^M SSIM(Patch_{rendered}, Patch_{target}) ) $$

Compare with SSIM:

  1. S3IM applied on random pixel patches significantly outperforms SSIM applied on local continuous patches.

  2. The authors explain this as the SSIM can only capture the local similarity, whereas S3IM can compare the nonlocal structural similarity over all training images.

  3. Training NeRF with local continuous patches will hurt the performance (as stated at the end of section 3.1).


Play

Code

1
2
3
4
5
6
7
8
9
s3im_func = S3IM(kernel_size=args.s3im_kernel, 
                 stride=args.s3im_stride, 
                 repeat_time=args.s3im_repeat_time, 
                 patch_height=args.s3im_patch_height, 
                 patch_width=args.s3im_patch_width).cuda()

if args.s3im_weight > 0:
  s3im_pp = args.s3im_weight * s3im_func(rgb_map, rgb_train)
  total_loss += s3im_pp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class S3IM(torch.nn.Module):
  r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm.
  Arguments:
    kernel_size (int): kernel size in ssim's convolution(default: 4)
    stride (int): stride in ssim's convolution(default: 4)
    repeat_time (int): repeat time in re-shuffle virtual patch(default: 10)
    patch_height (height): height of virtual patch(default: 64)
    patch_width (height): width of virtual patch(default: 64)
  """
  def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=64):
    super(S3IM, self).__init__()
    self.kernel_size = kernel_size
    self.stride = stride
    self.repeat_time = repeat_time
    self.patch_height = patch_height
    self.patch_width = patch_width
    self.ssim_loss = SSIM(window_size=self.kernel_size, stride=self.stride)

  def forward(self, src_vec, tar_vec):
    r"""
    src_vec: (ray_batch_size=4096=64*64, 3)
    """
    loss = 0.0
    index_list = []
    for i in range(self.repeat_time):
      if i == 0:
        tmp_index = torch.arange(len(tar_vec))  # (4096)
        index_list.append(tmp_index)
      else:
        ran_idx = torch.randperm(len(tar_vec))
        index_list.append(ran_idx)

    res_index = torch.cat(index_list) # (M * ray_bs = 10*4096)
    tar_all = tar_vec[res_index]    # (10*4096, 3)
    src_all = src_vec[res_index]

    tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.patch_height, self.patch_width * self.repeat_time)
    src_patch = src_all.permute(1, 0).reshape(1, 3, self.patch_height, self.patch_width * self.repeat_time)
    loss = (1 - self.ssim_loss(src_patch, tar_patch))
    return loss