Orientation Module
To predict the orientation of an aligned student-id image inputted from the detection module, we shall quickly develop an image classification model and train it on our orientation dataset. We expect the trained orientation model to predict the confidence scores for orientation angles (90, 180, 270, and 360) for an input image.
So, let's resolve the imports of our orientation module.
2.1. Orientation Dataset
The orientation dataset consist of folders containing four subfolders, whereby each subfolder is named according to one of the four orientation classes i.e. '090', '180', '270', and '360'. Each subfolder contains images rotated according to their folder's name.
Pytorch provides torchvision.datasets.ImageFolder for loading datasets with such format without requiring us to hardcode a custom dataset class for the data like we did for the detection dataset.
2.1.1. Define transforms for orientation datasets
Before instantiating our various orientation datasets, we have to define the various transforms which shall be used to initialize them.
Code Brieffings
- Defined and initialized various transforms specific to each of our orientation datasets(training, validation, and testing). We did that by importing and leveraging torchvision.transforms which is a module containing common image transformations.
- transforms.Compose composes several transforms together.
- transforms.Resize resizes the input image to the given size.
- transforms.RandomAffine randomly affines transformation of the image keeping center invariant.
- transforms.RandomApply randomly a list of transformations with a given probability.
- transforms.RandomGrayscale randomly convert image to grayscale with a probability of 'p'.
- transforms.Normalize normalize a tensor image with mean and standard deviation.
- Note that we only applied data augmentation on the training dataset. This is so that our model can easily generalize input data.
2.1.2. Instantiate orientation datasets
We shall leverage Pytorch inbuilt torchvision.datasets.ImageFolder class to effortlessly instantiate our orientation training, validation, and testing datasets.
Code Brieffings
- We initialized training, validation, and testing datasets using torchvision.datasets.ImageFolder with their respective folders.
- We initialized the variables
orientation_classes
, andnum_orientation_classes
to values of our orientation classes and their number respectively.
Just checking the names and number of classes from our orientation dataset to make sure everything is OK!
2.1.3. Visualize orientation dataset
Code Brieffings
- Randomly selected a group of four elements from our orientation training dataset as
(image_tensor, label_tensor)
pairs. - Denormalized
image_tensor
for each pair and had each image plotted displaying their corresponding classes.
2.2. Orientation Model
2.2.1. Define Orientation Model
Note that the model architecture defined below expects input image tensors of shape (3 x 224 x 224) taking after transforms of the orientation datasets. Let's define an architecture for our orientation model from scratch.
Code Brieffings
- Defined
OrientationModel
extending torch.nn.Module. The constructor argumentnum_classes
is equivalent to desired number of classes/labels. - Defined the feed-forward behavior of the neural network by overriding the
forward
method.
Now that we have defined the architecture of our orientation model, let's define the helper function to instantiate it !
Code Brieffings
- Used orientation_model.load_state_dict() to set model weights from state dictionary if
state_dict
is given.
2.2.2. Specify checkpoint and instantiate the model
Looking forward to resumable training and saving of our orientation model, we shall now specify the checkpoints for the state dictionaries of the model and its training optimizer while initializing the model at once.
Code Brieffings
- Selected available computational hardware using torch.device().
Now let's print our orientation model to check if it has been initialized as we expect.
2.3. Training and Validation
2.3.1. Specify data loaders
After initializing the various orientation datasets, let us use them to specify data loaders which shall be used for training, validation, and testing.
Code Brieffings
- Initialized data loaders for each of our orientation datasets (training, validation and testing) by using torch.utils.data.DataLoader().
- Initialized the dictionary variable '
orientation_loaders
', which references all of the data loaders.
2.3.2. Define loss function and optimizer
Let's initialize the optimizer for training the orientation model, and get ready for training !
2.3.3. Define training function
Code Brieffings
- Moved the model to the computation device as the
(data, target)
pairs fromloaders['train']
. - Within the training loop, we reset the gradients before predicting
output
for eachdata
, and its computeloss
to itstarget
. - Find best
loss
asvalid_loss
and update checkpoint accordingly.
2.3.4. Train orientation model
Now let's train our orientation model for 20 epochs.
2.3.5. Resume training orientation model
At the end of every epoch, we had the checkpoints of the orientation module updated. Now let's use these updated checkpoints to reload the orientation model with orientation optimizer and resume the training up to '30' epochs.
important
To reload the orientation model and the orientation optimizer from the checkpoint, simply re-run the code cells in Section 2.2.2. and Section 2.3.2 respectively. Just make sure load_orientation_checkpoint
is set to True
. The resulting outputs shall be identical to the ones below.
Reloading orientation model from the checkpoint. (Section 2.2.2)
Code Brieffings
- Loaded checkpoint using torch.load(). The argument
map_location
is used to specify the computing device into which the checkpoint is loaded.
Reloading orientation optimizer from the checkpoint (Section 2.3.2)
Code Brieffings
- Used orientation_optimizer.load_state_dict() to initialize optimizer weights if
orientation_optimizer_state_dict
is available. This sets the optimizer to the state after that of the previous training.
Now let's resume the training of our orientation model.
You notice that the training starts from epoch 21 since the orientation model has already been trained for 20 epochs.
2.3.6. Evaluate orientation model
To conclude on the performance of your models, it is always of good practice to evaluate them on sample data. We shall evaluate the performance of the orientation model on sample images from the testing dataset.
But, before that let's define the test function.
Code Brieffings
- Put the model to evaluation mode using model.eval(). This disables some training behaviors of our model such as batch normalization and dropout layers.
- Iterate batches of the orientation test loader for
(data, target)
pairs. - Move
(data, target)
pairs to computation device/hardware using to() method. - Predict
output
fordata
and computeloss
to targets. Then the average total loss is computed astest_loss
.
With our test function defined, we shall now use it to evaluate the performance of the orientation model on the orientation test dataset.
2.4 Orientation Correction
Let's properly visualize the performance of our orientation model via inference on sample images from the test dataset one at a time.
Keep in mind that the objective behind an orientation module is to detect the orientation of an aligned document image, and to rectify it where necessary. Therefore, after inferencing every single image, we have shall apply the proper transformation to the image to rectify its orientation if necessary.