OpenAI CLIP

1 分钟阅读

概述

CLIP全称Contrastive Language-Image Pretraining,基于语言图像对比预训练,是目前最为优秀的zero-shot模型,也是后续一系列图文模型的基石,甚至直接拿它做Backbone。它最大的优势在于,可以直接用文本+图像做训练,这部分数据量在网上是非常庞大的。而传统的标注类数据集成本非常高昂,数据量也不是一个量级。

zero-shot:零样本学习,无需专门对样本分类,使分类功能可以泛化。

原理如下图(图来自官方)所示:

  • 训练时,文本数据经过Text Encoder 生成[T1, T2, ..., TN]向量;图片数据经过Image Encoder转成[I1, I2, ..., IN]向量,然后两个向量求余弦距离,得到所有文件与图片的相关性
  • 推理时,同样的方式使用Text EncodeImage Encoder得到两个向量,求余弦举例
  • 通常图像用resnet50做backbone,文本用vit做backbone

代码分析

// 得到图像向量
vision_outputs = self.vision_model(
    pixel_values=pixel_values,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

// 得到文本向量
text_outputs = self.text_model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

// 图像向量与系数矩阵乘
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)

// 文本向量与系数矩阵乘
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)

// 分别做归一化
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

// 求余弦距离
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()