mirror of
https://github.com/skishore/makemeahanzi.git
synced 2025-11-01 20:27:44 +08:00
Update classifier to included trained neural net
This commit is contained in:
17
lib/classifier.js
Normal file
17
lib/classifier.js
Normal file
@ -0,0 +1,17 @@
|
||||
Meteor.startup(function() {
|
||||
var weight = 0.8;
|
||||
var dimensions = 8;
|
||||
var input = new convnetjs.Vol(1, 1, dimensions);
|
||||
var net = new convnetjs.Net();
|
||||
net.fromJSON(TRAINED_NEURAL_NET);
|
||||
|
||||
function net_classifier(features) {
|
||||
input.w = features;
|
||||
var softmax = net.forward(input).w;
|
||||
return softmax[1] - softmax[0];
|
||||
}
|
||||
|
||||
this.combined_classifier = function(features) {
|
||||
return hand_tuned_classifier(features) + weight*net_classifier(features);
|
||||
}
|
||||
});
|
||||
1
lib/net.js
Normal file
1
lib/net.js
Normal file
File diff suppressed because one or more lines are too long
@ -503,7 +503,7 @@ this.get_glyph_render_data = function(glyph, manual_bridges, classifier) {
|
||||
}
|
||||
}
|
||||
var log = [];
|
||||
var bridges = get_bridges(endpoints, classifier || hand_tuned_classifier);
|
||||
var bridges = get_bridges(endpoints, classifier || combined_classifier);
|
||||
var strokes = extract_strokes(
|
||||
paths, endpoints, manual_bridges || bridges, log);
|
||||
var expected = UNIHAN_STROKE_COUNTS[glyph.name];
|
||||
|
||||
@ -8,7 +8,7 @@ function evaluate(glyphs, classifier) {
|
||||
return num_correct/glyphs.length;
|
||||
}
|
||||
|
||||
Meteor.startup(function() {
|
||||
function train_neural_net() {
|
||||
var glyphs = Glyphs.find({'manual.verified': true}).fetch();
|
||||
var sample = _.sample(glyphs, 400);
|
||||
console.log('Hand-tuned accuracy:', evaluate(sample, hand_tuned_classifier));
|
||||
@ -52,7 +52,7 @@ Meteor.startup(function() {
|
||||
}
|
||||
console.log('Iteration', iteration, 'mean loss:', loss/round_data.length);
|
||||
}
|
||||
console.log('Trained neural network.');
|
||||
console.log('Trained neural network:', JSON.stringify(net.toJSON()));
|
||||
|
||||
function net_classifier(features) {
|
||||
assert(input.w.length === features.length);
|
||||
@ -73,4 +73,4 @@ Meteor.startup(function() {
|
||||
console.log('Weight', weights[i], 'combined accuracy:',
|
||||
evaluate(sample, combined_classifier(weights[i])));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user