Skip to content

Commit 8f56b73

Browse files
[feature] ONNX Export (#1490)
* Add: ONNX Export The main objective of this commit was to add a 'ONNX export' button to export trained models on the platform. Export to ONNX function is available on Dive Web & Desktop. Added a tab to list, delete and export all trained pipelines. Added the possibility to download single or multiple files from the browser on Dive web Fixed alignment issue on Jobs list on Dive web
1 parent 93c061a commit 8f56b73

File tree

23 files changed

+783
-25
lines changed

23 files changed

+783
-25
lines changed

client/dive-common/apispec.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ interface Pipe {
3434
name: string;
3535
pipe: string;
3636
type: string;
37+
folderId?: string;
38+
ownerId?: string;
39+
ownerLogin?: string;
3740
}
3841

3942
interface Category {
@@ -142,6 +145,8 @@ interface DatasetMeta extends DatasetMetaMutable {
142145
interface Api {
143146
getPipelineList(): Promise<Pipelines>;
144147
runPipeline(itemId: string, pipeline: Pipe): Promise<unknown>;
148+
deleteTrainedPipeline(pipeline: Pipe): Promise<void>;
149+
exportTrainedPipeline(path: string, pipeline: Pipe): Promise<unknown>;
145150

146151
getTrainingConfigurations(): Promise<TrainingConfigs>;
147152
runTraining(

client/platform/desktop/backend/ipcService.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ import OS from 'os';
22
import http from 'http';
33
import { ipcMain } from 'electron';
44
import { MultiCamImportArgs } from 'dive-common/apispec';
5+
import type { Pipe } from 'dive-common/apispec';
56
import {
67
DesktopJobUpdate, RunPipeline, RunTraining, Settings, ExportDatasetArgs,
78
DesktopMediaImportResponse,
9+
ExportTrainedPipeline,
810
} from 'platform/desktop/constants';
911

1012
import linux from './native/linux';
@@ -32,6 +34,10 @@ export default function register() {
3234
const ret = await common.getPipelineList(settings.get());
3335
return ret;
3436
});
37+
ipcMain.handle('delete-trained-pipeline', async (event, args: Pipe) => {
38+
const ret = await common.deleteTrainedPipeline(args);
39+
return ret;
40+
})
3541
ipcMain.handle('get-training-configs', async () => {
3642
const ret = await common.getTrainingConfigs(settings.get());
3743
return ret;
@@ -122,6 +128,12 @@ export default function register() {
122128
};
123129
return currentPlatform.runPipeline(settings.get(), args, updater);
124130
});
131+
ipcMain.handle('export-trained-pipeline', async (event, args: ExportTrainedPipeline) => {
132+
const updater = (update: DesktopJobUpdate) => {
133+
event.sender.send('job-update', update);
134+
};
135+
return currentPlatform.exportTrainedPipeline(settings.get(), args, updater);
136+
});
125137
ipcMain.handle('run-training', async (event, args: RunTraining) => {
126138
const updater = (update: DesktopJobUpdate) => {
127139
event.sender.send('job-update', update);

client/platform/desktop/backend/native/common.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
DatasetMetaMutableKeys,
2323
AnnotationSchema,
2424
SaveAttributeTrackFilterArgs,
25+
Pipe,
2526
} from 'dive-common/apispec';
2627
import * as viameSerializers from 'platform/desktop/backend/serializers/viame';
2728
import * as nistSerializers from 'platform/desktop/backend/serializers/nist';
@@ -444,6 +445,16 @@ async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs>
444445
};
445446
}
446447

448+
/**
449+
* delete a trained pipeline
450+
*/
451+
async function deleteTrainedPipeline(pipeline: Pipe): Promise<void> {
452+
if (pipeline.type !== 'trained') throw new Error(`${pipeline.name} is not a trained pipeline`);
453+
454+
const parent = npath.parse(pipeline.pipe).dir;
455+
await fs.remove(parent);
456+
}
457+
447458
/**
448459
* _saveSerialized save pre-serialized tracks to disk
449460
*/
@@ -1196,6 +1207,7 @@ export {
11961207
exportDataset,
11971208
finalizeMediaImport,
11981209
getPipelineList,
1210+
deleteTrainedPipeline,
11991211
getTrainingConfigs,
12001212
getProjectDir,
12011213
getValidatedProjectDir,

client/platform/desktop/backend/native/linux.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {
1313
NvidiaSmiReply,
1414
RunTraining,
1515
DesktopJobUpdater,
16+
ExportTrainedPipeline,
1617
} from 'platform/desktop/constants';
1718
import { observeChild } from 'platform/desktop/backend/native/processManager';
1819
import * as viame from './viame';
@@ -88,6 +89,17 @@ async function runPipeline(
8889
});
8990
}
9091

92+
async function exportTrainedPipeline(
93+
settings: Settings,
94+
exportTrainedPipelineArgs: ExportTrainedPipeline,
95+
updater: DesktopJobUpdater,
96+
): Promise<DesktopJob> {
97+
return viame.exportTrainedPipeline(settings, exportTrainedPipelineArgs, updater, validateViamePath, {
98+
...ViameLinuxConstants,
99+
setupScriptAbs: sourceString(settings),
100+
});
101+
}
102+
91103
async function train(
92104
settings: Settings,
93105
runTrainingArgs: RunTraining,
@@ -132,6 +144,7 @@ export default {
132144
DefaultSettings,
133145
nvidiaSmi,
134146
runPipeline,
147+
exportTrainedPipeline,
135148
train,
136149
validateViamePath,
137150
};

client/platform/desktop/backend/native/utils.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ async function createWorkingDirectory(settings: Settings, jsonMetaList: JsonMeta
101101
return runFolderPath;
102102
}
103103

104+
async function createCustomWorkingDirectory(settings: Settings, prefix: string, pipeline: string) {
105+
const jobFolderPath = path.join(settings.dataPath, JobsFolderName);
106+
// Formating prefix if for any reason the prefix is input by the user in the futur
107+
// eslint-disable-next-line no-useless-escape
108+
const safePrefix = prefix.replace(/[\.\s/]+/g, '_');
109+
const runFolderName = moment().format(`[${safePrefix}_${pipeline}]_MM-DD-yy_hh-mm-ss.SSS`);
110+
const runFolderPath = path.join(jobFolderPath, runFolderName);
111+
if (!fs.existsSync(jobFolderPath)) {
112+
await fs.mkdir(jobFolderPath);
113+
}
114+
await fs.mkdir(runFolderPath);
115+
return runFolderPath;
116+
}
117+
104118
/* same as os.path.splitext */
105119
function splitExt(input: string): [string, string] {
106120
const ext = path.extname(input);
@@ -112,6 +126,7 @@ export {
112126
getBinaryPath,
113127
jobFileEchoMiddleware,
114128
createWorkingDirectory,
129+
createCustomWorkingDirectory,
115130
spawnResult,
116131
splitExt,
117132
};

client/platform/desktop/backend/native/viame.ts

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ import fs from 'fs-extra';
55
import {
66
Settings, DesktopJob, RunPipeline, RunTraining,
77
DesktopJobUpdater,
8+
ExportTrainedPipeline,
89
} from 'platform/desktop/constants';
910
import { cleanString } from 'platform/desktop/sharedUtils';
1011
import { serialize } from 'platform/desktop/backend/serializers/viame';
1112
import { observeChild } from 'platform/desktop/backend/native/processManager';
1213

1314
import { MultiType, stereoPipelineMarker, multiCamPipelineMarkers } from 'dive-common/constants';
1415
import * as common from './common';
15-
import { jobFileEchoMiddleware, createWorkingDirectory } from './utils';
16+
import { jobFileEchoMiddleware, createWorkingDirectory, createCustomWorkingDirectory } from './utils';
1617
import {
1718
getMultiCamImageFiles, getMultiCamVideoPath,
1819
writeMultiCamStereoPipelineArgs,
@@ -212,6 +213,98 @@ async function runPipeline(
212213
return jobBase;
213214
}
214215

216+
/**
217+
* a node.js implementation of dive_tasks.tasks.export_trained_model
218+
*/
219+
async function exportTrainedPipeline(settings: Settings,
220+
exportTrainedPipelineArgs: ExportTrainedPipeline,
221+
updater: DesktopJobUpdater,
222+
validateViamePath: (settings: Settings) => Promise<true | string>,
223+
viameConstants: ViameConstants,
224+
): Promise<DesktopJob> {
225+
const { path, pipeline } = exportTrainedPipelineArgs;
226+
227+
const isValid = await validateViamePath(settings);
228+
if (isValid !== true) {
229+
throw new Error(isValid);
230+
}
231+
232+
const exportPipelinePath = npath.join(settings.viamePath, PipelineRelativeDir, "convert_to_onnx.pipe");
233+
if (!fs.existsSync(npath.join(exportPipelinePath))) {
234+
throw new Error("Your VIAME version doesn't support ONNX export. You have to update it to a newer version to be able to export models.");
235+
}
236+
237+
const modelPipelineDir = npath.parse(pipeline.pipe).dir;
238+
let weightsPath: string;
239+
if (fs.existsSync(npath.join(modelPipelineDir, 'yolo.weights'))) {
240+
weightsPath = npath.join(modelPipelineDir, 'yolo.weights');
241+
} else {
242+
throw new Error("Your pipeline has no trained weights (yolo.weights is missing)");
243+
}
244+
245+
const jobWorkDir = await createCustomWorkingDirectory(settings, 'OnnxExport', pipeline.name);
246+
247+
const converterOutput = npath.join(jobWorkDir, 'model.onnx');
248+
const joblog = npath.join(jobWorkDir, 'runlog.txt');
249+
250+
const command = [
251+
`${viameConstants.setupScriptAbs} &&`,
252+
`"${viameConstants.kwiverExe}" runner ${exportPipelinePath}`,
253+
`-s "onnx_convert:model_path=${weightsPath}"`,
254+
`-s "onnx_convert:onnx_model_prefix=${converterOutput}"`,
255+
];
256+
257+
const job = observeChild(spawn(command.join(' '), {
258+
shell: viameConstants.shell,
259+
cwd: jobWorkDir,
260+
}));
261+
262+
const jobBase: DesktopJob = {
263+
key: `pipeline_${job.pid}_${jobWorkDir}`,
264+
command: command.join(' '),
265+
jobType: 'export',
266+
pid: job.pid,
267+
args: exportTrainedPipelineArgs,
268+
title: `${exportTrainedPipelineArgs.pipeline.name} to ONNX`,
269+
workingDir: jobWorkDir,
270+
datasetIds: [],
271+
exitCode: job.exitCode,
272+
startTime: new Date(),
273+
};
274+
275+
fs.writeFile(npath.join(jobWorkDir, DiveJobManifestName), JSON.stringify(jobBase, null, 2));
276+
277+
updater({
278+
...jobBase,
279+
body: [''],
280+
});
281+
282+
job.stdout.on('data', jobFileEchoMiddleware(jobBase, updater, joblog));
283+
job.stderr.on('data', jobFileEchoMiddleware(jobBase, updater, joblog));
284+
285+
job.on('exit', async (code) => {
286+
if (code === 0) {
287+
if (fs.existsSync(converterOutput)) {
288+
if (fs.existsSync(path)) {
289+
fs.unlinkSync(path);
290+
}
291+
// We move instead of copying because .onnx files can be huge
292+
fs.moveSync(converterOutput, path);
293+
} else {
294+
console.error("An error occured while creating the ONNX file.");
295+
}
296+
}
297+
updater({
298+
...jobBase,
299+
body: [''],
300+
exitCode: code,
301+
endTime: new Date(),
302+
});
303+
});
304+
305+
return jobBase;
306+
}
307+
215308
/**
216309
* a node.js implementation of dive_tasks.tasks.run_training
217310
*/
@@ -356,5 +449,6 @@ async function train(
356449

357450
export {
358451
runPipeline,
452+
exportTrainedPipeline,
359453
train,
360454
};

client/platform/desktop/backend/native/windows.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
Settings, SettingsCurrentVersion,
1313
DesktopJob, RunPipeline, NvidiaSmiReply, RunTraining,
1414
DesktopJobUpdater,
15+
ExportTrainedPipeline,
1516
} from 'platform/desktop/constants';
1617
import * as viame from './viame';
1718

@@ -91,6 +92,17 @@ async function runPipeline(
9192
});
9293
}
9394

95+
async function exportTrainedPipeline(
96+
settings: Settings,
97+
exportTrainedPipelineArgs: ExportTrainedPipeline,
98+
updater: DesktopJobUpdater,
99+
): Promise<DesktopJob> {
100+
return viame.exportTrainedPipeline(settings, exportTrainedPipelineArgs, updater, validateFake, {
101+
...ViameWindowsConstants,
102+
setupScriptAbs: `"${npath.join(settings.viamePath, ViameWindowsConstants.setup)}"`,
103+
});
104+
}
105+
94106
async function train(
95107
settings: Settings,
96108
runTrainingArgs: RunTraining,
@@ -168,6 +180,7 @@ export default {
168180
DefaultSettings,
169181
validateViamePath,
170182
runPipeline,
183+
exportTrainedPipeline,
171184
train,
172185
nvidiaSmi,
173186
initialize,

client/platform/desktop/constants.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ export interface RunPipeline {
154154
pipeline: Pipe;
155155
}
156156

157+
export interface ExportTrainedPipeline {
158+
path: string;
159+
pipeline: Pipe;
160+
}
161+
157162
/** TODO promote to apispec */
158163
export interface RunTraining {
159164
// datasets to run training on
@@ -179,11 +184,11 @@ export interface DesktopJob {
179184
// command that was run
180185
command: string;
181186
// jobType identify type of job
182-
jobType: 'pipeline' | 'training' | 'conversion';
187+
jobType: 'pipeline' | 'training' | 'conversion' | 'export';
183188
// title whatever humans should see this job called
184189
title: string;
185190
// arguments to creation
186-
args: RunPipeline | RunTraining | ConversionArgs;
191+
args: RunPipeline | RunTraining | ExportTrainedPipeline | ConversionArgs;
187192
// datasetIds of the involved datasets
188193
datasetIds: string[];
189194
// pid of the process spawned

client/platform/desktop/frontend/api.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import {
1818
} from 'dive-common/constants';
1919
import {
2020
DesktopJob, DesktopMetadata, JsonMeta, NvidiaSmiReply,
21-
RunPipeline, RunTraining, ExportDatasetArgs, ExportConfigurationArgs,
21+
RunPipeline, RunTraining, ExportTrainedPipeline, ExportDatasetArgs, ExportConfigurationArgs,
2222
DesktopMediaImportResponse,
2323
} from 'platform/desktop/constants';
2424

@@ -96,6 +96,14 @@ async function runPipeline(itemId: string, pipeline: Pipe): Promise<DesktopJob>
9696
return ipcRenderer.invoke('run-pipeline', args);
9797
}
9898

99+
async function exportTrainedPipeline(path: string, pipeline: Pipe): Promise<DesktopJob> {
100+
const args: ExportTrainedPipeline = {
101+
path,
102+
pipeline,
103+
};
104+
return ipcRenderer.invoke('export-trained-pipeline', args);
105+
}
106+
99107
async function runTraining(
100108
folderIds: string[],
101109
pipelineName: string,
@@ -113,6 +121,10 @@ async function runTraining(
113121
return ipcRenderer.invoke('run-training', args);
114122
}
115123

124+
async function deleteTrainedPipeline(pipeline: Pipe): Promise<void> {
125+
return ipcRenderer.invoke('delete-trained-pipeline', pipeline);
126+
}
127+
116128
function importMedia(path: string): Promise<DesktopMediaImportResponse> {
117129
return ipcRenderer.invoke('import-media', { path });
118130
}
@@ -219,7 +231,9 @@ export {
219231
/* Standard Specification APIs */
220232
loadMetadata,
221233
getPipelineList,
234+
deleteTrainedPipeline,
222235
runPipeline,
236+
exportTrainedPipeline,
223237
getTrainingConfigurations,
224238
runTraining,
225239
saveMetadata,

0 commit comments

Comments
 (0)