OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  ColPali — 用视觉语言模型做文档检索的新思路

ColPali — 用视觉语言模型做文档检索的新思路

 
  index ·  2026-02-21 20:58:01 · 5 次点击  · 0 条评论  

ColPali:基于视觉语言模型的高效文档检索 👀

arXiv
GitHub
Hugging Face
GitHub

Test
Version
Downloads


[模型卡片]
[ViDoRe 排行榜]
[演示]
[博客文章]

相关论文

本仓库包含用于训练论文 ColPali: Efficient Document Retrieval with Vision Language Models 中视觉检索器的代码。具体来说,它包含了训练 ColPali 模型的代码,该模型是一个基于 ColBERT 架构和 PaliGemma 模型的视觉检索器。

简介

我们提出的新模型 ColPali,旨在利用视觉语言模型(VLM)在视觉空间中构建高效的多向量嵌入用于文档检索。通过将 PaliGemma-3B 的 ViT 输出图像块送入一个线性投影层,我们创建了文档的多向量表示。我们按照 ColBERT 方法训练模型,以最大化这些文档嵌入与查询嵌入之间的相似度。

使用 ColPali 无需复杂且脆弱的版面识别和 OCR 流程,单个模型即可同时考虑文档的文本和视觉内容(布局、图表等)。

ColPali 架构

ColVision 模型列表

模型 ViDoRe 得分 🏆 许可证 说明 当前是否支持
vidore/colpali 81.3 Gemma • 基于 google/paligemma-3b-mix-448
• ColPali 论文中使用的检查点。
vidore/colpali-v1.1 81.5 Gemma • 基于 google/paligemma-3b-mix-448
• 修复了查询的右填充。
vidore/colpali-v1.2 83.9 Gemma • 与 vidore/colpali-v1.1 类似。
vidore/colpali-v1.3 84.8 Gemma • 与 vidore/colpali-v1.2 类似。
• 使用更大的有效批次大小(256)训练了 3 个周期。
vidore/colqwen2-v0.1 87.3 Apache 2.0 • 基于 Qwen/Qwen2-VL-2B-Instruct
• 支持动态分辨率。
• 使用每页 768 个图像块和有效批次大小 32 进行训练。
vidore/colqwen2-v1.0 89.3 Apache 2.0 • 与 vidore/colqwen2-v0.1 类似,但使用更强大的 GPU 和更大的有效批次大小(256)进行训练。
vidore/colqwen2.5-v0.1 88.8 Apache 2.0 • 基于 Qwen/Qwen2 5-VL-3B-Instruct
• 支持动态分辨率。
• 使用每页 768 个图像块和有效批次大小 32 进行训练。
vidore/colqwen2.5-v0.2 89.4 Apache 2.0 • 与 vidore/colqwen2.5-v0.1 类似,但使用略有不同的超参数训练。
TomoroAI/tomoro-colqwen3-embed-4b 90.6 Apache 2.0 • 基于 Qwen3-VL 骨干网络。
• 320 维 ColBERT 风格嵌入,支持动态分辨率。
• 为多向量文档检索训练。
vidore/colSmol-256M 80.1 Apache 2.0 • 基于 HuggingFaceTB/SmolVLM-256M-Instruct
vidore/colSmol-500M 82.3 Apache 2.0 • 基于 HuggingFaceTB/SmolVLM-500M-Instruct
Cognitive-Lab/ColNetraEmbed 86.4 Gemma • 基于 google/gemma-3-4b-it
• 多向量延迟交互检索模型。
• 支持 22 种语言的多语言检索。
Cognitive-Lab/NetraEmbed 81.0 Gemma • 基于 google/gemma-3-4b-it
• 双编码器检索模型。
• 支持 Matryoshka 嵌入(768, 1536, 2560)。
• 支持 22 种语言的多语言检索。

环境设置

我们使用 Python 3.11.6 和 PyTorch 2.4 来训练和测试我们的模型,但代码库兼容 Python >=3.10 和较新的 PyTorch 版本。要安装该包,请运行:

pip install colpali-engine # 从 PyPi 安装
pip install git+https://github.com/illuin-tech/colpali # 从源码安装

在 Mac 上使用 MPS 运行 ColQwen 模型的用户报告了 torch 2.6.0 的错误。降级到 torch 2.5.1 可以修复这些错误。

[!WARNING]
对于 v1.0 以上的 ColPali 版本,请确保从源码安装 colpali-engine 包,或安装版本高于 v0.2.0。

使用方法

快速开始

import torch
from PIL import Image
from transformers.utils.import_utils import is_flash_attn_2_available

from colpali_engine.models import ColQwen2, ColQwen2Processor

model_name = "vidore/colqwen2-v1.0"

model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # 如果是 Apple Silicon 芯片,使用 "mps"
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()

processor = ColQwen2Processor.from_pretrained(model_name)

# 你的输入
images = [
    Image.new("RGB", (128, 128), color="white"),
    Image.new("RGB", (64, 32), color="black"),
]
queries = [
    "What is the organizational structure for our R&D department?",
    "Can you provide a breakdown of last year’s financial performance?",
]

# 处理输入
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)

# 前向传播
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)

