diff --git a/server/training.js b/server/training.js index bd5df70d..2f86c456 100644 --- a/server/training.js +++ b/server/training.js @@ -23,7 +23,6 @@ Meteor.startup(function() { } console.log('Got ' + training_data.length + ' rows of training data.'); - var input = new convnetjs.Vol(1, 1, 8); var net = new convnetjs.Net(); net.makeLayers([ {type: 'input', out_sx: 1, out_sy: 1, out_depth: 8}, @@ -33,20 +32,24 @@ Meteor.startup(function() { ]); var trainer = new convnetjs.Trainer( net, {method: 'adadelta', l2_decay: 0.001, batch_size: 10}); - for (var iteration = 0; iteration < 1; iteration++) { + var input = new convnetjs.Vol(1, 1, 8); + for (var iteration = 0; iteration < 10; iteration++) { + var loss = 0; for (var i = 0; i < training_data.length; i++) { assert(input.w.length === training_data[i][0].length); input.w = training_data[i][0]; - trainer.train(input, training_data[i][1]); + var stats = trainer.train(input, [training_data[i][1]]); + assert(!isNaN(stats.loss)) + loss += stats.loss; } - console.log('Completed iteration:', iteration); + console.log('Iteration', iteration, 'loss:', loss/training_data.length); } console.log('Trained neural network.'); function net_classifier(features) { assert(input.w.length === features.length); input.w = features; - return net.forward(input).w[0]; + return net.forward(input).w[0] || 0; } console.log('Neural-net accuracy:', evaluate(glyphs, net_classifier)); });