Skip to content

v-v1150n/HW_Natural_Language_Generation

Repository files navigation

Step 1: Setup Environment

1. Create a Virtual Environment

conda create -n ADL_ENV python=3.8
conda activate ADL_ENV

2. Install Required Packages

pip install torch matplotlib evaluate datasets transformers
pip install absl-py nltk rouge-score sentencepiece
pip install transformers[torch] accelerate

Step 2: Model Training

Run the training script run_mt5_train.py with training data train.jsonl and validation data val_split.jsonl (split by yourself), and save the trained model.

python run_mt5_train.py --train_path train.jsonl --validation_path val_split.jsonl --model_name google/mt5-small --output_dir ./mt5-finetuned_output

Step 3: Model Inference

After training the model (or using download.sh to download a pre-trained model), provide the path to the public.jsonl file, and output the prediction results to prediction.jsonl.

bash download.sh 
bash run.sh data/public.jsonl prediction.jsonl

Step 4: Evaluate Prediction Results

Provide the paths of the prediction results prediction.jsonl and validation data public.jsonl, run eval.py for evaluation, and output the evaluation results to rouge_scores.json. This file contains the metrics for rouge-1, rouge-2, and rouge-l.

python eval.py --reference data/public.jsonl --submission prediction.jsonl

Step 5: Plot ROUGE Curve

Ensure that the rouge_scores.json file is correctly saved, then run plot_curve.py to generate the ROUGE curve plot.

python plot_curve.py

About

Applied Deep Learning Natural Language Generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors