- Add features and labels into the schema:
Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnString("RowNumber")
schemaBuilder.addColumnInteger("CustomerId")
schemaBuilder.addColumnString("Surname")
schemaBuilder.addColumnInteger("CreditScore");
- Identify and add categorical features to the schema:
schemaBuilder.addColumnCategorical("Geography", Arrays.asList("France","Germany","Spain"))
schemaBuilder.addColumnCategorical("Gender", Arrays.asList("Male","Female"));
- Remove noise features from the dataset:
Schema schema = schemaBuilder.build();
TransformProcess.Builder transformProcessBuilder = new TransformProcess.Builder(schema);
transformProcessBuilder.removeColumns("RowNumber","CustomerId","Surname");
- Transform categorical variables:
transformProcessBuilder.categoricalToInteger("Gender");
- Apply one-hot encoding by calling categoricalToOneHot():
transformProcessBuilder.categoricalToInteger("Gender")
transformProcessBuilder.categoricalToOneHot("Geography");
- Remove the correlation dependency on the Geography feature by calling removeColumns():
transformProcessBuilder.removeColumns("Geography[France]")
Here, we selected France as the correlation variable.
- Extract the data and apply the transformation using TransformProcessRecordReader:
TransformProcess transformProcess = transformProcessBuilder.build();
TransformProcessRecordReader transformProcessRecordReader = new TransformProcessRecordReader(recordReader,transformProcess);
- Create a dataset iterator to train/test:
DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(transformProcessRecordReader,batchSize) .classification(labelIndex,numClasses)
.build();
- Normalize the dataset:
DataNormalization dataNormalization = new NormalizerStandardize();
dataNormalization.fit(dataSetIterator);
dataSetIterator.setPreProcessor(dataNormalization);
- Split the main dataset iterator to train and test iterators:
DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,totalNoOfBatches,ratio);
- Generate train/test iterators from DataSetIteratorSplitter:
DataSetIterator trainIterator = dataSetIteratorSplitter.getTrainIterator();
DataSetIterator testIterator = dataSetIteratorSplitter.getTestIterator();