多层前馈神经元网络(二)
ts the ith label
* @param maxIteration : maximum times of interation
* @param threshold : threshold of weights update
* @param errorRate : threshold for error rate
* @return
*/
public boolean train(double[][] data, double target[][], int maxIteration, double threshold, double errorRate) {
// check status
if ( status == Status.TRAINED ){
throw new IllegalStateException();
}
// check input arguments and input parameters
if ( data.length <=0 || data[0].length != neurodes[0].length ||
target.length == 0 || target[0].length != neurodes[depth-1].length ) {
throw new IllegalArgumentException();
}
int round = 1;
boolean convergence = false;
while ( round <= maxIteration && ! convergence ) {
double rate = 0.2;//1.0/round; // learn rate
double delta = 0.0;
for ( int r=0; r
double res = trainWithOneSample(data[r], target[r], rate);
delta = (delta
}
convergence = (delta
round++;
System.out.printf(" %d round of train, delta is %f %n", round-1, delta);
}
return true;
}
/**
* Train the neuronetwork with one entry of sample data
*
* @param data : an vector represent one entry of taining sample
* @param target : an vector represent class label of the training sample
* @param rate : learn rate
* @return : maximum detla of weights
*/
private double trainWithOneSample(double[] data, double[] target, double rate) {
calculateOutput(data);
// calculate error for layer n-1
for ( int j=0; j
double output = neurodes[depth-1][j].output;
neurodes[depth-1][j].err = output*(1-output)*(target[j]-output);
}
// calculate error for hidden layers n-2 ... 1
for ( int d=depth-2; d>0; d-- ) {
for ( int j=0; j
double error = 0.0;
for ( int k=0; k
error += neurodes[d+1][k].err*weights[d][j][k];
}
double output = neurodes[d][j].output;
neurodes[d][j].err = output*(1-output)*error;
}
}
double maxDelta = 0.0;
// update weights
for ( int d=0; d
for ( int i=0; i
for ( int j=0; j
double delta = neurodes[d][i].output*neurodes[d+1][j].err;
weights[d][i][j] += rate*delta;
if ( maxDelta < Math.abs(delta) ) {
maxDelta = Math.abs(delta);
}
}
}
}
// update theta
for ( int d=1; d
for ( int j=0; j
neurodes[d][j].theta += rate*neurodes[d][j].err;
}
}
return maxDelta;
}
}
测试:
[java]
public class TestMain {
public static double[][][] generateData(int m) {
double[][][] res = new double[2][][];
double[][] data = new double[m*m][2];
double[][] label = new double[m*m][3];