Uformer

Uformer

“””

Uformer: A General U-Shaped Transformer for Image Restoration

Zhendong Wang, Xiaodong Cun, Jianmin Bao, Jianzhuang Liu

https://arxiv.org/abs/2106.03106

“””

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Uformer(nn.Module):
def __init__(self, img_size=128, in_chans=3,
embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, token_projection='linear', token_mlp='ffn', se_layer=False,
dowsample=Downsample, upsample=Upsample, **kwargs):
super().__init__()
...

def forward(self, x, mask=None):
# Input Projection
y = self.input_proj(x)
y = self.pos_drop(y)
# Encoder
conv0 = self.encoderlayer_0(y, mask=mask)
pool0 = self.dowsample_0(conv0)
conv1 = self.encoderlayer_1(pool0, mask=mask)
pool1 = self.dowsample_1(conv1)
conv2 = self.encoderlayer_2(pool1, mask=mask)
pool2 = self.dowsample_2(conv2)
conv3 = self.encoderlayer_3(pool2, mask=mask)
pool3 = self.dowsample_3(conv3)

# Bottleneck
conv4 = self.conv(pool3, mask=mask)

# Decoder
up0 = self.upsample_0(conv4)
deconv0 = torch.cat([up0, conv3], -1)
deconv0 = self.decoderlayer_0(deconv0, mask=mask)

up1 = self.upsample_1(deconv0)
deconv1 = torch.cat([up1, conv2], -1)
deconv1 = self.decoderlayer_1(deconv1, mask=mask)

up2 = self.upsample_2(deconv1)
deconv2 = torch.cat([up2, conv1], -1)
deconv2 = self.decoderlayer_2(deconv2, mask=mask)

up3 = self.upsample_3(deconv2)
deconv3 = torch.cat([up3, conv0], -1)
deconv3 = self.decoderlayer_3(deconv3, mask=mask)

# Output Projection
y = self.output_proj(deconv3)
return x + y

整体结构为Unet, 每一层Encoder stage 后跟下采,下采是kernel_size=4的,步长2的卷积,注意网络是采用的B,*H*W, C来传播的, HW是window的大小,这里取得是8*8。这里需要打破自己之前网络传输以HW来传播,其实HW(2D)和1D是一样得,而1D更加方便理解计算ATTention。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Downsample(nn.Module):
def __init__(self, in_channel, out_channel):
super(Downsample, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),

)

def forward(self, x):
B, L, C = x.shape
# import pdb;pdb.set_trace()
H = int(math.sqrt(L))
W = int(math.sqrt(L))
x = x.transpose(1, 2).contiguous().view(B, C, H, W)
out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
return out

每一个Encoder、Decoder stage由BasicUformerLayer组成。由Decoder的skip操作不同,分为3中:Concat, Cross, ConcatCross, 最优结achieved by Uformer32-Concat->39.77. Encoder都是 BasicUformerLayer组成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class BasicUformerLayer(nn.Module):
def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
token_projection='linear', token_mlp='ffn', se_layer=False):

super().__init__()
...
# build blocks
self.blocks = nn.ModuleList([
LeWinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, win_size=win_size,
shift_size=0 if (i % 2 == 0) else win_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp,
se_layer=se_layer)
for i in range(depth)])
...

文章主要propose 的模块:LeWinTransformerBlock, 计算win 内部的self-attention. 注意不是win之间的att,只计算win:8*8的attention。这里和crop操作类似,可以参考,这种方式更加优雅,将patch放到Batch中,增加Batch_size的大小,因为win可以同等处理,即每一个win share了同一组参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class LeWinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, token_projection='linear', token_mlp='leff',
se_layer=False):
super().__init__()
...

self.attn = WindowAttention(
dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
token_projection=token_projection, se_layer=se_layer)

...
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)


def forward(self, x, mask=None):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))

## input mask
if mask != None:
input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
input_mask_windows = window_partition(input_mask, self.win_size) # nW, window_size, window_size, 1
attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, window_size*window_size
attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(
1) # nW, window_size*window_size, window_size*window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None

shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)

...

# partition windows
x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C
x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C

# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C

# merge windows
attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C

...

# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
del attn_mask
return x

windowAttention 和 Swin 一样, win每个位置由position embedding, 也就是说这里每一个win都 treat equally!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class WindowAttention(nn.Module):
def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
proj_drop=0., se_layer=False):

super().__init__()
self.dim = dim
self.win_size = win_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.win_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
if token_projection == 'conv':
self.qkv = ConvProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
elif token_projection == 'linear_concat':
self.qkv = LinearProjection_Concat_kv(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)

self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.se_layer = SELayer(dim) if se_layer else nn.Identity()
self.proj_drop = nn.Dropout(proj_drop)

trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x, attn_kv=None, mask=None):
B_, N, C = x.shape
q, k, v = self.qkv(x, attn_kv)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
ratio = attn.size(-1) // relative_position_bias.size(-1)
relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)

attn = attn + relative_position_bias.unsqueeze(0)

if mask is not None:
nW = mask.shape[0]
mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N * ratio)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.se_layer(x)
x = self.proj_drop(x)
return x

def extra_repr(self) -> str:
return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'