Train from ONNX models¶
The ONNX specification does not include any training parameter. To perform a training on an imported ONNX model, it is possible to add the training elements (solvers, learning rate scheduler…) on top of an ONNX model in N2D2, in the INI file directly or using the Python API.
This is particularly useful to perform transfer learning from an existing ONNX model trained on ImageNet for example.
With an INI file¶
We propose in this section to apply transfer learning to a MobileNet v1 ONNX
model. We assume that this model is obtained by converting the reference
pre-trained model from Google using the tools/mobilenet_v1_to_onnx.sh
tool
provided in N2D2. The resulting model file name is therefore assumed to be
mobilenet_v1_1.0_224.onnx
.
1) Remove the original classifier¶
The first step to perform transfer learning is to remove the existing classifier
from the ONNX model. To do so, one can simply use the Ignore
parameter in
the ONNX INI section.
[onnx]
Input=sp
Type=ONNX
File=mobilenet_v1_1.0_224.onnx
; Remove the last layer and the softmax for transfer learning
Ignore=Conv__252:0 MobilenetV1/Predictions/Softmax:0
2) Add a new classifier to the ONNX model¶
The next step is to add a new classifier (fully connected layer with a softmax) and connect it to the last layer in the ONNX model.
In order to properly handle graph dependencies, all the N2D2 layers connected
to a layer embedded in an ONNX model, must take the ONNX section name
(here onnx
) as first input in the Input
parameter. The actual inputs
are then added in the comma-separated list, which can mix ONNX and N2D2 layers.
In the example below, the average pooling layer from the ONNX model is connected
to the Fc
cell:
; Here, we add our new layers for transfer learning
[fc]
; first input MUST BE "onnx"
; for proper dependency handling
Input=onnx,MobilenetV1/Logits/AvgPool_1a/AvgPool:0
Type=Fc
NbOutputs=100
ActivationFunction=Linear
WeightsFiller=XavierFiller
ConfigSection=common.config
[softmax]
Input=fc
Type=Softmax
NbOutputs=[fc]NbOutputs
WithLoss=1
[softmax.Target]
; Common config for static model
[common.config]
WeightsSolver.LearningRate=0.01
WeightsSolver.Momentum=0.9
WeightsSolver.Decay=0.0005
Solvers.LearningRatePolicy=StepDecay
Solvers.LearningRateStepSize=[sp]_EpochSize
Solvers.LearningRateDecay=0.993
As this new classifier must be trained, all the training parameter must be specified as usual for this layer.
3) Fine tuning (optional)¶
If one wants to also fine-tune the existing ONNX layers, one must set the solver configuration for the ONNX layers, using default configuration sections.
Default configuration sections applies to all the layers of the same type in the
ONNX model. For example, to add default parameters to all convolution layers
in the ONNX model loaded in a section of type ONNX named onnx
,
just add a section named [onnx:Conv_def]
in the INI file. The name of the
default section follows the convention [ONNXSection:N2D2CellType_def]
.
; Default section for ONNX Conv from section "onnx"
; "ConfigSection", solvers and fillers can be specified here...
[onnx:Conv_def]
ConfigSection=common.config
; Default section for ONNX Fc from section "onnx"
[onnx:Fc_def]
ConfigSection=common.config
; For BatchNorm, make sure the stats won't change if there is no fine-tuning
[onnx:BatchNorm_def]
ConfigSection=bn_notrain.config
[bn_notrain.config]
MovingAverageMomentum=0.0
Note
Important: make sure that the BatchNorm stats does not change if the
BatchNorm layer are not fine-tuned! This can be done by setting the
parameter MovingAverageMomentum
to 0.0 for the layer than must not be
fine-tuned.
It is possible to add parameters for a specific ONNX layer by adding a section with the ONNX layer named.
You can fine-tune the whole network or only some of its layers, usually the last
ones. To stop the fine-tuning at a specific layer, one can simply prevent the
gradient from back-propagating further. This can be achieved with the
BackPropagate=0
configuration parameter.
[Conv__250]
ConfigSection=common.config,notrain.config
[notrain.config]
BackPropagate=0
For the full configuration related to this example and more information, have a
look in models/MobileNet_v1_ONNX_transfer.ini
.
With the Python API¶
Coming soon.