2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。
2020 年 10 月,谷歌提出了Vit(Vision Transformer),利用 Transformer 对图像进行分类,而不需要卷积网络。
2021年1月,OpenAI 提出两个模型:DALL·E 基于本文直接生成图像,CLIP将图像映射到文本描述的类别中。两个模型都利用 Transformer 。
2021年3月,微软提出Swin Transformer,把CV各大任务给屠榜了。。。。
我能放过它?我不能。。。总结下前段时间看了论文和代码梳理出来的swin_transformer框架和实现。
论文: https://arxiv.org/abs/2103.14030
代码: https://github.com/microsoft/Swin-Transformer
swin_transformer对比之前Vit有两个改进点:
1.引入了CNN里常用的多层次transformers结构
Vit的尺度是不变的,不易于接入到下游任务中,比如分割的encoder阶段可以方便的接入resnet等backbone网络,而Vit的特征图尺寸是不变的下图(b)。swin_transfomer通过合并image_patchesd的方式引入多层次结构,如下图(a)。
2、降低计算复杂度和内存占用
论文中定义上图中灰色块为patch,红色块定义为window。swin_transfomer通过切分窗口,计算self_attention是针对这些局部的无重叠的window。原始的MSA和论文中W-MSA的计算复杂度如下图公式,其中M是窗口包含patch的个数,也就是window_size,其大小是远小于h,w的。通过公式可以看出其计算复杂度和hw是线性关系。这里复杂度计算方法,我们后续分析源码后可以更清晰了解。
针对第一个优化点,论文使用的网络架构如下:
结构分为4个stage,stages中特征图大小分别缩小为1/4,1/8,1/16,1/32。
针对第二个优化点,论文指出仅仅对FM切分windows,然后对每个window进行self_attention有一个缺点,就是窗口之间是无沟通的。所以提出使用串联W-MSA和SW-MSA的方式。
W-MSA就是无重叠的窗口self_attention计算,而cyclic shift就如下图,对窗口进行一个shift。本来2*2的窗口个数,不等比切分为3*3个窗口。但是这样计算量会增大1.5*1.5倍。作者提出一个替换方法是进行一个roll操作,将2*2的窗口向左向上移动,移动后的窗口就包含了上层其他区域窗口的信息了。但是ABC区域本不该是邻近区域,所以还需要进行一个mask操作。
最后记得反shift把整个窗口移回去~
结果就是把CV几个大任务屠榜了。。
下面介绍从代码角度深入了解swin_transformer
先了解主要类:BasicLayer实现stage的流程,SwinTransformerBlock是BasicLayer的主要逻辑模块也是论文核心模块,WindowAttention是SwinTransformerBlock中实现attention的模块。
depths:(2,2,6,4)决定每个layer的SwinTransformerBlock执行次数。
论文提出了4套参数模型,我们下面以Swin-T为例介绍。
代码模块逻辑:
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
主要模块的代码逻辑:
首先进行一次patch_embed,patch_embed就是把输入按patch进行一次向量映射。我认为就是卷积操作(标题swin_transfomer,第一步就是卷积~卷积yyds)
设定输入:(3,256,256),patch_size=4,embeding_dim=96
(1)分辨率不够4整除就pad到4的倍数
(2)通用卷积kernel=4,stride=4,将image映射为无重叠的4*4的patchs:(96,64,64)
(3)如果需要norm,再进行一次layerNorm
(4)(3,256,256) 通过patch_embed,特征为(96,64,64)
如果有position_embeding步骤,需要学习一个96,64,64的pos_emded参数。和patch_embed进行concat.
将emded矩阵进行flatten+transpose-->64*64, 96
对分辨率缩小*4的特征图进行4个stage的-BasicLayer
设定window_size=7,以stage1为例输入特征图大小为(64,64)。img_mask初始为(70,70),那么通过window_partition就把特征图切分为100个7*7的窗口。
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask:, h, w, : = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上代码目的是得到100个49*49的attn_mask。
这里的attn_mask是为后续的cyclic shift,也就是SW-MSA使用。
首先,对img_mask70*70的图进行切分9大块赋值
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
然后通过将window_partition将窗口切分为100个7*7窗口,对数据平铺,得到100*49,每个窗口和其他窗口进行相减,得到100*49*49,再将不为0的值赋值-100。这些不为0位置含义可以理解为和相对位置不为上图中划分的同一个区域。结合cyclic shift,表示cyclic shift中在一个window内,特征不相邻的sub_window的位置,所以需要mask掉。
对输入64*64, 96进行layer_norm+reshape+pad操作。pad作用是要FM的H,W是window_size的倍数。对stage1:64*64, 96-->70,70,96
先看第一阶段W-MSA blcok,也就是不加入cyclic shift。
(a)进行window_partition,将特征图切分为window_size*window_size的patch,1,70*70,96切分为100,7,7,96,再reshape100,49,96
(b) WindowAttention
计算self_attention
step1:获取QKV矩阵。X:100,49,64-->Q,K,V:100,3,49,32
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv0, qkv1, qkv2
具体操作:输入全连接C通道扩展到3C,再根据multi_head将FM切分为head_num份,最后slipe得到qkv矩阵。100,3,49,32表示窗口个数,attention头,窗口长度,C/head
step2:计算attention。
attn = (q @ k.transpose(-2, -1))
100,3,49,32*100,3,32,49-->:100,3,49,49 。self_attention方面的原理可以查看transformers论文,这里就不详细介绍了。
step3:计算relative_position_bias
论文提出,增加相对位置编码效果更好。也就是在step2计算出的attn加上relative_position_bias。和attn一样,大小应该为(3,49*49)的矩阵。
下面看如何计算relative_position_bias。
#define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size0) coords_w = torch.arange(self.window_size1) coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0 relative_coords:, :, 1 += self.window_size1 - 1 relative_coords:, :, 0 *= 2 * self.window_size1 - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view( self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)
我们假设窗口大小为2,方便理解计算相对位置编码逻辑。
首先建立坐标系:
然后在X和Y方向计算relative_coords。计算relative_coords第一步加(window_size-1)是为了让值都为正数,在X方向再*(2*window_size-1)是为了后续求和能区分(0,1)和(1,0)这类坐标。
最后将X和Y方向坐标值值求和,得到relative_position_index 。
根据以上计算过程,也可以知道,我们的relative_position_bias_table(需要学习的参数)最大值应该是(window_size+(window_size-1))*(2*window_size-1)。
有了relative_position_index和relative_position_bias_table后,relative_position_bias就可以通过查表方式获取。
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
step4:计算attn_out
attn = attn + relative_position_bias.unsqueeze(0) x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
根据self_attntion的公式:
softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32
step5:进行全连接
reshape+proj -->100,49,96
计算self_attention和transformer里attention机制一样。在NLP领域,输入为BLC,计算的attn是L*L表示每个pos的token对另一个pos的attention值。在这里CV领域,之前将特征图划分为不同窗口,每个窗口大小windowsize*windowsize,所以L对应windowsize*windowsize的长度,也就是一个窗口内每个点对其他点的attention值,是对每个窗口计算self_attention。
以上过程是通过window_partition后处理,这里需要进行window_reverse,把100,49,96还原到1,70,70,96
reverse后的FM和SwinTransformerBlock最初的输入进行一次shortcut。SwinTransformerBlock模块流程结束~了么?没有。之前我们避开了cyclic shift。
在执行block中,对shift_size是
shift_size=0 if (i % 2 == 0) else window_size // 2,
所以第二个迭代 block,我们是需要进行cyclic shift的。
执行逻辑还是以上的(1)-(4),主要不同在于步骤(2),下面主要讲解,shift_size不为0时,步骤(2)的流程。
看第二阶段SW-MSA blcok,也就是加入cyclic shift。
(a)同样进行window_partition,得到b,100,49,96的特征图。然后
cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None
这行代码的含义就是,将x向左移动shift_size,向上移动shift_size。也就是下图中的cyclic shift。执行这个操作的目的是,通过window_partition后进行W-MSA,窗口和窗口之间是没有重叠的,使用SW-MSA就可以让窗口之间有关联,但是这里存在的一个问题是下图中ABC区域和邻近窗口其实是不相邻的,是通过roll操作后赋值在这个区域。
(b)windowAttention
计算attention和上诉步骤一致,只是在步骤a中我们提到了,ABC区域在计算attention时需要mask掉,这里的mask就是我们BasicLayer的第一步获取的attn_mask(100,49,49)~
if mask is not None: nW = mask.shape0 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn)
mask主要逻辑,attn假设目前是200,3,49,49,我们计算的attn_mask是(100,49,49),因为是针对窗口位置mask和bs和head_num无关,所以将attn和mask分别reshape到(2, 100, 3, 49, 49)和(1,100,1,49,49)就好了。
最后记得window_rever后,记得把shift_x给sereverse回去。
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 以上就将最复杂的SwinTransformerBlock模块介绍完了~
downsamp(最后一个stage不需要)使用的是PatchMerging.对FM进行间隔采样达到降采样的目的,再concat低分辨率FM后,通过全连接对C通道裁剪。很像pixelShuffle的反向操作。
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) x = x.view(B, H, W, C) padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x:, 0::2, 0::2, : # B H/2 W/2 C x1 = x:, 1::2, 0::2, : # B H/2 W/2 C x2 = x:, 0::2, 1::2, : # B H/2 W/2 C x3 = x:, 1::2, 1::2, : # B H/2 W/2 C x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x)
以上就是一个basicLayer的逻辑,通过四个stage得到不同尺度的特征图(Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
有了这个四个特征图就可以和resnet等结构一样,接入到下游任务了~
近期,快手可谓动作不断。 先是宣布和京东合作618购物节,再是邀请周杰伦入驻,...
高级配置 设置“云服务器名称”。 名称可自定义,但需符合命名规则:只能由中文...
2020年11月,数据猿推出了 数智跃新,破浪而出大数据的2020,我的2021 大型年度主题...
小编最近遇到点问题,就是自己的电脑自从升级为win10系统后,win10 edge浏览器网...
退订时如何扣费要根据资源状态、资源使用时长等条件而定,具体规则参见 表1 。 ...
对于现代人来说,有一个安全的地方来展示他们的作品和分享故事和观点是一个日益...
许力 仁威 阿里云数据库产品事业部高级产品经理 目前负责阿里云原生多模数据库Li...
被贴标签,是生活在这个时代的每个人,都无法避免的事情。 沉迷游戏、农村子弟、...
怎么配置网站服务器?我们在购买 云服务器 之后,往往是需要自己手动进行配置的...
『码哥』的 Redis 系列文章有一篇讲透了 Redis 的性能优化 《 Redis 核心篇:唯...