- Call the load() method to import the model from the saved location:
File savedLocation = new File("model.zip");
boolean saveUpdater = true;
MultiLayerNetwork restored = MultiLayerNetwork.load(savedLocation, saveUpdater);
- Add the required pom dependency to use the deeplearning4j-zoo module:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>1.0.0-beta3</version>
</dependency>
- Add the fine-tuning configuration for MultiLayerNetwork using the TransferLearning API:
MultiLayerNetwork newModel = new TransferLearning.Builder(oldModel)
.fineTuneConfiguration(fineTuneConf)
.build();
- Add the fine-tuning configuration for ComputationGraph using the TransferLearning API:
ComputationGraph newModel = new TransferLearning.GraphBuilder(oldModel).
.fineTuneConfiguration(fineTuneConf)
.build();
- Configure the training session using TransferLearningHelper. TransferLearningHelper can be created in two ways:
- Pass in the model object that was created using the transfer learning builder (step 2) with the frozen layers mentioned:
TransferLearningHelper tHelper = new TransferLearningHelper(newModel);
-
- Create it directly from the imported model by specifying the frozen layers explicitly:
TransferLearningHelper tHelper = new TransferLearningHelper(oldModel, "layer1")
- Featurize the train/test data using the featurize() method:
while(iterator.hasNext()) {
DataSet currentFeaturized = transferLearningHelper.featurize(iterator.next());
saveToDisk(currentFeaturized); //save the featurized date to disk
}
- Create train/test iterators by using ExistingMiniBatchDataSetIterator:
DataSetIterator existingTrainingData = new ExistingMiniBatchDataSetIterator(new File("trainFolder"),"churn-"+featureExtractorLayer+"-train-%d.bin");
DataSetIterator existingTestData = new ExistingMiniBatchDataSetIterator(new File("testFolder"),"churn-"+featureExtractorLayer+"-test-%d.bin");
- Start the training instance on top of the featurized data by calling fitFeaturized():
transferLearningHelper.fitFeaturized(existingTrainingData);
- Evaluate the model by calling evaluate() for unfrozen layers:
transferLearningHelper.unfrozenMLN().evaluate(existingTestData);