- Create the ComputationGraph model from the previously-created model configuration:
ComputationGraphConfiguration configuration = builder.build();
ComputationGraph model = new ComputationGraph(configuration);
- Load the iterator and train the model using the fit() method:
for(int i=0;i<epochs;i++){
model.fit(trainDataSetIterator);
}
You can use the following approach as well:
model.fit(trainDataSetIterator,epochs);
We can then avoid using a for loop by directly specifying the epochs parameter in the fit() method.