Uformer
“””
Uformer: A General U-Shaped Transformer for Image Restoration
Zhendong Wang, Xiaodong Cun, Jianmin Bao, Jianzhuang Liu
https://arxiv.org/abs/2106.03106
“””
1 | class Uformer(nn.Module): |
整体结构为Unet, 每一层Encoder stage 后跟下采,下采是kernel_size=4的,步长2的卷积,注意网络是采用的B,*H*W, C来传播的, HW是window的大小,这里取得是8*8。这里需要打破自己之前网络传输以HW来传播,其实HW(2D)和1D是一样得,而1D更加方便理解计算ATTention。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16class Downsample(nn.Module):
def __init__(self, in_channel, out_channel):
super(Downsample, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
)
def forward(self, x):
B, L, C = x.shape
# import pdb;pdb.set_trace()
H = int(math.sqrt(L))
W = int(math.sqrt(L))
x = x.transpose(1, 2).contiguous().view(B, C, H, W)
out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
return out
每一个Encoder、Decoder stage由BasicUformerLayer
组成。由Decoder的skip操作不同,分为3中:Concat
, Cross
, ConcatCross
, 最优结achieved by Uformer32-Concat->39.77
. Encoder都是 BasicUformerLayer
组成。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21class BasicUformerLayer(nn.Module):
def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
token_projection='linear', token_mlp='ffn', se_layer=False):
super().__init__()
...
# build blocks
self.blocks = nn.ModuleList([
LeWinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, win_size=win_size,
shift_size=0 if (i % 2 == 0) else win_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp,
se_layer=se_layer)
for i in range(depth)])
...
文章主要propose 的模块:LeWinTransformerBlock
, 计算win 内部的self-attention. 注意不是win之间的att,只计算win:8*8
的attention。这里和crop操作类似,可以参考,这种方式更加优雅,将patch放到Batch中,增加Batch_size的大小,因为win可以同等处理,即每一个win share了同一组参数。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58class LeWinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, token_projection='linear', token_mlp='leff',
se_layer=False):
super().__init__()
...
self.attn = WindowAttention(
dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
token_projection=token_projection, se_layer=se_layer)
...
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, mask=None):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
## input mask
if mask != None:
input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
input_mask_windows = window_partition(input_mask, self.win_size) # nW, window_size, window_size, 1
attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, window_size*window_size
attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(
1) # nW, window_size*window_size, window_size*window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
...
# partition windows
x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C
x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
...
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
del attn_mask
return x
windowAttention 和 Swin 一样, win每个位置由position embedding, 也就是说这里每一个win都 treat equally!
1 | class WindowAttention(nn.Module): |