[Notes] Understanding XCiT - Part 2
Local Patch Interaction(LPI) and Class Attention Layer
Jul 25, 2021 · 1661 words · 8 minute read
In Part 1, we introduced the XCiT architecture and reviewed the implementation of the Cross-Covariance Attention(XCA) block. In this Part 2, we’ll review the implementation of the Local Patch Interaction(LPI) block and the Class Attention layer.
Local Patch Interaction(LPI)
Because there is no explicit communication between patches(tokens) in XCA, a layer consisting of two depth-wise 3×3
convolutional layers with Batch Normalization with GELU non-linearity is added to enable explicit communication.
Here’s the implementation of LPI in rwightman/pytorch-image-models:
The module should be very familiar to you if you’re versed in traditional convolution networks. Let’s first review the initialization of the first convolutional layer:
self.conv1 = torch.nn.Conv2d(
in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features)
kernel_size
is 3 by default. padding
is calculated from kernel_size
to retain the size of the feature map. The number of groups is set to the number of channels, so there is no interaction between channels (this kind of layers is called “depth-wise convolution layers”).
The second convolutional layer is mostly the same as the first but with a configurable number of output channels (defaults to the number of input channels).
Remember from Part 1 that the output tensor of an XCA block is in BxNxC
shape. But we’re doing 2-D convolution in LPI, so we need to restore it to a proper shape at the start of the forward
method:
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
The tensor is first permuted to BxCxN
and then reshaped to BxCxHxW
. Remember that in vision transformers, an image is split into patches. Each patch is sent through some processing (usually convolutional) layers to be transformed into a vector. It is analogous to looking up the embedding matrix in NLP applications. The vectors from patches are being flattened and treated like tokens in traditional NLP transformers. This step is to reverse the flattening and restore the 2-D structure.
What follows is the usual convolution operations. The first depth-wise convolution is followed by a non-linearity and a batch normalization layer. Then the second convolution is applied.
x = self.conv1(x)
x = self.act(x)
x = self.bn(x)
x = self.conv2(x)
At the last step, the tensor is flattened again and permuted back to BxCxN
.
x = x.reshape(B, C, N).permute(0, 2, 1)
As you can see, LPI is just a convolutional block with some additional reshaping and permuting operations to fit into the transformer pipeline.
Class Attention Layer
This special layer for class attention is introduced in the CaiT architecture[2]. The CLS
token is added as the input to the first layer in the original vision transformer. This design choice gives the CLS
token two objectives: (1) guiding the self-attention between patches while (2) summarizing the information useful to the linear classifier. CaiT moves the insertion of the CLS
token towards the top and freezes the patch embeddings after the insertion.
Class Attention
We first look at the implementation of the class attention(I slightly rearranged the code without affecting the functionality):
The query tensor is created only from the first token (CLS
). The resulting tensor after unsqueeze
is shaped Bx1xC
. It is then reshaped to Bx1xHx(C/H)
and permuted to BxHx1x(C/H)
. The values in the Q tensor are then divided by $\sqrt{C/H}$ (as in the regular self-attention).
B, N, C = x.shape
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
The key tensor is the standard one. The resulting tensor is shaped BxHxNx(C/H)
.
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
The followings are standard self-attention calculations:
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
The first line is the (batch) matrix multiplication between a BxHx1x(C/H)
tensor and a BxHx(C/H)xN
, which results in a BxHx1xN
tensor. The 1xN
part is the column vector containing the attention weight from the CLS
token to all tokens (CLS
plus the patches).
Softmax is applied to the last dimension (so the attention weights sum to one), and a dropout layer is applied.
Then the new embedding vector for the CLS
token is computed according to the attention matrices (they are actually vectors in this case):
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)
The matrix multiplication in the first line results in a BxHx1x(C/H)
tensor. After transposing and reshaping(concatenating results from all heads), the tensor’s shape becomes Bx1xC
.
The tensor from the first line goes through another linear transformation and a dropout layer. Then we have the new embedding vector for the CLS
token!
Stochastic Depth
Before moving on to the specialized class attention layer, I invite you to review this DropPath module if you’re not already familiar with the stochastic depth regularization.
The dropout mask has a shape of Bx1x1...
. The exact shape depends on the shape of the input tensor:
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
The random_tensor
has range [keep_prob, 1+keep_prob]
. After binarizing/flooring, the probability of a zero is 1 - keep_prob = drop_prob
, and the probability of a one is keep_prob
.
The mask (binarized random_tensor
) is then applied to the input tensor (.div(keep_prob)
is to keep the mean of the tensor the same after the dropout):
output = x.div(keep_prob) * random_tensor
If the mask value is zero, the entire sample will be erased. This should be used on the “main path” (as opposed to the “shortcut path”) in a residual network. The dropped out sample means that the network will only take the shortcut for that sample, which effectively removes one residual layer (hence the name “Stochastic Depth”).
We’ll soon see DropPath
in practice in the next section, so please read ahead even if you’re still confused about this concept.
Class Attention Block
The final specialized class attention layer consists of several ClassAttentionBlock modules:
There are two sets of learnable layer scale parameters (also introduced by CaiT[2]) in the __init__
methods:
if eta is not None: # LayerScale Initialization (no layerscale when None)
self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
else:
self.gamma1, self.gamma2 = 1.0, 1.0
The x
input tensor should already have the CLS
added to its head (with shape Bx(N+1)xC
). It first goes through one layer of class attention:
x_norm1 = self.norm1(x)
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
x = x + self.drop_path(self.gamma1 * x_attn)
The input tensor is layer normalized. self.attn(x_norm)
gives us the new CLS
embedding vectors, which are then concatenated with other tokens x_norm1[:, 1:]
. The resulting tensor is scaled by self.gamma1
and sent through a DropPath
dropout before being added back to the original input tensor x
.
(If the dropout mask for the particular sample is zero, x = x + self.drop_path(self.gamma1 * x_attn)
would become x = x
for that
sample.)
Note that the network can still control the values of patch embedding vectors through self.gamma1
, so the patch embedding vectors are not strictly frozen.
The tensor x
goes through another layer normalization. There are two modes implemented: one that normalizes all token vectors; one that only normalizes the CLS
token vectors (which is the default behavior):
if self.tokens_norm:
x = self.norm2(x)
else:
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
The CLS
token vectors go through a feed-forward network(MLP) and are then scaled again by self.gamma2
. The new CLS
token vectors are concatenated with vectors of other tokens, go through a DropPath
, and are added back to the input tensor x
(almost the same as in the class attention part, except that the scaling is only applied on CLS
here):
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.mlp(cls_token)
x = torch.cat([cls_token, x[:, 1:]], dim=1)
x = x_res + self.drop_path(x)
The tensor x
is returned as the new embedding vectors. We’ve gone through the entire class attention block!
Putting it together
Here’s how the ClassAttentionBlock
modules are initialized in the main XCiT module:
self.cls_attn_blocks = nn.ModuleList([
ClassAttentionBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
for _ in range(cls_attn_layers)])
And here’s the relevant code in the forward_feature
method:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
x = self.norm(x)[:, 0]
return x
The initial CLS
vector is prepended to the tensor x
before sending it to the ClassAttentionBlock
modules. The result gets normalized again before returning.
In case it’s not clear enough, the result from the forward_feature
method is then sent through a classifier(a linear layer) to get the logits:
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
Fin
I hope this two-part series is helpful to you. XCiT is an interesting architecture. The lower memory requirements and its competitive performance (comparing to other STOA models) are all good news to people without access to high-power GPUs (like me). One of the few things I don’t like about this paper and the implementation is the hard-coded usage of batch norms in LPI and patch feature extraction layers. It’d be even better if the author can provide model weights pretrained with GroupNorm[3], which is shown to provide more robust performance with tiny batch sizes.
I’ve also been trying to fine-tune an image classifier using XCiT. The results so far are promising. Be sure to tune the learning rate if you come from the ResNet world. Although XCiT can be understood as “dynamic” convolutions, the usual fine-tuning learning rates for transformers (e.g., 3e-5
) seem to work better than the ones for ResNet.
References
- El-Nouby, A., Touvron, H., Caron, M., Bojanowski, P., Douze, M., Joulin, A., … Jegou, H. (2021). XCiT: Cross-Covariance Image Transformers.
- Touvron, H., Cord, M., Sablayrolles, A., Synnaeve, G., & Jégou, H. (2021). Going deeper with Image Transformers.
- Kolesnikov, A., Beyer, L., Zhai, X., Puigcerver, J., Yung, J., Gelly, S., & Houlsby, N. (2019). Big Transfer (BiT): General Visual Representation Learning.