Qwen2.5VL解析

16 分钟阅读

本文以Qwen2.5VL-7B为例,分Vision部分和LLM部分,来解析该模型。

源码:modeling_qwen2_5_vl.py

Vision部分Permalink

config.json中vision部分的配置如下:

"vision_config": {
    "depth": 32,
    "embed_dim": 1280,
    "mlp_ratio": 4,
    "num_heads": 16,
    "in_chans": 3,
    "hidden_size": 1536,
    "patch_size": 14,
    "spatial_merge_size": 2,
    "spatial_patch_size": 14,
    "temporal_patch_size": 2
  }

图片预处理过程Permalink

预处理过程可以看Qwen2VLImageProcessor的实现

第一步,Resize

比如图片尺寸是800x600,先寻找接近28 (注:merge_sizexpatch_size)倍数的尺寸,得到784x588,确定这个像素量<= max_pixels。如果大于,则需要调整到max_piexls内;同理如下小于min_piexls也要调整。然后将图片Resize到784x588

源码如下:

def smart_resize(
    height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
):
    """Rescales the image so that the following conditions are met:

    1. Both dimensions (height and width) are divisible by 'factor'.

    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].

    3. The aspect ratio of the image is maintained as closely as possible.

    """
    if height < factor or width < factor:
        raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
    elif max(height, width) / min(height, width) > 200:
        raise ValueError(
            f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
        )
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = math.floor(height / beta / factor) * factor
        w_bar = math.floor(width / beta / factor) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar

第二步,图片归一化

将图片数值从u8转换成[-1,1]的float值。此时数据为data = float [1, 3, 588, 784]

第三步,temporal处理

batch维度叠加到temporal_patch_size的大小,此时数据为data = float [2, 3, 588, 784]

也就是这里data[0] == data[1]

第四步,维度转换

参数如下:

grid_t = data.shape[0]//`temporal_patch_size` = 1
grid_h = 588 / patch_size = 42
grid_w = 784 / patch_size = 56
data = data.reshape(grid_t, temporal_patch_size, channel, grid_h//merge_size, merge_size, patch_size, grid_w//merge_size, merge_size, patch_size) = data[1, 2, 3, 21, 2, 14, 28,2,14]

转换过程如下:

data[2, 3, 588, 784] 
=>data[1, 2, 3, 21, 2, 14, 28, 2, 14]
=> transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
=> data[1, 21, 28, 2, 2, 3, 2, 14, 14]
=> reshape(grid_t x grid_h x grid_w, channel x temporal_patch_size x patch_size x patch_size)
=> data[2352, 1176]

理解

以上是预处理过程,进一步理解这个处理过程的涵义大致如下:

  • patch_size:图片会被切分成14x14的网格
  • temporal_patch_size: Qwen2-VL把视频当做1秒2帧图片处理,为了统一图片和视频,会把图片复制一份
  • grid_t/grid_h/grid_w: 表示时间/高度/宽度三个维度的网格数
  • merge_size: 从transpose过程可以看出grid_hgrid_w 做了2x2的融合处理
  • [2352, 1176]: 最终的shape,可以理解成网络index与网格数据

经过融合后,每个patch对应的position_id依次为:[0,0]、[0,1]、[1,0]、[1,1]、[0,2]、[0,3]、[1,2]、[1,3]……

VITPermalink

经过预处理后VIT输入为[2352, 1176]

1) PatchEmbed

可以理解成维度变化,权重为[1176, 1280] (config.vision_config.hidden_size)

表面上是3D Conv运算,实际等价于MatMul: [2352, 1176] x [1176, 1280] => [2352, 1280]

1) rot_pos_emb

输入是grid_thw,得到position_id[2352, 2](即每个patch对应的坐标),通过对[2352,2]做位置编码处理,得到输出[2352,40]

2) get_window_index

它的目的是将同一个窗口排列到一起。

输入是grid_thw[1, 42, 56],每2X2为一组,则一共21x28=588组,对应index为[0,1,...587]

window_size在配置中为112,映射到grid中就是112/14/2 = 4

将其再按4x4分组,则为[6,8],不足的部分补上pad,得到inded_padded[1, 24, 32],pad部分补上-100,数值如下:

tensor([[[   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,
            11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,
            22,   23,   24,   25,   26,   27, -100, -100, -100, -100],
         [  28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,
            39,   40,   41,   42,   43,   44,   45,   46,   47,   48,   49,
            50,   51,   52,   53,   54,   55, -100, -100, -100, -100],
         [  56,   57,   58,   59,   60,   61,   62,   63,   64,   65,   66,
            67,   68,   69,   70,   71,   72,   73,   74,   75,   76,   77,
            78,   79,   80,   81,   82,   83, -100, -100, -100, -100],
         [  84,   85,   86,   87,   88,   89,   90,   91,   92,   93,   94,
            95,   96,   97,   98,   99,  100,  101,  102,  103,  104,  105,
           106,  107,  108,  109,  110,  111, -100, -100, -100, -100],
         [ 112,  113,  114,  115,  116,  117,  118,  119,  120,  121,  122,
           123,  124,  125,  126,  127,  128,  129,  130,  131,  132,  133,
         ......
         [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
         [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]]])

进而index_padded做如下处理:

index_padded[1, 24, 32]
=> reshape(1, 6, 4, 8, 4) 
=> permute(0, 1, 3, 2, 4)
=> reshape(1, 48, 4, 4)
=> index_padded[1, 48, 4, 4]
=> reshape(-1) # [768]
=> 去掉其中-100
=> window_index[588] # 窗口,[0,1,2,3,28,29,...]
=> cu_window_seqlens[49] # 记录窗口间隔,从0开始,到2352结束。其中2352为总patch数.[0, 64, 128, 192,...]
# 为什么多数间隔是64?因为窗口大小是112x112,每个窗口对应8x8个patch,也就是共64个patch

3) window attention

根据前面计算好的窗口数据,调整hidden_states,使同一个窗口内的token排列在一起,同时位置编码也要做重新排列。

如此有大部分层只用计算窗口注意力,少数计算全部注意力(对应配置fullatt_block_indexes)。

怎么理解window attention?
比如QKV的输入是[1, n, head, dim],那么Q和K运算为[1, n, head, dim]x[1, n, head, dim] => [1, n, head, n]
那么它的计算量是1 x n x head x n x dim x 2;
采用window attention后,它的计算为n/64 个[1, 64, head, dim]x[1, 64, head, dim] => [1, 64, head, 64]
那么它的计算量是(1 x 64 x head x 64 x dim x 2) x n/64 = 1 x n x head x 64 x dim x 2
可以看出计算量比例为n:64

4) block

执行标准的attention + mlp运算,共config.vision_config.depth个Block。

最终输入依然是[2352, 1280]

5) Merge

2x2 merge 到hidden_size中,直接看源码理解:

class Qwen2_5_VLPatchMerger(nn.Module):
    # context_dim对应1280,也就是vision_config.hidden_size (vision的hidden_size)
    # dim对应3584,也就是config.hidden_size (llm的hidden_size)
    def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        return x

运算过程如下:

input[2352, 1280]
=> rmsnorm
=> view (-1, 5120) => [588, 5120]
=> x [5120, 5120] => [588, 5120]
=> GELU
=> x [5120, 3584] => [588, 3584]

6) Reverse

reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]

由于window attention的原因,对原始的hidden_states做了调整,这里是其逆操作,还原成原始的顺序。

LLM部分Permalink

(待补充)