Qwen2.5VL解析
本文以Qwen2.5VL-7B
为例,分Vision部分和LLM部分,来解析该模型。
Vision部分Permalink
在config.json
中vision部分的配置如下:
"vision_config": {
"depth": 32,
"embed_dim": 1280,
"mlp_ratio": 4,
"num_heads": 16,
"in_chans": 3,
"hidden_size": 1536,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2
}
图片预处理过程Permalink
预处理过程可以看Qwen2VLImageProcessor
的实现
第一步,Resize
比如图片尺寸是800x600
,先寻找接近28 (注:merge_sizexpatch_size
)倍数的尺寸,得到784x588
,确定这个像素量<= max_pixels
。如果大于,则需要调整到max_piexls
内;同理如下小于min_piexls
也要调整。然后将图片Resize到784x588
源码如下:
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if height < factor or width < factor:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
第二步,图片归一化
将图片数值从u8转换成[-1,1]的float值。此时数据为data = float [1, 3, 588, 784]
第三步,temporal处理
batch维度叠加到temporal_patch_size
的大小,此时数据为data = float [2, 3, 588, 784]
也就是这里data[0] == data[1]
第四步,维度转换
参数如下:
grid_t = data.shape[0]//`temporal_patch_size` = 1
grid_h = 588 / patch_size = 42
grid_w = 784 / patch_size = 56
data = data.reshape(grid_t, temporal_patch_size, channel, grid_h//merge_size, merge_size, patch_size, grid_w//merge_size, merge_size, patch_size) = data[1, 2, 3, 21, 2, 14, 28,2,14]
转换过程如下:
data[2, 3, 588, 784]
=>data[1, 2, 3, 21, 2, 14, 28, 2, 14]
=> transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
=> data[1, 21, 28, 2, 2, 3, 2, 14, 14]
=> reshape(grid_t x grid_h x grid_w, channel x temporal_patch_size x patch_size x patch_size)
=> data[2352, 1176]
理解
以上是预处理过程,进一步理解这个处理过程的涵义大致如下:
patch_size
:图片会被切分成14x14
的网格temporal_patch_size
:Qwen2-VL
把视频当做1秒2帧图片处理,为了统一图片和视频,会把图片复制一份grid_t/grid_h/grid_w
: 表示时间/高度/宽度三个维度的网格数merge_size
: 从transpose过程可以看出grid_h
和grid_w
做了2x2
的融合处理[2352, 1176]
: 最终的shape,可以理解成网络index与网格数据
经过融合后,每个patch对应的position_id
依次为:[0,0]、[0,1]、[1,0]、[1,1]、[0,2]、[0,3]、[1,2]、[1,3]……
VITPermalink
经过预处理后VIT输入为[2352, 1176]
1) PatchEmbed
可以理解成维度变化,权重为[1176, 1280]
(config.vision_config.hidden_size
)
表面上是3D Conv运算,实际等价于MatMul: [2352, 1176] x [1176, 1280] => [2352, 1280]
1) rot_pos_emb
输入是grid_thw
,得到position_id
为[2352, 2]
(即每个patch对应的坐标),通过对[2352,2]
做位置编码处理,得到输出[2352,40]
2) get_window_index
它的目的是将同一个窗口排列到一起。
输入是grid_thw[1, 42, 56]
,每2X2
为一组,则一共21x28=588
组,对应index为[0,1,...587]
;
而window_size
在配置中为112,映射到grid中就是112/14/2 = 4
;
将其再按4x4
分组,则为[6,8]
,不足的部分补上pad,得到inded_padded[1, 24, 32]
,pad部分补上-100
,数值如下:
tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, -100, -100, -100, -100],
[ 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, -100, -100, -100, -100],
[ 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, -100, -100, -100, -100],
[ 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, -100, -100, -100, -100],
[ 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
......
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100]]])
进而index_padded做如下处理:
index_padded[1, 24, 32]
=> reshape(1, 6, 4, 8, 4)
=> permute(0, 1, 3, 2, 4)
=> reshape(1, 48, 4, 4)
=> index_padded[1, 48, 4, 4]
=> reshape(-1) # [768]
=> 去掉其中-100
=> window_index[588] # 窗口,[0,1,2,3,28,29,...]
=> cu_window_seqlens[49] # 记录窗口间隔,从0开始,到2352结束。其中2352为总patch数.[0, 64, 128, 192,...]
# 为什么多数间隔是64?因为窗口大小是112x112,每个窗口对应8x8个patch,也就是共64个patch
3) window attention
根据前面计算好的窗口数据,调整hidden_states
,使同一个窗口内的token排列在一起,同时位置编码也要做重新排列。
如此有大部分层只用计算窗口注意力,少数计算全部注意力(对应配置fullatt_block_indexes
)。
怎么理解window attention?
比如QKV的输入是[1, n, head, dim],那么Q和K运算为[1, n, head, dim]x[1, n, head, dim] => [1, n, head, n]
那么它的计算量是1 x n x head x n x dim x 2;
采用window attention后,它的计算为n/64 个[1, 64, head, dim]x[1, 64, head, dim] => [1, 64, head, 64]
那么它的计算量是(1 x 64 x head x 64 x dim x 2) x n/64 = 1 x n x head x 64 x dim x 2
可以看出计算量比例为n:64
4) block
执行标准的attention + mlp
运算,共config.vision_config.depth
个Block。
最终输入依然是[2352, 1280]
5) Merge
将2x2
merge 到hidden_size
中,直接看源码理解:
class Qwen2_5_VLPatchMerger(nn.Module):
# context_dim对应1280,也就是vision_config.hidden_size (vision的hidden_size)
# dim对应3584,也就是config.hidden_size (llm的hidden_size)
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
运算过程如下:
input[2352, 1280]
=> rmsnorm
=> view (-1, 5120) => [588, 5120]
=> x [5120, 5120] => [588, 5120]
=> GELU
=> x [5120, 3584] => [588, 3584]
6) Reverse
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
由于window attention的原因,对原始的hidden_states
做了调整,这里是其逆操作,还原成原始的顺序。
LLM部分Permalink
(待补充)