Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ model/trainingdata
*.pt
*.onnx
*.bin
*.exe
*.log
2 changes: 1 addition & 1 deletion backend.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM python:3.12-slim
WORKDIR /quickdraw
COPY ./backend/requirements.txt .
RUN pip install --no-cache-dir --upgrade -r ./requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def websocket_endpoint(websocket: WebSocket):
await wait_pubsub_subscribe(f"game:{game_id}:channel", subs)
await send_next_round(game_id, 0)
await asyncio.gather(
websocket_loop(websocket, game_id, player_id, game_data, player_data),
websocket_loop(websocket, game_id, game_data, player_data),
pubsub_loop(websocket, pubsub),
)

Expand Down
2 changes: 1 addition & 1 deletion deploy/nginx/nginx.conf
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ server {
proxy_set_header X-Forwarded-Proto $scheme;
}

location /model3_4_large.onnx {
location /CNN_cat16_v6-0_large_gputrain.onnx {
root /out;
expires 7d;
add_header Cache-Control "public";
Expand Down
1 change: 1 addition & 0 deletions frontend/.dockerignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
**/node_modules/
**/out/
22 changes: 22 additions & 0 deletions frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
},
"dependencies": {
"js-cookie": "^3.0.5",
"lodash.debounce": "^4.0.8",
"next": "14.2.5",
"onnxruntime-web": "^1.19.0",
"react": "^18",
Expand All @@ -19,6 +20,7 @@
},
"devDependencies": {
"@types/js-cookie": "^3.0.6",
"@types/lodash.debounce": "^4.0.9",
"@types/node": "^20",
"@types/react": "^18",
"@types/react-dom": "^18",
Expand Down
104 changes: 63 additions & 41 deletions frontend/src/components/DrawCanvas.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ import React, {
useRef,
useState,
useEffect,
useImperativeHandle,
forwardRef,
} from "react";
import { Stage, Layer, Line } from "react-konva";
import { InferenceSession, Tensor } from "onnxruntime-web";
import { clear } from "console";
// import { clearInterval } from "timers";
import debounce from "lodash.debounce";

interface AnimateProps {
children: React.ReactNode;
Expand Down Expand Up @@ -35,7 +34,6 @@ interface DrawCanvasProps {
clearCanvas: boolean;
}

let lastDrawn = Date.now();
const DrawCanvas: React.FC<DrawCanvasProps> = ({
dataPass,
onParentClearCanvas,
Expand All @@ -46,39 +44,27 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({
const [confidence, setConfidence] = useState(0);
const isDrawing = useRef(false);
const session = useRef<InferenceSession | null>(null);
const [shouldReEval,setShouldReEval] = useState(false);

const predDebounce = 750;

const modelCategories = [
"apple",
"anvil",
"dresser",
"broom",
"hat",
"camera",
"dog",
"basketball",
"pencil",
"hammer",
"hexagon",
"banana",
"angel",
"airplane",
"ant",
"paper clip",
];
const predDebounce = 400;

const modelCategories = ['The%20Eiffel%20Tower.bin', 'airplane.bin', 'alarm%20clock.bin', 'anvil.bin', 'apple.bin', 'axe.bin', 'banana.bin', 'bed.bin', 'bee.bin', 'birthday%20cake.bin', 'book.bin', 'brain.bin', 'broom.bin', 'bucket.bin', 'calculator.bin', 'camera.bin', 'carrot.bin', 'car.bin', 'clock.bin', 'chair.bin', 'cookie.bin', 'diamond.bin', 'donut.bin', 'door.bin', 'elephant.bin', 'eye.bin', 'fish.bin', 'giraffe.bin', 'hammer.bin', 'hat.bin', 'key.bin', 'knife.bin', 'leaf.bin', 'map.bin', 'microphone.bin', 'mug.bin', 'mushroom.bin', 'nose.bin', 'palm%20tree.bin', 'pants.bin', 'paper%20clip.bin', 'peanut.bin', 'pillow.bin', 'rabbit.bin', 'river.bin']

useEffect(() => {
(async () => {
try {
session.current = await InferenceSession.create("model3_4_large.onnx");
session.current = await InferenceSession.create("CNN_cat45_v6-1_large_gputrain.onnx");
} catch (error) {
// TODO: Handle this error properly
console.error("Failed to load model", error);
}
})();

// const evalTimer = setInterval(() => { console.log("debug"); handleEvaluate()},predDebounce);
// console.log("evaltimer!",evalTimer);

return () => {
// clearInterval(evalTimer)
session.current?.release();
};
}, []);
Expand All @@ -98,12 +84,6 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({
const lastLine = lines[lines.length - 1].concat([point]);
setLines(lines.slice(0, -1).concat([lastLine]));
}

if (Date.now() - lastDrawn > predDebounce) {
// console.log("Evaluating drawing now");
lastDrawn = Date.now();
handleEvaluate();
}
};

const handleMouseUp = () => {
Expand Down Expand Up @@ -201,6 +181,21 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({
rasterImage[i / 4] = imageData.data[i]; // Invert colors (black on white background)
}

// const rasterImage: number[][][][] = new Array(1).fill(null).map(() =>
// new Array(1).fill(null).map(() =>
// new Array(side).fill(null).map(() =>
// new Array(side).fill(0)
// )
// )
// );

// for (let y = 0; y < side; y++) {
// for (let x = 0; x < side; x++) {
// const i = (y * side + x) * 4;
// rasterImage[0][0][y][x] = imageData.data[i] / 255; // Normalize to 0-1
// }
// }

return rasterImage;
};

Expand All @@ -213,23 +208,24 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({

const argMax = (arr: Float32Array): number => arr.indexOf(Math.max(...arr));

async function ONNX(input: any) {
async function ONNX(input: number[]) {
if (session.current === null) {
console.error(
"Attempted to run inference while InferenceSession is null"
);
return;
}
try {
const tensor = new Tensor("float32", new Float32Array(input), [1, 784]);
// const flattenedInput = input.flat(3);
const tensor = new Tensor("float32", new Float32Array(input), [1, 1, 28, 28]);

const inputMap = { input: tensor };

const outputMap = await session.current.run(inputMap);

const output = outputMap["output"].data as Float32Array;

// console.log(output);

return output;
} catch (error) {
console.error("Error running ONNX model:", error);
Expand Down Expand Up @@ -260,26 +256,52 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({
return svgContent;
};

const evalTimeoutRef = useRef<NodeJS.Timeout | null>(null);

useEffect(() => { // evaluate on lines changed
if (lines.length > 0) {

if (lines.length > 0) {
if (evalTimeoutRef.current) {
clearTimeout(evalTimeoutRef.current);
}
evalTimeoutRef.current = setTimeout(handleEvaluate, predDebounce);
}

return () => {
if (evalTimeoutRef.current) {
clearTimeout(evalTimeoutRef.current);
}
};
// const debounceEval = debounce(handleEvaluate,predDebounce)
// debounceEval()

// return () => {
// debounceEval.cancel(); // cleanup on unmount
// };

}
}, [lines]);

const handleEvaluate = () => {
console.log("ah")
const normalizedStrokes = normalizeStrokes(lines);
const rasterArray = rasterizeStrokes(normalizedStrokes);

ONNX(rasterArray).then((res) => {
// console.log(res);
res = res as Float32Array;
let i = argMax(res);
setPrediction(modelCategories[i]);
let prob = softmax(res)[i];
let probPercent = Math.floor(prob * 1000) / 10;
setPrediction(modelCategories[i]);
setConfidence(probPercent);
if (probPercent > 70) {
if (probPercent > 80) {
dataPass(prediction);
}
});
};

useEffect(() => {
// effect to check if clearCanvas is true
useEffect(() => { // effect to check if clearCanvas is true
if (clearCanvas) {
setLines([]);
onParentClearCanvas(); // call the callback function to reset the state in parent component
Expand All @@ -296,10 +318,10 @@ const DrawCanvas: React.FC<DrawCanvasProps> = ({
<div className="grid place-items-center">
{prediction && (
<AnimateText on={prediction}>
I guess... {confidence > 70 ? prediction : "not sure"}!
I guess... {confidence > 80 ? prediction : "not sure"}!
</AnimateText>
)}
{/* confidence > 70 ? (
{/* confidence > 80 ? (
<p className="text-lg font-medium text-green-400">
Confidence (dev): {confidence + "%"}
</p>
Expand Down
Loading