diff --git a/src/components/party/DiceTrayButton.js b/src/components/party/DiceTrayButton.js index 96679db..524983d 100644 --- a/src/components/party/DiceTrayButton.js +++ b/src/components/party/DiceTrayButton.js @@ -1,13 +1,16 @@ import React, { useState } from "react"; -import { Flex, IconButton } from "theme-ui"; +import { Flex, IconButton, Box } from "theme-ui"; import ExpandMoreDiceIcon from "../../icons/ExpandMoreDiceIcon"; -import DiceTrayOverlay from "../dice/DiceTrayOverlay"; import { DiceLoadingProvider } from "../../contexts/DiceLoadingContext"; import useSetting from "../../helpers/useSetting"; +import LoadingOverlay from "../LoadingOverlay"; + +const DiceTrayOverlay = React.lazy(() => import("../dice/DiceTrayOverlay")); + function DiceTrayButton({ shareDice, onShareDiceChage, @@ -46,13 +49,31 @@ function DiceTrayButton({ - + + + + ) + } + > + + ); diff --git a/src/ml/gridSize/GridSizeModel.js b/src/ml/gridSize/GridSizeModel.js index 2b4dcaa..6775f75 100644 --- a/src/ml/gridSize/GridSizeModel.js +++ b/src/ml/gridSize/GridSizeModel.js @@ -1,5 +1,3 @@ -import * as tf from "@tensorflow/tfjs"; - import Model from "../Model"; import config from "./model.json"; @@ -8,19 +6,28 @@ import weights from "./group1-shard1of1.bin"; class GridSizeModel extends Model { // Store model as static to prevent extra network requests static model; + // Load tensorflow dynamically + static tf; constructor() { super(config, { "group1-shard1of1.bin": weights }); } async predict(imageData) { + if (!GridSizeModel.tf) { + GridSizeModel.tf = await import("@tensorflow/tfjs"); + } + const tf = GridSizeModel.tf; + if (!GridSizeModel.model) { GridSizeModel.model = await tf.loadLayersModel(this); } + const model = GridSizeModel.model; + const prediction = tf.tidy(() => { const image = tf.browser.fromPixels(imageData, 1).toFloat(); const normalized = image.div(tf.scalar(255.0)); const batched = tf.expandDims(normalized); - return GridSizeModel.model.predict(batched); + return model.predict(batched); }); const data = await prediction.data(); return data[0];