I recently came across this new image data augmentation technique called SnapMix. It looks like a very sensible improvement over CutMix, so I was eager to give it a try.
The SnapMix author provides a PyTorch implementation. I made some adjustments to improve the numeric stability and converted it to a callback in PyTorch Lightning. I encountered one major obstacle during the process — SnapMix uses Class Activation Mapping(CAM) to calculate an augmented example’s label weights. It requires access to the final linear classifier’s weight and the model activations before the pooling operation. Some PyTorch pre-trained CV models do implement methods to access these two things, but the namings are inconsistent. We need a unified API to do this.
One way to create a unified API is to subclass every pre-trained model and implement
extract_features methods. However, this requires switching existing imports to the new subclasses. As I want to test SnapMix to see if it works quickly, I want to change the existing codebase as little as possible. Here’s where
@patch_to from fastcore came in.
The fastcore library is a spin-off of the fast.ai library. It provides some useful power-ups to the Python standard library, and
@patch_to is one of them.
The following code block shows how I patch the EfficientNet class from
getn_efficientnet to make it compatible with SnapMix:
import geffnet from fastcore.basics import patch_to @patch_to(geffnet.gen_efficientnet.GenEfficientNet) def extract_features(self, input_tensor): return self.features(input_tensor) @patch_to(geffnet.gen_efficientnet.GenEfficientNet) def get_fc(self): return self.classifier
Pretty neat, isn’t it? Another advantage of this approach is that it supports patching multiple classes in one call. For example, if we have another PyTorch class that also stores its final classifier in
self.classifier, we can pass it along with the existing class as a tuple to the
The documentation and the source code of the fastcore library can be a bit confusing. Therefore, I create a small demo script to showcase the ability of the
That’s it. Thanks for reading this short post. Let me know if you can think of a better way to patch the PyTorch model. I’d also love to know your experience using the
fastcore library. Leave a comment if you have any questions. I’d usually answer within a week (DM me on Twitter if that doesn’t happen. Sometimes I can miss the notification.).
The documentation and the source code of the fastcore library can be a bit confusing. Therefore, I create a small demo script to showcase the ability of the @patch_to decorator: