read: FNet

Table of contents

FNet: Mixing Tokens with Fourier Transforms - NAACL 2022

Code: arxiv appendix | official-jax; | keras code

(2023-07-05) Other implementations found by asking bing chat: “Could you give its pytorch code?”


(2023-06-16)

Video Intro

Source video: FNet: Mixing Tokens with Fourier Transforms (Machine Learning Research Paper Explained) - Yannic Kilcher


(2023-07-07)

Abstract

  • Use linear transformations replace self-attention sublayers resulting in speeding up;

  • Use unparameterized Fourier Transform replace self-attention sublayers achieving over 90% accuracy of BERT counterparts.

  • FNet has a light memory footprint (because it doesn’t have parameters?)

Introduction

  • Attention connects each token by the relevance weights of every other token in the input.

  • And more complex mixing help capture the relationship between tokens.

  • Can attention, the relevance-based “token-mixer”, be replaced by simpler linear transformation (๐—๐–โปยน+๐›)?

  • Decent results are gived by replacing attention with twice parametrized (optimizable) matrix multiplications, which are mixing the sequence dimension and then mixing hidden dimension.

    A sequence containing 5 tokens, which are 4-dimensional.

    S e q u e n c e d i m e n t i o n h d i i d m d e e n n s i o n
  • Use the faster, structured linear transformation FFT without parameters, yielding similar performance of dense layer mixing and good scalability.

  • Contributions:

    1. attention may not be a necessary component. Hence, seeking new mixing mechanisms is valuable.
    2. FNet uses FFT to mix token speeding up the training while losing some accuracy.
    3. Attention do help increase accuracy to some extent.
    4. FNet scales well to long inputs.

Code from: rishikksh20/FNet-pytorch:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch import nn

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
      
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FNetBlock(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    # "2-D fft"? row-wise first, then column-wise.
    x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real
    return x

class FNet(nn.Module):
    def __init__(self, dim, depth, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, FNetBlock()),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
Built with Hugo
Theme Stack designed by Jimmy