A Case Study of fastcore @patch_to
Trying out SnapMix with minimal changes to the codebase
Feb 19, 2021 · 424 words · 2 minute read

Motivation
I recently came across this new image data augmentation technique called SnapMix. It looks like a very sensible improvement over CutMix, so I was eager to give it a try.
The SnapMix author provides a PyTorch implementation. I made some adjustments to improve the numeric stability and converted it to a callback in PyTorch Lightning. I encountered one major obstacle during the process — SnapMix uses Class Activation Mapping(CAM) to calculate an augmented example’s label weights. It requires access to the final linear classifier’s weight and the model activations before the pooling operation. Some PyTorch pre-trained CV models do implement methods to access these two things, but the namings are inconsistent. We need a unified API to do this.
One way to create a unified API is to subclass every pre-trained model and implement get_fc
and extract_features
methods. However, this requires switching existing imports to the new subclasses. As I want to test SnapMix to see if it works quickly, I want to change the existing codebase as little as possible. Here’s where @patch_to
from fastcore came in.
The fastcore library is a spin-off of the fast.ai library. It provides some useful power-ups to the Python standard library, and @patch_to
is one of them.
The Solution
The following code block shows how I patch the EfficientNet class from getn_efficientnet
to make it compatible with SnapMix:
import geffnet
from fastcore.basics import patch_to
@patch_to(geffnet.gen_efficientnet.GenEfficientNet)
def extract_features(self, input_tensor):
return self.features(input_tensor)
@patch_to(geffnet.gen_efficientnet.GenEfficientNet)
def get_fc(self):
return self.classifier
Pretty neat, isn’t it? Another advantage of this approach is that it supports patching multiple classes in one call. For example, if we have another PyTorch class that also stores its final classifier in self.classifier
, we can pass it along with the existing class as a tuple to the @patch_to
decorator.
More Details
The documentation and the source code of the fastcore library can be a bit confusing. Therefore, I create a small demo script to showcase the ability of the @patch_to
decorator:
from fastcore.basics import patch_to | |
class Demo: | |
val = 10 | |
def __init__(self, val): | |
self.val = val | |
# ==================== | |
# The default mode | |
# ==================== | |
@patch_to(Demo) | |
def print(self): | |
print(self.val) | |
Demo(5).print() # prints 5 | |
# ======================= | |
# The class method mode | |
# ======================= | |
@patch_to(Demo, cls_method=True) | |
def print(self): | |
print(self.val) | |
Demo(5).print() # prints 10 | |
# ===================== | |
# The property mode | |
# ===================== | |
@patch_to(Demo, as_prop=True) | |
def print(self): | |
print(self.val) | |
Demo(5).print # prints 5 | |
# ===================== | |
# Addition notes | |
# ===================== | |
# | |
# The function under @patch_to doesn't overwrite existing functions. | |
# In this example, the global built-in function still works normally: | |
print("I am a print statement.") # prints I am a print statement. |
Conclusion
That’s it. Thanks for reading this short post. Let me know if you can think of a better way to patch the PyTorch model. I’d also love to know your experience using the fastcore
library. Leave a comment if you have any questions. I’d usually answer within a week (DM me on Twitter if that doesn’t happen. Sometimes I can miss the notification.).
More Details
The documentation and the source code of the fastcore library can be a bit confusing. Therefore, I create a small demo script to showcase the ability of the @patch_to decorator: