mirror of
https://github.com/skishore/makemeahanzi.git
synced 2025-11-01 20:27:44 +08:00
Update classifier to included trained neural net
This commit is contained in:
17
lib/classifier.js
Normal file
17
lib/classifier.js
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
Meteor.startup(function() {
|
||||||
|
var weight = 0.8;
|
||||||
|
var dimensions = 8;
|
||||||
|
var input = new convnetjs.Vol(1, 1, dimensions);
|
||||||
|
var net = new convnetjs.Net();
|
||||||
|
net.fromJSON(TRAINED_NEURAL_NET);
|
||||||
|
|
||||||
|
function net_classifier(features) {
|
||||||
|
input.w = features;
|
||||||
|
var softmax = net.forward(input).w;
|
||||||
|
return softmax[1] - softmax[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
this.combined_classifier = function(features) {
|
||||||
|
return hand_tuned_classifier(features) + weight*net_classifier(features);
|
||||||
|
}
|
||||||
|
});
|
||||||
1
lib/net.js
Normal file
1
lib/net.js
Normal file
File diff suppressed because one or more lines are too long
@ -503,7 +503,7 @@ this.get_glyph_render_data = function(glyph, manual_bridges, classifier) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var log = [];
|
var log = [];
|
||||||
var bridges = get_bridges(endpoints, classifier || hand_tuned_classifier);
|
var bridges = get_bridges(endpoints, classifier || combined_classifier);
|
||||||
var strokes = extract_strokes(
|
var strokes = extract_strokes(
|
||||||
paths, endpoints, manual_bridges || bridges, log);
|
paths, endpoints, manual_bridges || bridges, log);
|
||||||
var expected = UNIHAN_STROKE_COUNTS[glyph.name];
|
var expected = UNIHAN_STROKE_COUNTS[glyph.name];
|
||||||
|
|||||||
@ -8,7 +8,7 @@ function evaluate(glyphs, classifier) {
|
|||||||
return num_correct/glyphs.length;
|
return num_correct/glyphs.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
Meteor.startup(function() {
|
function train_neural_net() {
|
||||||
var glyphs = Glyphs.find({'manual.verified': true}).fetch();
|
var glyphs = Glyphs.find({'manual.verified': true}).fetch();
|
||||||
var sample = _.sample(glyphs, 400);
|
var sample = _.sample(glyphs, 400);
|
||||||
console.log('Hand-tuned accuracy:', evaluate(sample, hand_tuned_classifier));
|
console.log('Hand-tuned accuracy:', evaluate(sample, hand_tuned_classifier));
|
||||||
@ -52,7 +52,7 @@ Meteor.startup(function() {
|
|||||||
}
|
}
|
||||||
console.log('Iteration', iteration, 'mean loss:', loss/round_data.length);
|
console.log('Iteration', iteration, 'mean loss:', loss/round_data.length);
|
||||||
}
|
}
|
||||||
console.log('Trained neural network.');
|
console.log('Trained neural network:', JSON.stringify(net.toJSON()));
|
||||||
|
|
||||||
function net_classifier(features) {
|
function net_classifier(features) {
|
||||||
assert(input.w.length === features.length);
|
assert(input.w.length === features.length);
|
||||||
@ -73,4 +73,4 @@ Meteor.startup(function() {
|
|||||||
console.log('Weight', weights[i], 'combined accuracy:',
|
console.log('Weight', weights[i], 'combined accuracy:',
|
||||||
evaluate(sample, combined_classifier(weights[i])));
|
evaluate(sample, combined_classifier(weights[i])));
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user