MNIST
The MNIST dataset is a dataset of handwritten digits that is commonly used as the 'Hello World' dataset in Deep Learning domain. It contains 60,000 training images and 10,000 testing images, and
carefree-learn
provided a straightforward API to access it.
MNIST dataset can be used for training various image processing systems. In this article, we will demonstrate how to actually utilize carefree-learn
to solve these different tasks on MNIST dataset.
tip
- As shown above, the MNIST dataset could be easily turned into a
DLDataModule
instance, which is the common data interface used incarefree-learn
. - The
transform
argument specifies which transform do we want to use to pre-process the input batch. See Transforms section for more details.
Classification
Python source code | Jupyter Notebook | Task |
---|---|---|
run_clf.py | clf.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based classifier:
We leveraged the register_module
API here, which can turn a general nn.Module
instance to a ModelProtocol
in carefree-learn
. After registered, it can be easily accessed with its name ("simple_conv"
):
Our model achieves 98.0400% accuracy on validation set within 1 epoch, not bad!
Variational Auto Encoder
Python source code | Jupyter Notebook | Task |
---|---|---|
run_vae.py | vae.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based VAE:
There are quite a few details that worth to be mentioned:
- We leveraged the
register_module
API here, which can turn a generalnn.Module
instance to aModelProtocol
incarefree-learn
. After registered, it can be easily accessed with its name ("simple_vae"
) - We leveraged some built-in common blocks of
carefree-learn
to build our simple VAE:Lambda
, which can turn a function to annn.Module
.UpsampleConv2d
, which can be used to upsample the input image.interpolate
, which is a handy function to resize the input image to the desired size.
After we finished implementing our model, we need to implement the special loss used in VAE tasks:
- We used
register_loss_module
to register a generalLossModule
instance to aLossProtocol
incarefree-learn
. - We can call
register_loss_module
multiple times to assign multiple names to the same loss function. - When the loss function shares the same name with the model, we don't need to specify the
loss_name
argument explicitly:
Of course, we can still specify loss_name
explicitly:
Generative Adversarial Network
Python source code | Jupyter Notebook | Task |
---|---|---|
run_gan.py | gan.ipynb | Computer Vision ๐ผ๏ธ |
For demo purpose, we are going to build a simple convolution-based GAN. But first, let's build the loss function of GAN:
Although the concept of GAN is fairly easy, it's pretty complicated if we want to implement it with a 'pre-defined' framework. In order to provide full flexibility, carefree-learn
exposed two methods for users:
train_step
, which is used to control ALL training behaviours, including:- calculate losses
- apply back propagation
- perform automatic mixed precision, gradient norm clipping and so on
evaluate_step
, which is used to define the final metric that we want to monitor.
Besides, we also need to define the forward
method, as usual.
We leveraged the register_custom_module
API here, which can turn a general CustomModule
instance to a ModelProtocol
in carefree-learn
. After registered, it can be easily accessed with its name ("simple_gan"
).
There are two more things that are worth mentioning:
- When using models with custom steps, we don't need to specify
loss_name
anymore, because the losses are calculated insidetrain_step
. - The
register_custom_module
API will generate aModelProtocol
, whosecore
property points to the originalCustomModule
. From the above codes, we can see thatSimpleGAN
implementsg_parameters
andd_parameters
, which means theself.core.g_parameters
andself.core.d_parameters
of the generatedModelProtocol
will be two sets of parameters that we wish to optimize.- In this case, the
core.g_parameter
andcore.d_parameters
will be the optimizescope
of the generatedModelProtocol
. That's why we access the optimizers with them. - Please refer to the
OptimizerPack
section for more details.
- In this case, the