Cover image generated by Nano Banana Pro

Cover image generated by Nano Banana Pro

Introduction

I recently came across a course called “Multi-Vector Image Retrieval” by DeepLearning.ai. The course mainly introduces ColPali [1], a vision-language model that generalizes the late-interaction retrieval paradigm pioneered by ColBERT [2], extending it from covering only text tokens to covering both text and visual tokens. It also contains a few tutorials on performance optimization techniques using Qdrant’s Python SDK. It is a great introductory resource, and I recommend it to anyone interested in visual document understanding and retrieval.

While I’m fascinated by the effectiveness of this end-to-end visual understanding approach without any explicit OCR components, I’m a bit skeptical when looking at visualizations of the activations on high-resolution document images. I wonder: “How does this model handle high-resolution images efficiently?” and “How should I preprocess my images to get the highest accuracy from the pre-trained models?”

Unfortunately, image-preprocessing information was not documented in either the paper or the associated GitHub repository. Therefore, I went on a journey to uncover this information by examining the source code and the model files for the open-source ColPali models. This blog post briefly documents that journey and the eventual findings.

tl;dr: The original ColPali model and its variants simply follow the recipe of its base model, PaliGemma 3B, to preprocess the image, which always resizes the input image to 448×448. This design is not ideal for text-heavy and non-square document images, as one would expect. The researchers later developed some new models that no longer enforce a fixed aspect ratio and allow higher image resolutions to address this restriction.

A Brief Review of Model Implementations

Before investigating the image preprocessor used by ColPali models, let us take a look at the models themselves to understand the general architecture and how the pre-trained model files are structured.

The ColPali class is essentially a PaliGemma model with a custom linear projection layer (the custom_text_proj attribute).

There are four variants of ColPali at the time of writing:

  1. vidore/colpali
  2. vidore/colpali-v1.1
  3. vidore/colpali-v1.2
  4. vidore/colpali-v1.3

According to the adapter_config.json, they were all fine-tuned with LoRA on the same base model — vidore/colpaligemma-3b-pt-448-base.

Judging from the target_modules string, the fine-tuned components include all linear projection layers in the language model portion of PaliGemma, as well as the final custom linear projection layer.

{
  "target_modules": (.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)
}

The vidore/colpaligemma-3b-pt-448-base model seems to be just the PaliGemma-3B model weights ported into the format used by the ColPali class.

Analyze the ColPali Class

Analyze the ColPali Class

A quick glance at the newer ColQwen2 models shows a similar structure: a ported base model (vidore/colqwen2.5-base) with a linear projection layer at the top, and LoRA-fine-tuned variants.

The ColPali Image Preprocessing Pipeline

The patch size used by the ColPali model is the same as that of the PaliGemma-3B model (14×14 pixels), as shown in the base model’s config.json file.

The vision model is a SigLIP model [3], whose preprocessor configurations are included in the model files for all ColPali variants (e.g., preprocessor_config.json for vidore/colpali-v1.3). Below are the contents of the file:

{
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "SiglipImageProcessor",
  "image_seq_length": 1024,
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "processor_class": "ColPaliProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 448,
    "width": 448
  }
}

Before we can understand what each of these configurations mean, we need to locate the Python class that uses them. The preprocessor for ColPali models is ColPaliProcessor, which is based on PaliGemmaProcessor from the transformers library. The PaliGemmaProcessor is a wrapper for the underlying tokenizer, the chat template, and the image processor. The class used by the image processor is defined by the image_processor_type field in the configuration, which leads us to the SiglipImageProcessor class.

Image processing pipeline explained

Image preprocessing primarily occurs in the .preprocess method of SiglipImageProcessor. By looking at the code, we can see that the input image is first resized to 448×448, then rescaled by multiplying by 1/255 (≈ 0.00392156862745098) to map pixel values to the [0, 1] range, and finally normalized by subtracting 0.5 and then dividing by 0.5.

The resulting image sequence length is fixed: 448 / 14 = 32, so 32 × 32 = 1024 for all input images. This forces the input image to a 1:1 (square) aspect ratio, which is not ideal for many document images.

Let’s run a simple experiment to investigate the impact of this image processor on document images. We’ll use the first page of the “Attention Is All You Need” paper as an example. Below is the uncompressed version:

Page 1; Courtesy of theMulti-Vector Image Retrieval course

Page 1; Courtesy of theMulti-Vector Image Retrieval course

We can run the following code to reverse the normalization and rescaling operations in the pipeline and get the resized version of the image:

from PIL import Image

from colpali_engine.models import ColPaliProcessor

model_name = "vidore/colpali-v1.3"

processor = ColPaliProcessor.from_pretrained(model_name)

images = [
  Image.open("../data/attention-is-all-you-need/page-0.png"),
]

batch_images = processor.process_images(images)

restored_pixels = ((batch_images["pixel_values"] * 0.5 + 0.5) * 255).numpy()

Image.fromarray(restored_pixels[0].transpose(1, 2, 0).astype('uint8')).save("page-0-resized.png")
Resized Page 1

Resized Page 1

The text in the resized image is barely legible. It’s impressive that the model can understand the text in the image. The model seems to have learned to read very blurry and slightly distorted characters during training.

A closer look at the Late-Interaction mechanism

I’ve seen some visualizations of the interactions between the query vectors and the document vectors, such as the one in the ColPali paper below. However, seeing the barely legible resized image above made me want to verify the results for myself.

Taken from Figure 3 of the ColiPali paper

Taken from Figure 3 of the ColiPali paper

