Training the network

In every neural network we must train the network in order for it to recognize whatever we are providing it. Our TrainNetwork() function does just that:

private void TrainNetwork(Dictionary<string, (float[][] train, float[][] valid, float[][] test)> dataSet, int hiDim, int cellDim, int iteration, int batchSize, Action<Trainer, Function, int, DeviceDescriptor> progressReport)
{
Split the dataset on TrainNetwork into validate and test parts
var featureSet = dataSet["features"];
var labelSet = dataSet["label"];

Create the model, as follows:

var feature = Variable.InputVariable(new int[] { inDim }, DataType.Float, featuresName, null, false /*isSparse*/);
var label = Variable.InputVariable(new int[] { ouDim }, DataType.Float, labelsName, new List<CNTK.Axis>() { CNTK.Axis.DefaultBatchAxis() }, false);
var lstmModel = LSTMHelper.CreateModel(feature, ouDim, hiDim, cellDim, DeviceDescriptor.CPUDevice, "timeSeriesOutput");
Function trainingLoss = CNTKLib.SquaredError(lstmModel, label, "squarederrorLoss");
Function prediction = CNTKLib.SquaredError(lstmModel, label, "squarederrorEval");

Prepare for training:

TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.0005, 1);
TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256);
IList<Learner> parameterLearners = new List<Learner>()
{
Learner.MomentumSGDLearner(lstmModel?.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true)
};

Create the trainer, as follows:

       var trainer = Trainer.CreateTrainer(lstmModel, trainingLoss, prediction, parameterLearners);

Train the model, as follows:

for (int i = 1; i <= iteration; i++)
{

Get the next minibatch amount of data, as follows:

foreach (var batchData infrom miniBatchData in GetNextDataBatch(featureSet.train, labelSet.train, batchSize)
let xValues = Value.CreateBatch(new NDShape(1, inDim), miniBatchData.X, DeviceDescriptor.CPUDevice)
let yValues = Value.CreateBatch(new NDShape(1, ouDim), miniBatchData.Y, DeviceDescriptor.CPUDevice)
select new Dictionary<Variable, Value>
{
{ feature, xValues },
{ label, yValues }})
{

Train, as follows:

trainer?.TrainMinibatch(batchData, DeviceDescriptor.CPUDevice);
}
if (InvokeRequired)
{
Invoke(new Action(() => progressReport?.Invoke(trainer, lstmModel.Clone(), i, DeviceDescriptor.CPUDevice)));
}
else
{
progressReport?.Invoke(trainer, lstmModel.Clone(), i, DeviceDescriptor.CPUDevice);
}
}
}
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset