mirror of
https://github.com/skishore/makemeahanzi.git
synced 2025-10-28 21:13:40 +08:00
Restore trained stroke extractor
This commit is contained in:
19
lib/classifier.js
Normal file
19
lib/classifier.js
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"use strict";
|
||||||
|
|
||||||
|
Meteor.startup(() => {
|
||||||
|
const input = new convnetjs.Vol(1, 1, 8 /* feature vector dimensions */);
|
||||||
|
const net = new convnetjs.Net();
|
||||||
|
net.fromJSON(NEURAL_NET_TRAINED_FOR_STROKE_EXTRACTION);
|
||||||
|
const weight = 0.8;
|
||||||
|
|
||||||
|
const trainedClassifier = (features) => {
|
||||||
|
input.w = features;
|
||||||
|
const softmax = net.forward(input).w;
|
||||||
|
return softmax[1] - softmax[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
stroke_extractor.combinedClassifier = (features) => {
|
||||||
|
return stroke_extractor.handTunedClassifier(features) +
|
||||||
|
weight*trainedClassifier(features);
|
||||||
|
}
|
||||||
|
});
|
||||||
22
lib/external/convnet/1.1.0/LICENSE
vendored
Normal file
22
lib/external/convnet/1.1.0/LICENSE
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
The MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2014 Andrej Karpathy
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
||||||
|
|
||||||
2115
lib/external/convnet/1.1.0/convnet.js
vendored
Normal file
2115
lib/external/convnet/1.1.0/convnet.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1
lib/net.js
Normal file
1
lib/net.js
Normal file
File diff suppressed because one or more lines are too long
@ -354,7 +354,7 @@ if (this.stroke_extractor !== undefined) {
|
|||||||
}
|
}
|
||||||
this.stroke_extractor = {};
|
this.stroke_extractor = {};
|
||||||
|
|
||||||
this.stroke_extractor.getBridges = (glyph, classifier) => {
|
stroke_extractor.getBridges = (glyph, classifier) => {
|
||||||
assert(glyph.stages.path)
|
assert(glyph.stages.path)
|
||||||
const paths = svg.convertSVGPathToPaths(glyph.stages.path);
|
const paths = svg.convertSVGPathToPaths(glyph.stages.path);
|
||||||
const endpoints = [];
|
const endpoints = [];
|
||||||
@ -363,11 +363,12 @@ this.stroke_extractor.getBridges = (glyph, classifier) => {
|
|||||||
endpoints.push(new Endpoint(paths, [i, j]));
|
endpoints.push(new Endpoint(paths, [i, j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const bridges = getBridges(endpoints, classifier || handTunedClassifier);
|
classifier = classifier || stroke_extractor.combinedClassifier;
|
||||||
|
const bridges = getBridges(endpoints, classifier);
|
||||||
return {endpoints: endpoints, bridges: bridges};
|
return {endpoints: endpoints, bridges: bridges};
|
||||||
}
|
}
|
||||||
|
|
||||||
this.stroke_extractor.getStrokes = (glyph) => {
|
stroke_extractor.getStrokes = (glyph) => {
|
||||||
assert(glyph.stages.path)
|
assert(glyph.stages.path)
|
||||||
assert(glyph.stages.bridges)
|
assert(glyph.stages.bridges)
|
||||||
const paths = svg.convertSVGPathToPaths(glyph.stages.path);
|
const paths = svg.convertSVGPathToPaths(glyph.stages.path);
|
||||||
@ -383,3 +384,5 @@ this.stroke_extractor.getStrokes = (glyph) => {
|
|||||||
const strokes = stroke_paths.map((x) => svg.convertPathsToSVGPath([x]));
|
const strokes = stroke_paths.map((x) => svg.convertPathsToSVGPath([x]));
|
||||||
return {log: log, strokes: strokes};
|
return {log: log, strokes: strokes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stroke_extractor.handTunedClassifier = handTunedClassifier;
|
||||||
|
|||||||
78
server/training.js
Normal file
78
server/training.js
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"use strict";
|
||||||
|
|
||||||
|
function evaluate(glyphs, classifier) {
|
||||||
|
var num_correct = 0;
|
||||||
|
for (var i = 0; i < glyphs.length; i++) {
|
||||||
|
if (check_classifier_on_glyph(glyphs[i], classifier)) {
|
||||||
|
num_correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_correct/glyphs.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
function train_neural_net() {
|
||||||
|
var glyphs = Glyphs.find({'manual.verified': true}).fetch();
|
||||||
|
var sample = _.sample(glyphs, 400);
|
||||||
|
console.log('Hand-tuned accuracy:', evaluate(sample, hand_tuned_classifier));
|
||||||
|
|
||||||
|
var training_data = [];
|
||||||
|
for (var i = 0; i < glyphs.length; i++) {
|
||||||
|
var glyph_data = get_glyph_training_data(glyphs[i]);
|
||||||
|
var positive_data = glyph_data.filter(function(x) { return x[1] > 0; });
|
||||||
|
var negative_data = glyph_data.filter(function(x) { return x[1] === 0; });
|
||||||
|
if (positive_data.length > negative_data.length) {
|
||||||
|
positive_data = _.sample(positive_data, negative_data.length);
|
||||||
|
} else {
|
||||||
|
negative_data = _.sample(negative_data, positive_data.length);
|
||||||
|
}
|
||||||
|
glyph_data = negative_data.concat(positive_data);
|
||||||
|
for (var j = 0; j < glyph_data.length; j++) {
|
||||||
|
training_data.push(glyph_data[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
console.log('Got ' + training_data.length + ' rows of training data.');
|
||||||
|
|
||||||
|
var net = new convnetjs.Net();
|
||||||
|
net.makeLayers([
|
||||||
|
{type: 'input', out_sx: 1, out_sy: 1, out_depth: 8},
|
||||||
|
{type: 'fc', num_neurons: 8, activation: 'tanh'},
|
||||||
|
{type: 'fc', num_neurons: 8, activation: 'tanh'},
|
||||||
|
{type: 'softmax', num_classes: 2},
|
||||||
|
]);
|
||||||
|
var trainer = new convnetjs.Trainer(
|
||||||
|
net, {method: 'adadelta', l2_decay: 0.001, batch_size: 10});
|
||||||
|
var input = new convnetjs.Vol(1, 1, 8);
|
||||||
|
for (var iteration = 0; iteration < 10; iteration++) {
|
||||||
|
var loss = 0;
|
||||||
|
var round_data = _.sample(training_data, 4000);
|
||||||
|
for (var i = 0; i < round_data.length; i++) {
|
||||||
|
assert(input.w.length === round_data[i][0].length);
|
||||||
|
input.w = round_data[i][0];
|
||||||
|
var stats = trainer.train(input, round_data[i][1]);
|
||||||
|
assert(!isNaN(stats.loss))
|
||||||
|
loss += stats.loss;
|
||||||
|
}
|
||||||
|
console.log('Iteration', iteration, 'mean loss:', loss/round_data.length);
|
||||||
|
}
|
||||||
|
console.log('Trained neural network:', JSON.stringify(net.toJSON()));
|
||||||
|
|
||||||
|
function net_classifier(features) {
|
||||||
|
assert(input.w.length === features.length);
|
||||||
|
input.w = features;
|
||||||
|
var softmax = net.forward(input).w;
|
||||||
|
assert(softmax.length === 2);
|
||||||
|
return softmax[1] - softmax[0];
|
||||||
|
}
|
||||||
|
console.log('Neural-net accuracy:', evaluate(sample, net_classifier));
|
||||||
|
|
||||||
|
function combined_classifier(weight) {
|
||||||
|
return function(features) {
|
||||||
|
return hand_tuned_classifier(features) + weight*net_classifier(features);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var weights = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1];
|
||||||
|
for (var i = 0; i < weights.length; i++) {
|
||||||
|
console.log('Weight', weights[i], 'combined accuracy:',
|
||||||
|
evaluate(sample, combined_classifier(weights[i])));
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user