Converted /ml folder to typescript
This commit is contained in:
parent
349cad53a2
commit
32f6e1fb23
@ -1,15 +1,21 @@
|
|||||||
|
import { ModelJSON, WeightsManifestConfig } from "@tensorflow/tfjs-core/dist/io/types";
|
||||||
import blobToBuffer from "../helpers/blobToBuffer";
|
import blobToBuffer from "../helpers/blobToBuffer";
|
||||||
|
|
||||||
class Model {
|
class Model {
|
||||||
constructor(config, weightsMapping) {
|
config: ModelJSON;
|
||||||
|
weightsMapping: { [path: string]: string };
|
||||||
|
constructor(config: ModelJSON, weightsMapping: { [path: string]: string }) {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
this.weightsMapping = weightsMapping;
|
this.weightsMapping = weightsMapping;
|
||||||
}
|
}
|
||||||
|
|
||||||
async load() {
|
async load() {
|
||||||
// Load weights from the manifest then fetch them into an ArrayBuffer
|
// Load weights from the manifest then fetch them into an ArrayBuffer
|
||||||
let buffers = [];
|
let buffers: ArrayBuffer[] = [];
|
||||||
const manifest = this.config.weightsManifest[0];
|
if (this.config === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const manifest = this.config?.weightsManifest[0];
|
||||||
for (let path of manifest.paths) {
|
for (let path of manifest.paths) {
|
||||||
const url = this.weightsMapping[path];
|
const url = this.weightsMapping[path];
|
||||||
const response = await fetch(url);
|
const response = await fetch(url);
|
@ -2,17 +2,21 @@ import Model from "../Model";
|
|||||||
|
|
||||||
import config from "./model.json";
|
import config from "./model.json";
|
||||||
import weights from "./group1-shard1of1.bin";
|
import weights from "./group1-shard1of1.bin";
|
||||||
|
import { LayersModel } from "@tensorflow/tfjs";
|
||||||
|
import { ModelJSON } from "@tensorflow/tfjs-core/dist/io/types";
|
||||||
|
|
||||||
class GridSizeModel extends Model {
|
class GridSizeModel extends Model {
|
||||||
// Store model as static to prevent extra network requests
|
// Store model as static to prevent extra network requests
|
||||||
static model;
|
static model: LayersModel;
|
||||||
// Load tensorflow dynamically
|
// Load tensorflow dynamically
|
||||||
static tf;
|
|
||||||
|
// TODO: find type for tf
|
||||||
|
static tf: any;
|
||||||
constructor() {
|
constructor() {
|
||||||
super(config, { "group1-shard1of1.bin": weights });
|
super(config as ModelJSON, { "group1-shard1of1.bin": weights });
|
||||||
}
|
}
|
||||||
|
|
||||||
async predict(imageData) {
|
async predict(imageData: ImageData) {
|
||||||
if (!GridSizeModel.tf) {
|
if (!GridSizeModel.tf) {
|
||||||
GridSizeModel.tf = await import("@tensorflow/tfjs");
|
GridSizeModel.tf = await import("@tensorflow/tfjs");
|
||||||
}
|
}
|
||||||
@ -23,7 +27,8 @@ class GridSizeModel extends Model {
|
|||||||
}
|
}
|
||||||
const model = GridSizeModel.model;
|
const model = GridSizeModel.model;
|
||||||
|
|
||||||
const prediction = tf.tidy(() => {
|
// TODO: check this mess -> changing type on prediction causes issues
|
||||||
|
const prediction: any = tf.tidy(() => {
|
||||||
const image = tf.browser.fromPixels(imageData, 1).toFloat();
|
const image = tf.browser.fromPixels(imageData, 1).toFloat();
|
||||||
const normalized = image.div(tf.scalar(255.0));
|
const normalized = image.div(tf.scalar(255.0));
|
||||||
const batched = tf.expandDims(normalized);
|
const batched = tf.expandDims(normalized);
|
Loading…
x
Reference in New Issue
Block a user