Converted /ml folder to typescript
This commit is contained in:
parent
349cad53a2
commit
32f6e1fb23
@ -1,15 +1,21 @@
|
||||
import { ModelJSON, WeightsManifestConfig } from "@tensorflow/tfjs-core/dist/io/types";
|
||||
import blobToBuffer from "../helpers/blobToBuffer";
|
||||
|
||||
class Model {
|
||||
constructor(config, weightsMapping) {
|
||||
config: ModelJSON;
|
||||
weightsMapping: { [path: string]: string };
|
||||
constructor(config: ModelJSON, weightsMapping: { [path: string]: string }) {
|
||||
this.config = config;
|
||||
this.weightsMapping = weightsMapping;
|
||||
}
|
||||
|
||||
async load() {
|
||||
// Load weights from the manifest then fetch them into an ArrayBuffer
|
||||
let buffers = [];
|
||||
const manifest = this.config.weightsManifest[0];
|
||||
let buffers: ArrayBuffer[] = [];
|
||||
if (this.config === undefined) {
|
||||
return;
|
||||
}
|
||||
const manifest = this.config?.weightsManifest[0];
|
||||
for (let path of manifest.paths) {
|
||||
const url = this.weightsMapping[path];
|
||||
const response = await fetch(url);
|
@ -2,17 +2,21 @@ import Model from "../Model";
|
||||
|
||||
import config from "./model.json";
|
||||
import weights from "./group1-shard1of1.bin";
|
||||
import { LayersModel } from "@tensorflow/tfjs";
|
||||
import { ModelJSON } from "@tensorflow/tfjs-core/dist/io/types";
|
||||
|
||||
class GridSizeModel extends Model {
|
||||
// Store model as static to prevent extra network requests
|
||||
static model;
|
||||
static model: LayersModel;
|
||||
// Load tensorflow dynamically
|
||||
static tf;
|
||||
|
||||
// TODO: find type for tf
|
||||
static tf: any;
|
||||
constructor() {
|
||||
super(config, { "group1-shard1of1.bin": weights });
|
||||
super(config as ModelJSON, { "group1-shard1of1.bin": weights });
|
||||
}
|
||||
|
||||
async predict(imageData) {
|
||||
async predict(imageData: ImageData) {
|
||||
if (!GridSizeModel.tf) {
|
||||
GridSizeModel.tf = await import("@tensorflow/tfjs");
|
||||
}
|
||||
@ -23,7 +27,8 @@ class GridSizeModel extends Model {
|
||||
}
|
||||
const model = GridSizeModel.model;
|
||||
|
||||
const prediction = tf.tidy(() => {
|
||||
// TODO: check this mess -> changing type on prediction causes issues
|
||||
const prediction: any = tf.tidy(() => {
|
||||
const image = tf.browser.fromPixels(imageData, 1).toFloat();
|
||||
const normalized = image.div(tf.scalar(255.0));
|
||||
const batched = tf.expandDims(normalized);
|
Loading…
Reference in New Issue
Block a user