Updated grid size model to be static
This commit is contained in:
parent
cc16c8cbf6
commit
396504fd85
@ -6,20 +6,21 @@ import config from "./model.json";
|
|||||||
import weights from "./group1-shard1of1.bin";
|
import weights from "./group1-shard1of1.bin";
|
||||||
|
|
||||||
class GridSizeModel extends Model {
|
class GridSizeModel extends Model {
|
||||||
model;
|
// Store model as static to prevent extra network requests
|
||||||
|
static model;
|
||||||
constructor() {
|
constructor() {
|
||||||
super(config, { "group1-shard1of1.bin": weights });
|
super(config, { "group1-shard1of1.bin": weights });
|
||||||
}
|
}
|
||||||
|
|
||||||
async predict(imageData) {
|
async predict(imageData) {
|
||||||
if (!this.model) {
|
if (!GridSizeModel.model) {
|
||||||
this.model = await tf.loadLayersModel(this);
|
GridSizeModel.model = await tf.loadLayersModel(this);
|
||||||
}
|
}
|
||||||
const prediction = tf.tidy(() => {
|
const prediction = tf.tidy(() => {
|
||||||
const image = tf.browser.fromPixels(imageData, 1).toFloat();
|
const image = tf.browser.fromPixels(imageData, 1).toFloat();
|
||||||
const normalized = image.div(tf.scalar(255.0));
|
const normalized = image.div(tf.scalar(255.0));
|
||||||
const batched = tf.expandDims(normalized);
|
const batched = tf.expandDims(normalized);
|
||||||
return this.model.predict(batched);
|
return GridSizeModel.model.predict(batched);
|
||||||
});
|
});
|
||||||
const data = await prediction.data();
|
const data = await prediction.data();
|
||||||
return data[0];
|
return data[0];
|
||||||
|
Loading…
Reference in New Issue
Block a user