diff --git a/server/training.js b/server/training.js index ea945757..1014c518 100644 --- a/server/training.js +++ b/server/training.js @@ -10,7 +10,7 @@ function evaluate(glyphs, classifier) { Meteor.startup(function() { var glyphs = Glyphs.find({'manual.verified': true}).fetch(); - var sample = _.sample(glyphs, 100); + var sample = _.sample(glyphs, 400); console.log('Hand-tuned accuracy:', evaluate(sample, hand_tuned_classifier)); var training_data = []; @@ -42,7 +42,7 @@ Meteor.startup(function() { var input = new convnetjs.Vol(1, 1, 8); for (var iteration = 0; iteration < 10; iteration++) { var loss = 0; - var round_data = _.sample(training_data, 1000); + var round_data = _.sample(training_data, 4000); for (var i = 0; i < round_data.length; i++) { assert(input.w.length === round_data[i][0].length); input.w = round_data[i][0]; @@ -62,4 +62,15 @@ Meteor.startup(function() { return softmax[1] - softmax[0]; } console.log('Neural-net accuracy:', evaluate(sample, net_classifier)); + + function combined_classifier(weight) { + return function(features) { + return hand_tuned_classifier(features) + weight*net_classifier(features); + } + } + var weights = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]; + for (var i = 0; i < weights.length; i++) { + console.log('Weight', weights[i], 'combined accuracy:', + evaluate(sample, combined_classifier(weights[i]))); + } });