MNIST
The MNIST dataset is a dataset of handwritten digits that is commonly used as the 'Hello World' dataset in Deep Learning domain. It contains 60,000 training images and 10,000 testing images, and
carefree-learn provided a straightforward API to access it.
MNIST dataset can be used for training various image processing systems. In this article, we will demonstrate how to actually utilize carefree-learn to solve these different tasks on MNIST dataset.
tip
- As shown above, the MNIST dataset could be easily turned into a
DLDataModuleinstance, which is the common data interface used incarefree-learn. - The
transformargument specifies which transform do we want to use to pre-process the input batch. See Transforms section for more details.
Classification
| Python source code | Jupyter Notebook | Task |
|---|---|---|
| run_clf.py | clf.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based classifier:
We leveraged the register_module API here, which can turn a general nn.Module instance to a ModelProtocol in carefree-learn. After registered, it can be easily accessed with its name ("simple_conv"):
Our model achieves 98.0400% accuracy on validation set within 1 epoch, not bad!
Variational Auto Encoder
| Python source code | Jupyter Notebook | Task |
|---|---|---|
| run_vae.py | vae.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based VAE:
There are quite a few details that worth to be mentioned:
- We leveraged the
register_moduleAPI here, which can turn a generalnn.Moduleinstance to aModelProtocolincarefree-learn. After registered, it can be easily accessed with its name ("simple_vae") - We leveraged some built-in common blocks of
carefree-learnto build our simple VAE:Lambda, which can turn a function to annn.Module.UpsampleConv2d, which can be used to upsample the input image.interpolate, which is a handy function to resize the input image to the desired size.
After we finished implementing our model, we need to implement the special loss used in VAE tasks:
- We used
register_loss_moduleto register a generalLossModuleinstance to aLossProtocolincarefree-learn. - We can call
register_loss_modulemultiple times to assign multiple names to the same loss function. - When the loss function shares the same name with the model, we don't need to specify the
loss_nameargument explicitly:
Of course, we can still specify loss_name explicitly:
Generative Adversarial Network
| Python source code | Jupyter Notebook | Task |
|---|---|---|
| run_gan.py | gan.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based GAN. But first, let's build the loss function of GAN:
Although the concept of GAN is fairly easy, it's pretty complicated if we want to implement it with a 'pre-defined' framework. In order to provide full flexibility, carefree-learn exposed two methods for users:
train_step, which is used to control ALL training behaviours, including:- calculate losses
- apply back propagation
- perform automatic mixed precision, gradient norm clipping and so on
evaluate_step, which is used to define the final metric that we want to monitor.
Besides, we also need to define the forward method, as usual.
We leveraged the register_custom_module API here, which can turn a general CustomModule instance to a ModelProtocol in carefree-learn. After registered, it can be easily accessed with its name ("simple_gan").
There are two more things that are worth mentioning:
- When using models with custom steps, we don't need to specify
loss_nameanymore, because the losses are calculated insidetrain_step. - The
register_custom_moduleAPI will generate aModelProtocol, whosecoreproperty points to the originalCustomModule. From the above codes, we can see thatSimpleGANimplementsg_parametersandd_parameters, which means theself.core.g_parametersandself.core.d_parametersof the generatedModelProtocolwill be two sets of parameters that we wish to optimize.- In this case, the
core.g_parameterandcore.d_parameterswill be the optimizescopeof the generatedModelProtocol. That's why we access the optimizers with them. - Please refer to the
OptimizerPacksection for more details.
- In this case, the