package org.jlinalg.demo; import java.util.Random; import org.jlinalg.DoubleWrapper; import org.jlinalg.FieldElement; import org.jlinalg.LinAlgFactory; import org.jlinalg.Matrix; import org.jlinalg.MonadicOperator; import org.jlinalg.Vector; /** * Example showing Exclusive-Or neural net problem using JLinAlg. Example can be * easily modified by changing values of inpat, tgpat. * * @author Simon D. Levy */ public class Xor { // training patterns private static double[][] inpat = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } }; private static double[][] tgpat = { { 0 }, { 1 }, { 1 }, { 0 } }; // arbitrary params private static final int NEPOCH = 10000; // how many epochs private static final double ETA = 0.1; // learning rate private static final double MU = 0.9; // momentum private static final int NHID = 3; // how many hidden units // instance vars private LinAlgFactory df; private Random random; int ninp; // how many inputs int nout; // how many outputs int npat; // number of patterns private Matrix wih; // input->hidden weights private Matrix who; // hidden->output weights private Vector bh; // bias on hidden private Vector bo; // bias on output public Xor() { df = new LinAlgFactory(new DoubleWrapper(0)); random = new Random(); // this allows us to generalize to new problems ninp = inpat[0].length; nout = tgpat[0].length; npat = inpat.length; // weights are initially random wih = df.gaussianNoise(ninp, NHID, random); who = df.gaussianNoise(NHID, nout, random); // biases are initially random bh = df.gaussianNoise(NHID, random); bo = df.gaussianNoise(nout, random); } public void train() { // space savers SigmoidOperator sgop = new SigmoidOperator(); SigdervOperator sdop = new SigdervOperator(); DoubleWrapper eta = new DoubleWrapper(ETA); DoubleWrapper mu = new DoubleWrapper(MU); DoubleWrapper npatd = new DoubleWrapper(npat); DoubleWrapper errd = new DoubleWrapper(npat * nout); // initialize momentum terms for weight, bias changes Matrix dwih1 = df.zeros(ninp, NHID); Matrix dwho1 = df.zeros(NHID, nout); Vector dbh1 = df.zeros(NHID); Vector dbo1 = df.zeros(nout); // train for specified number of epochs for (int i = 0; i < NEPOCH; ++i) { // initialize weight, bias changes Matrix dwih = df.zeros(ninp, NHID); Matrix dwho = df.zeros(NHID, nout); Vector dbh = df.zeros(NHID); Vector dbo = df.zeros(nout); // initialize squared error Vector sqrerr = df.zeros(nout); // loop over patterns for (int j = 0; j < npat; ++j) { Vector ai = df.buildVector(inpat[j]); // run forward pass Vector ah = ai.multiply(wih).add(bh).apply(sgop); Vector ao = ah.multiply(who).add(bo).apply(sgop); // compute output error/delta from target Vector eo = df.buildVector(tgpat[j]).subtract(ao); Vector dlo = eo.arrayMultiply(ao.apply(sdop)); // compute hidden error/delta by back-prop from output delta Vector eh = dlo.multiply(who.transpose()); Vector dlh = eh.arrayMultiply(ah.apply(sdop)); // accumulate weight- and bias- changes using the Delta Rule dwih = dwih.add(ai.cross(dlh)); dwho = dwho.add(ah.cross(dlo)); dbh = dbh.add(dlh); dbo = dbo.add(dlo); // accumulate squared error sqrerr = sqrerr.add(eo.arrayMultiply(eo)); } // update weight and biases wih = wih.add(dwih.divide(npatd).multiply(eta)).add( dwih1.divide(npatd).multiply(mu)); who = who.add(dwho.divide(npatd).multiply(eta)).add( dwho1.divide(npatd).multiply(mu)); bh = bh.add(dbh.divide(npatd).multiply(eta)).add( dbh1.divide(npatd).multiply(mu)); bo = bo.add(dbo.divide(npatd).multiply(eta)).add( dbo1.divide(npatd).multiply(mu)); // recall weight, bias changes for momentum on next epoch dwih1 = dwih; dwho1 = dwho; dbh1 = dbh; dbo1 = dbo; // report RMS error first, last, every 1000 epochs if (i == 0 || i == NEPOCH - 1 || ((i + 1) % 1000) == 0) { System.err.println("EPOCH: " + (i + 1) + "\tRMS ERROR: " + Math .sqrt(((DoubleWrapper) (sqrerr.sum() .divide(errd))).doubleValue())); } } } public void test() { SigmoidOperator sgop = new SigmoidOperator(); for (int j = 0; j < npat; ++j) { Vector ai = df.buildVector(inpat[j]); Vector tg = df.buildVector(tgpat[j]); // run forward pass Vector ah = ai.multiply(wih).add(bh).apply(sgop); Vector ao = ah.multiply(who).add(bo).apply(sgop); // report actual, target output System.out.println(ao.toString().substring(2, 10) + " " + tg); } } // sigmoid squashing funciton operator private class SigmoidOperator implements MonadicOperator { public FieldElement apply(FieldElement x) { double dx = ((DoubleWrapper) x).getValue(); return new DoubleWrapper(1 / (1 + Math.exp(-dx))); } } // first derivative of sigmoid w.r.t. activation private class SigdervOperator implements MonadicOperator { public FieldElement apply(FieldElement x) { double dx = ((DoubleWrapper) x).getValue(); return new DoubleWrapper(dx * (1 - dx)); } } public static void main(String[] argv) { Xor xor = new Xor(); xor.train(); xor.test(); } }