Analogous to human learning, neural networks may also work in order to not forget the previous knowledge. Using the traditional approaches for neural learning, this is nearly impossible because of the fact that every training involves replacing all the connections already made with the new ones, thereby "forgetting" the previous knowledge, thus arises a need to make the neural networks adapt to new knowledge by incrementing instead of replacing their current knowledge. To address this issue, we are going to explore a method called adaptive resonance theory (ART).
The question that drove the development of this theory was the following: "How can an adaptive system remain plastic to a significant input and yet maintain the stability for irrelevant inputs?" In other words: "How to retain the previously learned information while learning new information?"
We've seen that the competitive learning in unsupervised learning deals with pattern recognition, wherein similar inputs yield similar outputs or fire the same neurons. In an ART topology, the resonance comes in when the information is being retrieved from the network, by providing the feedback from the competitive layer and the input layer. So, while the network receives the data to learn, there is an oscillation resulting from the feedback between the competitive and the input layers. This oscillation stabilizes when the pattern is fully developed inside the neural network. This resonance then reinforces the stored pattern.
A new class called ART
has been created in the som
package. The following table describes the methods of this class:
Class name: ART | |
Attributes | |
private int SIZE_OF_INPUT_LAYER;
|
Global variable to store the number of neurons in the input layer |
private int SIZE_OF_OUTPUT_LAYER;
|
Global variable to store the number of neurons in the output layer |
Methods | |
public NeuralNet train(NeuralNet n)
|
Method to train the neural net based on the ART algorithm |
Parameters: Neural net object to train | |
Returns: Trained neural net object | |
private void initGlobalVars(NeuralNet n)
|
Method to initialize global variables |
Parameters: Neural net object | |
Returns: - | |
private NeuralNet initNet(NeuralNet n)
|
Method to initialize neural net weights |
Parameters: Neural net object | |
Returns: Neural net object with the initialized weights | |
private int calcWinnerNeuron(NeuralNet n, int row_i, double[][] patterns)
|
Method to calculate the winner neuron |
Parameters: Neural net object, row of the training set, training set patterns | |
Returns: Index of the winner neuron | |
private NeuralNet setNetOutput(NeuralNet n, int winnerNeuron)
|
Method to attribute the neural net output |
Parameters: Neural net object, index of winner neuron | |
Returns: Neural net object with the output attributes | |
private boolean vigilanceTest(NeuralNet n, int row_i)
|
Method to verify whether the neural net has learned or not |
Parameters: Neural net object, row of the training set | |
Returns: True if the neural net learned and false if not | |
private NeuralNet fixWeights(NeuralNet n, int row_i, int winnerNeuron)
|
Method to fix the weights of the neural net |
Parameters: Neural net object, row of the training set, index of the winner neuron | |
Returns: Neural net object with the weights fixed | |
Class Implementation with Java: file ART.java |
The training method is shown in the following code. It's possible to notice that first, the global variables and the neural net are initialized. After that, the number of training sets and the training patterns are stored, and then, the training process begins. The first step of this process is to calculate the index of the winner neuron; the second is to make an attribution of the neural net output. The next step involves verifying whether the neural net has learned or not. If it has learned, then the weights are fixed, and if not, another training sample is presented to the net.
public NeuralNet train(NeuralNet n){ this.initGlobalVars( n ); n = this.initNet( n ); int rows = n.getTrainSet().length; double[][] trainPatterns = n.getTrainSet(); for (int epoch = 0; epoch < n.getMaxEpochs(); epoch++) { for (int row_i = 0; row_i < rows; row_i++) { int winnerNeuron = this.calcWinnerNeuron( n, row_i, trainPatterns ); n = this.setNetOutput( n, winnerNeuron ); boolean isMatched = this.vigilanceTest( n, row_i ); if ( isMatched ) { n = this.fixWeights(n, row_i, winnerNeuron); } } } return n; }