Skip to content

[feature] ONNX Export #1490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions client/dive-common/apispec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ interface Pipe {
name: string;
pipe: string;
type: string;
folderId?: string;
ownerId?: string;
ownerLogin?: string;
}

interface Category {
Expand Down Expand Up @@ -142,6 +145,8 @@ interface DatasetMeta extends DatasetMetaMutable {
interface Api {
getPipelineList(): Promise<Pipelines>;
runPipeline(itemId: string, pipeline: Pipe): Promise<unknown>;
deleteTrainedPipeline(pipeline: Pipe): Promise<void>;
exportTrainedPipeline(path: string, pipeline: Pipe): Promise<unknown>;

getTrainingConfigurations(): Promise<TrainingConfigs>;
runTraining(
Expand Down
12 changes: 12 additions & 0 deletions client/platform/desktop/backend/ipcService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ import OS from 'os';
import http from 'http';
import { ipcMain } from 'electron';
import { MultiCamImportArgs } from 'dive-common/apispec';
import type { Pipe } from 'dive-common/apispec';
import {
DesktopJobUpdate, RunPipeline, RunTraining, Settings, ExportDatasetArgs,
DesktopMediaImportResponse,
ExportTrainedPipeline,
} from 'platform/desktop/constants';

import linux from './native/linux';
Expand Down Expand Up @@ -32,6 +34,10 @@ export default function register() {
const ret = await common.getPipelineList(settings.get());
return ret;
});
ipcMain.handle('delete-trained-pipeline', async (event, args: Pipe) => {
const ret = await common.deleteTrainedPipeline(args);
return ret;
})
ipcMain.handle('get-training-configs', async () => {
const ret = await common.getTrainingConfigs(settings.get());
return ret;
Expand Down Expand Up @@ -122,6 +128,12 @@ export default function register() {
};
return currentPlatform.runPipeline(settings.get(), args, updater);
});
ipcMain.handle('export-trained-pipeline', async (event, args: ExportTrainedPipeline) => {
const updater = (update: DesktopJobUpdate) => {
event.sender.send('job-update', update);
};
return currentPlatform.exportTrainedPipeline(settings.get(), args, updater);
});
ipcMain.handle('run-training', async (event, args: RunTraining) => {
const updater = (update: DesktopJobUpdate) => {
event.sender.send('job-update', update);
Expand Down
12 changes: 12 additions & 0 deletions client/platform/desktop/backend/native/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
DatasetMetaMutableKeys,
AnnotationSchema,
SaveAttributeTrackFilterArgs,
Pipe,
} from 'dive-common/apispec';
import * as viameSerializers from 'platform/desktop/backend/serializers/viame';
import * as nistSerializers from 'platform/desktop/backend/serializers/nist';
Expand Down Expand Up @@ -444,6 +445,16 @@ async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs>
};
}

/**
* delete a trained pipeline
*/
async function deleteTrainedPipeline(pipeline: Pipe): Promise<void> {
if (pipeline.type !== 'trained') throw new Error(`${pipeline.name} is not a trained pipeline`);

const parent = npath.parse(pipeline.pipe).dir;
await fs.remove(parent);
}

/**
* _saveSerialized save pre-serialized tracks to disk
*/
Expand Down Expand Up @@ -1195,6 +1206,7 @@ export {
exportDataset,
finalizeMediaImport,
getPipelineList,
deleteTrainedPipeline,
getTrainingConfigs,
getProjectDir,
getValidatedProjectDir,
Expand Down
13 changes: 13 additions & 0 deletions client/platform/desktop/backend/native/linux.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
NvidiaSmiReply,
RunTraining,
DesktopJobUpdater,
ExportTrainedPipeline,
} from 'platform/desktop/constants';
import { observeChild } from 'platform/desktop/backend/native/processManager';
import * as viame from './viame';
Expand Down Expand Up @@ -88,6 +89,17 @@ async function runPipeline(
});
}

async function exportTrainedPipeline(
settings: Settings,
exportTrainedPipelineArgs: ExportTrainedPipeline,
updater: DesktopJobUpdater,
): Promise<DesktopJob> {
return viame.exportTrainedPipeline(settings, exportTrainedPipelineArgs, updater, validateViamePath, {
...ViameLinuxConstants,
setupScriptAbs: sourceString(settings),
});
}

async function train(
settings: Settings,
runTrainingArgs: RunTraining,
Expand Down Expand Up @@ -132,6 +144,7 @@ export default {
DefaultSettings,
nvidiaSmi,
runPipeline,
exportTrainedPipeline,
train,
validateViamePath,
};
15 changes: 15 additions & 0 deletions client/platform/desktop/backend/native/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ async function createWorkingDirectory(settings: Settings, jsonMetaList: JsonMeta
return runFolderPath;
}

async function createCustomWorkingDirectory(settings: Settings, prefix: string, pipeline: string) {
const jobFolderPath = path.join(settings.dataPath, JobsFolderName);
// Formating prefix if for any reason the prefix is input by the user in the futur
// eslint-disable-next-line no-useless-escape
const safePrefix = prefix.replace(/[\.\s/]+/g, '_');
const runFolderName = moment().format(`[${safePrefix}_${pipeline}]_MM-DD-yy_hh-mm-ss.SSS`);
const runFolderPath = path.join(jobFolderPath, runFolderName);
if (!fs.existsSync(jobFolderPath)) {
await fs.mkdir(jobFolderPath);
}
await fs.mkdir(runFolderPath);
return runFolderPath;
}

/* same as os.path.splitext */
function splitExt(input: string): [string, string] {
const ext = path.extname(input);
Expand All @@ -112,6 +126,7 @@ export {
getBinaryPath,
jobFileEchoMiddleware,
createWorkingDirectory,
createCustomWorkingDirectory,
spawnResult,
splitExt,
};
96 changes: 95 additions & 1 deletion client/platform/desktop/backend/native/viame.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import fs from 'fs-extra';
import {
Settings, DesktopJob, RunPipeline, RunTraining,
DesktopJobUpdater,
ExportTrainedPipeline,
} from 'platform/desktop/constants';
import { cleanString } from 'platform/desktop/sharedUtils';
import { serialize } from 'platform/desktop/backend/serializers/viame';
import { observeChild } from 'platform/desktop/backend/native/processManager';

import { MultiType, stereoPipelineMarker, multiCamPipelineMarkers } from 'dive-common/constants';
import * as common from './common';
import { jobFileEchoMiddleware, createWorkingDirectory } from './utils';
import { jobFileEchoMiddleware, createWorkingDirectory, createCustomWorkingDirectory } from './utils';
import {
getMultiCamImageFiles, getMultiCamVideoPath,
writeMultiCamStereoPipelineArgs,
Expand Down Expand Up @@ -212,6 +213,98 @@ async function runPipeline(
return jobBase;
}

/**
* a node.js implementation of dive_tasks.tasks.export_trained_model
*/
async function exportTrainedPipeline(settings: Settings,
exportTrainedPipelineArgs: ExportTrainedPipeline,
updater: DesktopJobUpdater,
validateViamePath: (settings: Settings) => Promise<true | string>,
viameConstants: ViameConstants,
): Promise<DesktopJob> {
const { path, pipeline } = exportTrainedPipelineArgs;

const isValid = await validateViamePath(settings);
if (isValid !== true) {
throw new Error(isValid);
}

const exportPipelinePath = npath.join(settings.viamePath, PipelineRelativeDir, "convert_to_onnx.pipe");
if (!fs.existsSync(npath.join(exportPipelinePath))) {
throw new Error("Your VIAME version doesn't support ONNX export. You have to update it to a newer version to be able to export models.");
}

const modelPipelineDir = npath.parse(pipeline.pipe).dir;
let weightsPath: string;
if (fs.existsSync(npath.join(modelPipelineDir, 'yolo.weights'))) {
weightsPath = npath.join(modelPipelineDir, 'yolo.weights');
} else {
throw new Error("Your pipeline has no trained weights (yolo.weights is missing)");
}

const jobWorkDir = await createCustomWorkingDirectory(settings, 'OnnxExport', pipeline.name);

const converterOutput = npath.join(jobWorkDir, 'model.onnx');
const joblog = npath.join(jobWorkDir, 'runlog.txt');

const command = [
`${viameConstants.setupScriptAbs} &&`,
`"${viameConstants.kwiverExe}" runner ${exportPipelinePath}`,
`-s "onnx_convert:model_path=${weightsPath}"`,
`-s "onnx_convert:onnx_model_prefix=${converterOutput}"`,
];

const job = observeChild(spawn(command.join(' '), {
shell: viameConstants.shell,
cwd: jobWorkDir,
}));

const jobBase: DesktopJob = {
key: `pipeline_${job.pid}_${jobWorkDir}`,
command: command.join(' '),
jobType: 'export',
pid: job.pid,
args: exportTrainedPipelineArgs,
title: `${exportTrainedPipelineArgs.pipeline.name} to ONNX`,
workingDir: jobWorkDir,
datasetIds: [],
exitCode: job.exitCode,
startTime: new Date(),
};

fs.writeFile(npath.join(jobWorkDir, DiveJobManifestName), JSON.stringify(jobBase, null, 2));

updater({
...jobBase,
body: [''],
});

job.stdout.on('data', jobFileEchoMiddleware(jobBase, updater, joblog));
job.stderr.on('data', jobFileEchoMiddleware(jobBase, updater, joblog));

job.on('exit', async (code) => {
if (code === 0) {
if (fs.existsSync(converterOutput)) {
if (fs.existsSync(path)) {
fs.unlinkSync(path);
}
// We move instead of copying because .onnx files can be huge
fs.moveSync(converterOutput, path);
} else {
console.error("An error occured while creating the ONNX file.");
}
}
updater({
...jobBase,
body: [''],
exitCode: code,
endTime: new Date(),
});
});

return jobBase;
}

/**
* a node.js implementation of dive_tasks.tasks.run_training
*/
Expand Down Expand Up @@ -356,5 +449,6 @@ async function train(

export {
runPipeline,
exportTrainedPipeline,
train,
};
13 changes: 13 additions & 0 deletions client/platform/desktop/backend/native/windows.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
Settings, SettingsCurrentVersion,
DesktopJob, RunPipeline, NvidiaSmiReply, RunTraining,
DesktopJobUpdater,
ExportTrainedPipeline,
} from 'platform/desktop/constants';
import * as viame from './viame';

Expand Down Expand Up @@ -91,6 +92,17 @@ async function runPipeline(
});
}

async function exportTrainedPipeline(
settings: Settings,
exportTrainedPipelineArgs: ExportTrainedPipeline,
updater: DesktopJobUpdater,
): Promise<DesktopJob> {
return viame.exportTrainedPipeline(settings, exportTrainedPipelineArgs, updater, validateFake, {
...ViameWindowsConstants,
setupScriptAbs: `"${npath.join(settings.viamePath, ViameWindowsConstants.setup)}"`,
});
}

async function train(
settings: Settings,
runTrainingArgs: RunTraining,
Expand Down Expand Up @@ -168,6 +180,7 @@ export default {
DefaultSettings,
validateViamePath,
runPipeline,
exportTrainedPipeline,
train,
nvidiaSmi,
initialize,
Expand Down
9 changes: 7 additions & 2 deletions client/platform/desktop/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ export interface RunPipeline {
pipeline: Pipe;
}

export interface ExportTrainedPipeline {
path: string;
pipeline: Pipe;
}

/** TODO promote to apispec */
export interface RunTraining {
// datasets to run training on
Expand All @@ -176,11 +181,11 @@ export interface DesktopJob {
// command that was run
command: string;
// jobType identify type of job
jobType: 'pipeline' | 'training' | 'conversion';
jobType: 'pipeline' | 'training' | 'conversion' | 'export';
// title whatever humans should see this job called
title: string;
// arguments to creation
args: RunPipeline | RunTraining | ConversionArgs;
args: RunPipeline | RunTraining | ExportTrainedPipeline | ConversionArgs;
// datasetIds of the involved datasets
datasetIds: string[];
// pid of the process spawned
Expand Down
16 changes: 15 additions & 1 deletion client/platform/desktop/frontend/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
} from 'dive-common/constants';
import {
DesktopJob, DesktopMetadata, JsonMeta, NvidiaSmiReply,
RunPipeline, RunTraining, ExportDatasetArgs, ExportConfigurationArgs,
RunPipeline, RunTraining, ExportTrainedPipeline, ExportDatasetArgs, ExportConfigurationArgs,
DesktopMediaImportResponse,
} from 'platform/desktop/constants';

Expand Down Expand Up @@ -96,6 +96,14 @@ async function runPipeline(itemId: string, pipeline: Pipe): Promise<DesktopJob>
return ipcRenderer.invoke('run-pipeline', args);
}

async function exportTrainedPipeline(path: string, pipeline: Pipe): Promise<DesktopJob> {
const args: ExportTrainedPipeline = {
path,
pipeline,
};
return ipcRenderer.invoke('export-trained-pipeline', args);
}

async function runTraining(
folderIds: string[],
pipelineName: string,
Expand All @@ -113,6 +121,10 @@ async function runTraining(
return ipcRenderer.invoke('run-training', args);
}

async function deleteTrainedPipeline(pipeline: Pipe): Promise<void> {
return ipcRenderer.invoke('delete-trained-pipeline', pipeline);
}

function importMedia(path: string): Promise<DesktopMediaImportResponse> {
return ipcRenderer.invoke('import-media', { path });
}
Expand Down Expand Up @@ -219,7 +231,9 @@ export {
/* Standard Specification APIs */
loadMetadata,
getPipelineList,
deleteTrainedPipeline,
runPipeline,
exportTrainedPipeline,
getTrainingConfigurations,
runTraining,
saveMetadata,
Expand Down
Loading
Loading