我们现在实验性地支持 fast-plaid,以便为更大的语料库提供更快的匹配:

# !pip install --no-deps fast-plaid fastkmeans

# 以批次大小为 4 处理输入
dataloader = DataLoader(
    dataset=images,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)

ds  = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

plaid_index = processor.create_plaid_index(ds)

scores = processor.get_topk_plaid(query_embeddings, plaid_index, k=10)

基准测试

要在 ViDoRe 排行榜 上对 ColPali 进行基准测试,请使用 vidore-benchmark 包。

通过相似度图进行可解释性分析

通过将延迟交互相似度图叠加到原始图像上,我们可以可视化每个查询词对应的最显著图像块,从而获得关于模型关注区域的可解释性洞察。

要使用 interpretability 模块,你需要安装 colpali-engine[interpretability] 包:

pip install colpali-engine[interpretability]

然后,在使用 ColPali 生成嵌入后,使用以下代码为每个查询词绘制相似度图:

🔽 点击展开代码片段
import torch
from PIL import Image

from colpali_engine.interpretability import (
    get_similarity_maps_from_embeddings,
    plot_all_similarity_maps,
)
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device

model_name = "vidore/colpali-v1.3"
device = get_torch_device("auto")

# 加载模型
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()

# 加载处理器
processor = ColPaliProcessor.from_pretrained(model_name)

# 加载图像和查询
image = Image.open("shift_kazakhstan.jpg")
query = "Quelle partie de la production pétrolière du Kazakhstan provient de champs en mer ?"

# 预处理输入
batch_images = processor.process_images([image]).to(device)
batch_queries = processor.process_queries([query]).to(device)

# 前向传播
with torch.no_grad():
    image_embeddings = model.forward(**batch_images)
    query_embeddings = model.forward(**batch_queries)

# 获取图像块数量
n_patches = processor.get_n_patches(image_size=image.size, patch_size=model.patch_size)

# 获取用于过滤与图像无关的嵌入的张量掩码
image_mask = processor.get_image_mask(batch_images)

# 生成相似度图
batched_similarity_maps = get_similarity_maps_from_embeddings(
    image_embeddings=image_embeddings,
    query_embeddings=query_embeddings,
    n_patches=n_patches,
    image_mask=image_mask,
)

# 获取我们(唯一)输入图像的相似度图
similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)

# 对查询进行分词
query_tokens = processor.tokenizer.tokenize(query)

# 为每个查询词绘制并保存相似度图
plots = plot_all_similarity_maps(
    image=image,
    query_tokens=query_tokens,
    similarity_maps=similarity_maps,
)
for idx, (fig, ax) in enumerate(plots):
    fig.savefig(f"similarity_map_{idx}.png")

更详细的示例,可以参考 ColPali Cookbooks 👨🏻‍🍳 仓库中的可解释性笔记。

词元池化

词元池化 是一种符合 CRUDE 原则(支持文档增删)的方法,旨在减少多向量嵌入的序列长度。对于 ColPali,许多图像块共享冗余信息,例如白色背景块。通过将这些块池化在一起,我们可以减少嵌入数量,同时保留页面的大部分信号。关于图像嵌入上分层平均词元池化的检索性能,可以在 ColPali 论文 中找到。在我们的实验中,我们发现池化因子为 3 时提供了最佳权衡:向量总数减少了 $66.7\%$,同时保持了 $97.8\%$ 的原始性能。

要使用词元池化,可以使用 colpali-engine 包中的 HierarchicalEmbeddingPooler 类:

🔽 点击展开代码片段
import torch

from colpali_engine.compression.token_pooling import HierarchicalTokenPooler

# 虚拟多向量嵌入
list_embeddings = [
    torch.rand(10, 768),
    torch.rand(20, 768),
]

# 定义具有所需压缩级别的池化器
pooler = HierarchicalTokenPooler()

# 池化嵌入
outputs = pooler.pool_embeddings(list_embeddings, pool_factor=2)
如果你的输入是填充后的 3D 张量嵌入,而不是 2D 张量列表,请使用 `padding=True` 并指定分词器使用的填充方式,以确保 `HierarchicalTokenPooler` 在池化前正确移除填充值: ```python import torch from PIL import Image from transformers.utils.import_utils import is_flash_attn_2_available from colpali_engine.compression.token_pooling import HierarchicalTokenPooler from colpali_engine.models import ColQwen2, ColQwen2Processor model_name = "vidore/colqwen2-v1.0" model = ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda:0", # 如果是 Apple Silicon 芯片,使用 "mps" attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, ).eval() processor = ColQwen2Processor.from_pretrained(model_name) token_pooler = HierarchicalTokenPooler() # 你的页面图像 images = [ Image.new("RGB", (128, 128), color="white"), Image.new("RGB", (32, 32), color="black"), ] # 处理输入 batch_images = processor.process_images(images).to(model.device) # 前向传播 with torch.no_grad(): image_embeddings = model(**batch_images) # 应用词元池化(减少多向量嵌入的序列长度) image_embeddings = token_pooler.pool_embeddings( image_embeddings, pool_factor=2, padding=True, padding_side=processor.tokenizer.padding_side,
5 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私政策 ·  服务条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 24 ms
Developed with Cursor