Converted /ml folder to typescript

This commit is contained in:
Nicola Thouliss 2021-05-29 14:20:24 +10:00
parent 349cad53a2
commit 32f6e1fb23
2 changed files with 19 additions and 8 deletions

View File

@ -1,15 +1,21 @@
import { ModelJSON, WeightsManifestConfig } from "@tensorflow/tfjs-core/dist/io/types";
import blobToBuffer from "../helpers/blobToBuffer";
class Model {
constructor(config, weightsMapping) {
config: ModelJSON;
weightsMapping: { [path: string]: string };
constructor(config: ModelJSON, weightsMapping: { [path: string]: string }) {
this.config = config;
this.weightsMapping = weightsMapping;
}
async load() {
// Load weights from the manifest then fetch them into an ArrayBuffer
let buffers = [];
const manifest = this.config.weightsManifest[0];
let buffers: ArrayBuffer[] = [];
if (this.config === undefined) {
return;
}
const manifest = this.config?.weightsManifest[0];
for (let path of manifest.paths) {
const url = this.weightsMapping[path];
const response = await fetch(url);

View File

@ -2,17 +2,21 @@ import Model from "../Model";
import config from "./model.json";
import weights from "./group1-shard1of1.bin";
import { LayersModel } from "@tensorflow/tfjs";
import { ModelJSON } from "@tensorflow/tfjs-core/dist/io/types";
class GridSizeModel extends Model {
// Store model as static to prevent extra network requests
static model;
static model: LayersModel;
// Load tensorflow dynamically
static tf;
// TODO: find type for tf
static tf: any;
constructor() {
super(config, { "group1-shard1of1.bin": weights });
super(config as ModelJSON, { "group1-shard1of1.bin": weights });
}
async predict(imageData) {
async predict(imageData: ImageData) {
if (!GridSizeModel.tf) {
GridSizeModel.tf = await import("@tensorflow/tfjs");
}
@ -23,7 +27,8 @@ class GridSizeModel extends Model {
}
const model = GridSizeModel.model;
const prediction = tf.tidy(() => {
// TODO: check this mess -> changing type on prediction causes issues
const prediction: any = tf.tidy(() => {
const image = tf.browser.fromPixels(imageData, 1).toFloat();
const normalized = image.div(tf.scalar(255.0));
const batched = tf.expandDims(normalized);