Below is a visualization of the interactions between the stacks token vector in the query “How do the Encoder and Decoder stacks work together in Transformers?” and the image patch vectors from page 3 of the Attention Is All You Need paper. I picked this token because it has the cleanest activation map among all tokens. However, the Transformers, Encoder, and Decoder tokens also produce similar activation maps.

Activation Map on the original scale

Activation Map on the original scale

I modified the official plotting function by adding grid lines to better identify each image patch and replacing Image.Resampling.BICUBIC with Image.Resampling.NEAREST during upscaling of the activation map to prevent the high-activation (dot-product) values from spilling into adjacent patches. This creates a less aesthetically pleasing but more accurate visualization.

Activation Map on the resized scale

Activation Map on the resized scale

The resized version shows what the model actually sees (technically speaking, it’s slightly blurrier because of the lossy manipulation by Matplotlib). The patches are perfect squares now. It’s still baffling that the model somehow managed to identify the words “Stack” and “stacks” in this noisy, blurry image.

Finally, I picked the patch with the maximum activation for each query token (including the padded tokens, as specified by the paper) to visualize all the patches whose corresponding vectors at the final layer contribute to the results of the late-interaction operator for this query–document pair. The result seems to make sense. Most of the selected patches are around terms like “Transformer”, “Stacks”, “Decoder”, and “Encoder”.

Late-Interaction Activations for ColPali

Late-Interaction Activations for ColPali

new_map = np.zeros(batched_similarity_maps[0].shape[1:])
for idx in range(batched_similarity_maps[0].shape[0]):
    # Pick the `j` index that maximizes the dot product result for each `i`
    index = np.unravel_index(
        batched_similarity_maps[0][idx].to(torch.float32).cpu().numpy().argmax(), 
        batched_similarity_maps[0][idx].shape
    )
    new_map[index[0]][index[1]] += batched_similarity_maps[0][idx][index[0]][index[1]]
fig, ax = plot_similarity_map(
    images[image_idx], 
    torch.tensor(new_map),
    show_grid_lines=True,
    normalization_range=(0, new_map.max())
)
ax.set_title("Late-Interaction Activations", fontsize=12)
plt.tight_layout()  # Adjust layout to fit title
fig.savefig("late-interaction-activations.jpg")`

Overcoming the Restrictive Image Processing Pipeline

As we’ve seen above, the SigLIP [3] image-processing pipeline requires the input image to be resized to exactly 448x448. This imposes a major restriction on the vision model, forcing it to infer text content from blurry, distorted images. Several newer variants of ColPali aim to address this issue by switching to a vision model that supports input images at various resolutions. Let’s examine one of them — ColQwen2 v1.0 — in this section. Below is its preprocessor configuration:

{
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "Qwen2VLImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "max_pixels": 602112,
  "merge_size": 2,
  "min_pixels": 3136,
  "patch_size": 14,
  "processor_class": "ColQwen2Processor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "max_pixels": 602112,
    "min_pixels": 3136,
    "longest_edge": 602112,
    "shortest_edge": 3136
  },
  "temporal_patch_size": 2
}

The preprocessing pipeline employed by Qwen2VLImageProcessor [4] is a bit more involved, so it’s harder to reconstruct the resized image from the returned pixel_values. However, we can infer the shape of the resized image from this reshape operation at the end of its _preprocess method.

flatten_patches = patches.reshape(
    grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
)

For our input images, grid_t is 1 (a static image), grid_h is 62, and grid_w is 48. As the patch size of Qwen2-VL is also 14, these figures translate to an 868-by-672 image. This is much bigger than the 448-by-448 image used by the original ColPali model.

Inferring the shape of the resized images

Inferring the shape of the resized images

Now let’s take a look at the activation map of the same stacks token on the resized image while using ColQwen2 v1.0:

Activation Map on the resized scale

Activation Map on the resized scale

The text is much more legible now! The patches with high activation values also seem to make sense.

Interestingly, the overall late-interaction activation map for ColQwen2 v1.0 looks quite different from the one for ColPali. The former has many high-activation patches in empty areas.

Late-Interaction Activations for ColQwen2

Late-Interaction Activations for ColQwen2

Higher resolutions and aspect-ratio-preserving preprocessing likely contribute to the significantly higher benchmark scores of ColQwen2 variants compared with the original ColPali variants. This finding highlights the importance of carefully designing the image preprocessing pipeline for multimodal models.

References

  1. Faysse, M., Sibille, H., Wu, T., Omrani, B., Viaud, G., Hudelot, C., & Colombo, P. (2024). ColPali: Efficient Document Retrieval with Vision Language Models.
  2. Khattab, O., & Zaharia, M. (2020). ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT.
  3. Zhai, X., Mustafa, B., Kolesnikov, A., & Beyer, L. (2023). Sigmoid Loss for Language Image Pre-Training.
  4. Wang, P., Bai, S., Tan, S., Wang, S., Fan, Z., Bai, J., Chen, K., Liu, X., Wang, J., Ge, W., Fan, Y., Dang, K., Du, M., Ren, X., Men, R., Liu, D., Zhou, C., Zhou, J., & Lin, J. (2024). Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution.

Notebooks and Code Used in this post

I’ve uploaded the notebooks and code files I used to write this blog post to GitHub Gist. Please note that I did not put much effort into ensuring the code is reproducible out of the box, as these are mostly simple exploratory notebooks and scripts. However, most of the code should be adaptable to your use case with just a small amount of tweaking.

AI Use Disclosure

I used AI tools to revise my writing (primarily for grammar and lexical correctness) in this post. However, I wrote most of the content myself; it was not generated with prompts.