Added lazy loading to dice and tensorflow

This commit is contained in:
Mitchell McCaffrey 2020-10-30 11:17:27 +11:00
parent d305532cdb
commit fcbc903d16
2 changed files with 40 additions and 12 deletions

View File

@ -1,13 +1,16 @@
import React, { useState } from "react";
import { Flex, IconButton } from "theme-ui";
import { Flex, IconButton, Box } from "theme-ui";
import ExpandMoreDiceIcon from "../../icons/ExpandMoreDiceIcon";
import DiceTrayOverlay from "../dice/DiceTrayOverlay";
import { DiceLoadingProvider } from "../../contexts/DiceLoadingContext";
import useSetting from "../../helpers/useSetting";
import LoadingOverlay from "../LoadingOverlay";
const DiceTrayOverlay = React.lazy(() => import("../dice/DiceTrayOverlay"));
function DiceTrayButton({
shareDice,
onShareDiceChage,
@ -46,6 +49,23 @@ function DiceTrayButton({
<ExpandMoreDiceIcon isExpanded={isExpanded} />
</IconButton>
<DiceLoadingProvider>
<React.Suspense
fallback={
isExpanded && (
<Box
sx={{
width: "32px",
height: "32px",
position: "absolute",
top: "40px",
left: "8px",
}}
>
<LoadingOverlay />
</Box>
)
}
>
<DiceTrayOverlay
isOpen={isExpanded}
shareDice={shareDice}
@ -53,6 +73,7 @@ function DiceTrayButton({
diceRolls={diceRolls}
onDiceRollsChange={onDiceRollsChange}
/>
</React.Suspense>
</DiceLoadingProvider>
</Flex>
);

View File

@ -1,5 +1,3 @@
import * as tf from "@tensorflow/tfjs";
import Model from "../Model";
import config from "./model.json";
@ -8,19 +6,28 @@ import weights from "./group1-shard1of1.bin";
class GridSizeModel extends Model {
// Store model as static to prevent extra network requests
static model;
// Load tensorflow dynamically
static tf;
constructor() {
super(config, { "group1-shard1of1.bin": weights });
}
async predict(imageData) {
if (!GridSizeModel.tf) {
GridSizeModel.tf = await import("@tensorflow/tfjs");
}
const tf = GridSizeModel.tf;
if (!GridSizeModel.model) {
GridSizeModel.model = await tf.loadLayersModel(this);
}
const model = GridSizeModel.model;
const prediction = tf.tidy(() => {
const image = tf.browser.fromPixels(imageData, 1).toFloat();
const normalized = image.div(tf.scalar(255.0));
const batched = tf.expandDims(normalized);
return GridSizeModel.model.predict(batched);
return model.predict(batched);
});
const data = await prediction.data();
return data[0];