Updated grid size model to be static

This commit is contained in:
Mitchell McCaffrey 2020-10-15 15:54:10 +11:00
parent cc16c8cbf6
commit 396504fd85

View File

@ -6,20 +6,21 @@ import config from "./model.json";
import weights from "./group1-shard1of1.bin"; import weights from "./group1-shard1of1.bin";
class GridSizeModel extends Model { class GridSizeModel extends Model {
model; // Store model as static to prevent extra network requests
static model;
constructor() { constructor() {
super(config, { "group1-shard1of1.bin": weights }); super(config, { "group1-shard1of1.bin": weights });
} }
async predict(imageData) { async predict(imageData) {
if (!this.model) { if (!GridSizeModel.model) {
this.model = await tf.loadLayersModel(this); GridSizeModel.model = await tf.loadLayersModel(this);
} }
const prediction = tf.tidy(() => { const prediction = tf.tidy(() => {
const image = tf.browser.fromPixels(imageData, 1).toFloat(); const image = tf.browser.fromPixels(imageData, 1).toFloat();
const normalized = image.div(tf.scalar(255.0)); const normalized = image.div(tf.scalar(255.0));
const batched = tf.expandDims(normalized); const batched = tf.expandDims(normalized);
return this.model.predict(batched); return GridSizeModel.model.predict(batched);
}); });
const data = await prediction.data(); const data = await prediction.data();
return data[0]; return data[0];