-
Notifications
You must be signed in to change notification settings - Fork 30
ONNX to Tensorflow.js conversion of GPT-2 #927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
55538d7
de946dd
29b1d4b
464ff8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,50 +1,102 @@ | ||||||||||||||||||||||
| // import fs from 'fs'; | ||||||||||||||||||||||
| import fsPromise from 'node:fs/promises'; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import { dirname } from 'path'; | ||||||||||||||||||||||
| import { fileURLToPath } from 'url'; | ||||||||||||||||||||||
| import { parse } from 'ts-command-line-args' | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import '@tensorflow/tfjs-node'; | ||||||||||||||||||||||
| import fs from 'node:fs'; | ||||||||||||||||||||||
| import path from 'node:path'; | ||||||||||||||||||||||
| import { Tokenizer, models } from '@epfml/discojs'; | ||||||||||||||||||||||
| import { models, serialization, Tokenizer } from '@epfml/discojs'; | ||||||||||||||||||||||
| import { loadHellaSwag } from '@epfml/discojs-node'; | ||||||||||||||||||||||
| // import { AutoTokenizer } from '@xenova/transformers'; | ||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. commented |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt'); | ||||||||||||||||||||||
| const logLines: string[] = []; | ||||||||||||||||||||||
| const __dirname = dirname(fileURLToPath(import.meta.url)); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const logLines: string[] = []; | ||||||||||||||||||||||
| function log(message: string) { | ||||||||||||||||||||||
| console.log(message); | ||||||||||||||||||||||
| logLines.push(message); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
Comment on lines
+17
to
21
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need a log system for the CLI, we can simply output to the console, no? |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async function evaluateTFJS(tokenizer: Tokenizer) { | ||||||||||||||||||||||
| const model = new models.GPT({ seed: 42 }); | ||||||||||||||||||||||
| log('Evaluating TFJS GPT on HellaSwag...'); | ||||||||||||||||||||||
| async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) { | ||||||||||||||||||||||
| const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints) | ||||||||||||||||||||||
| const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2'); | ||||||||||||||||||||||
| log('Starting the HellaSwag benchmark...'); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const start = Date.now(); | ||||||||||||||||||||||
| const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false); | ||||||||||||||||||||||
| const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true); | ||||||||||||||||||||||
| const duration = ((Date.now() - start) / 1000).toFixed(2); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`); | ||||||||||||||||||||||
| log(`TFJS GPT Evaluation Time: ${duration} seconds`); | ||||||||||||||||||||||
| log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`); | ||||||||||||||||||||||
| log(`Evaluation Time: ${duration} seconds`); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async function evaluateXenova(tokenizer: Tokenizer) { | ||||||||||||||||||||||
| const model = await models.ONNXModel.init_pretrained('Xenova/gpt2'); | ||||||||||||||||||||||
| log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...'); | ||||||||||||||||||||||
| const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const; | ||||||||||||||||||||||
| type ModelType = typeof ModelTypes[number]; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const start = Date.now(); | ||||||||||||||||||||||
| const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false); | ||||||||||||||||||||||
| const duration = ((Date.now() - start) / 1000).toFixed(2); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`); | ||||||||||||||||||||||
| log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`); | ||||||||||||||||||||||
| interface HellaSwagArgs { | ||||||||||||||||||||||
| model: ModelType | ||||||||||||||||||||||
| numDataPoints: number | ||||||||||||||||||||||
| logFile: string | ||||||||||||||||||||||
| pretrainedModelPath: string | ||||||||||||||||||||||
| help?: boolean | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async function main(): Promise<void> { | ||||||||||||||||||||||
| fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file | ||||||||||||||||||||||
| const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json") | ||||||||||||||||||||||
| const args = parse<HellaSwagArgs>({ | ||||||||||||||||||||||
| model: { | ||||||||||||||||||||||
| type: (raw: string) => raw as ModelType, | ||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. casting isn't nice, especially for user interfaces, better to implement so value checking with a |
||||||||||||||||||||||
| description: `Model type, one of ${ModelTypes.toString()}`, | ||||||||||||||||||||||
| defaultValue: 'onnx' | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| numDataPoints: { | ||||||||||||||||||||||
| type: Number, | ||||||||||||||||||||||
| description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark', | ||||||||||||||||||||||
| defaultValue: -1 | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| logFile: { | ||||||||||||||||||||||
| type: String, | ||||||||||||||||||||||
| description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log' | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| pretrainedModelPath: { | ||||||||||||||||||||||
| type: String, | ||||||||||||||||||||||
| description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model', | ||||||||||||||||||||||
| defaultValue: defaultPretrainedModelPath | ||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. single use, we can inline it IMO |
||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| help: { | ||||||||||||||||||||||
| type: Boolean, | ||||||||||||||||||||||
| optional: true, | ||||||||||||||||||||||
| alias: 'h', | ||||||||||||||||||||||
| description: 'Prints this usage guide' | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }, { helpArg: 'help' }) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2'); | ||||||||||||||||||||||
| await evaluateTFJS(tokenizer); | ||||||||||||||||||||||
| log('\n---\n'); | ||||||||||||||||||||||
| await evaluateXenova(tokenizer); | ||||||||||||||||||||||
| const logFile = path.join(__dirname, args.logFile); | ||||||||||||||||||||||
| fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let model: models.GPT | models.ONNXModel | undefined; | ||||||||||||||||||||||
| switch (args.model) { | ||||||||||||||||||||||
| case 'onnx': | ||||||||||||||||||||||
| log("Using ONNX pretrained model Xenova/gpt2") | ||||||||||||||||||||||
| model = await models.ONNXModel.init_pretrained('Xenova/gpt2'); | ||||||||||||||||||||||
| break; | ||||||||||||||||||||||
| case 'gpt-tfjs-random': | ||||||||||||||||||||||
| log("Using GPT-TFJS with random initialization") | ||||||||||||||||||||||
| model = new models.GPT({ seed: 42 }); | ||||||||||||||||||||||
| break; | ||||||||||||||||||||||
| case 'gpt-tfjs-pretrained': | ||||||||||||||||||||||
|
Comment on lines
+86
to
+90
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that was confusing
Suggested change
|
||||||||||||||||||||||
| log("Using GPT-TFJS with pretrained weights") | ||||||||||||||||||||||
| if (args.pretrainedModelPath === undefined) { | ||||||||||||||||||||||
| throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath") | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| const encodedModel = await fsPromise.readFile(args.pretrainedModelPath); | ||||||||||||||||||||||
| model = await serialization.model.decode(encodedModel) as models.GPT; | ||||||||||||||||||||||
| break; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| await evaluateModel(model, args.numDataPoints); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8'); | ||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no sync write, we are in an async function so better to use |
||||||||||||||||||||||
| console.log(`\nResults written to ${logFile}`); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed as the dataset is fetched everytime (I would argue against doing that) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,3 +20,6 @@ | |
|
|
||
| # GDHF demo | ||
| /tinder_dog/ | ||
|
|
||
| # HellaSwag benchmark | ||
| hellaswag* | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,10 @@ export const titanic: TaskProvider<"tabular", "federated"> = { | |
| title: 'Titanic Prediction', | ||
| summary: { | ||
| preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.", | ||
| overview: "The original competition can be found on <a target='_blank' class='underline text-blue-400' href='https://www.kaggle.com/c/titanic'>Kaggle</a> and a link to the training set can be found here <a target='_blank' class='underline text-blue-400' href='https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv'>here</a>." | ||
| overview: "" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no text? |
||
| }, | ||
| model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).', | ||
| dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br>The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked"<br>Each subsequent row contains passenger data.', | ||
| dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc. The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked". Each subsequent row contains passenger data.', | ||
| dataExample: [ | ||
| { name: "PassengerId", data: "1" }, | ||
| { name: "Survived", data: "0" }, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented