Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions scripts/export_sentencebert/export_cedr_to_sentencebert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python

import argparse
import os
import json



#from flexneuart.io.json import read_json, save_json
import flexneuart.config
import flexneuart.io.train_data
from flexneuart.config import TQDM_FILE
from flexneuart.io.runs import read_run_dict, write_run_dict
from flexneuart.io.qrels import read_qrels_dict
from flexneuart.models.train.batching import TrainSamplerFixedChunkSize


def main_cli():
parser = argparse.ArgumentParser('conversion from cedr to sentencebert format')

parser.add_argument('--datafiles', metavar='data files', help='data files: docs & queries',
type=str, nargs='+', required=True)

parser.add_argument('--qrels', metavar='QREL file', help='QREL file',
type=str, required=True)

parser.add_argument('--train_pairs', metavar='paired train data', help='paired train data',
type=str, required=True)

parser.add_argument('--valid_run', metavar='validation file', help='validation file',
type=str, required=True)

parser.add_argument('--output_dir_name', metavar='Folder containing the training data in sentence bert format', help='SentenceBERT training data',
type=str, required=True)

args = parser.parse_args()

# Create the directory to store the sentence bert
os.makedirs(args.output_dir_name, exist_ok=True)

dataset = flexneuart.io.train_data.read_datafiles(args.datafiles)
qrelf = args.qrels
qrels = read_qrels_dict(qrelf)
train_pairs_all = flexneuart.io.train_data.read_pairs_dict(args.train_pairs)
valid_run = read_run_dict(args.valid_run)

train_sampler = TrainSamplerFixedChunkSize(train_pairs=train_pairs_all,
neg_qty_per_query=7,
qrels=qrels,
epoch_repeat_qty=1,
do_shuffle=False)

with open(args.output_dir_name + '/train.jsonl', "w") as train_file:
for sample in train_sampler:
d = {
'qid': sample.qid,
'pos_id': str(sample.pos_id),
'pos_id_score' : float(sample.pos_id_score),
'neg_ids' : [str(s) for s in sample.neg_ids],
'neg_ids_score': [float(sc) for sc in sample.neg_id_scores]
}
train_file.write(json.dumps(d) + '\n')



if __name__ == '__main__':
main_cli()
53 changes: 53 additions & 0 deletions scripts/export_sentencebert/export_cedr_to_sentencebert.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash -e

source ./common_proc.sh
source ./config.sh

checkVarNonEmpty "COLLECT_ROOT"
checkVarNonEmpty "DERIVED_DATA_SUBDIR"
checkVarNonEmpty "SAMPLE_COLLECT_ARG"

parseArguments $@

usageMain="<collection> <train data subdir (relative to derived data)> <output data subdir (relative to derived data)>"

if [ "$help" = "1" ] ; then
genUsage "$usageMain"
exit 1
fi

collect=${posArgs[0]}

if [ "$collect" = "" ] ; then
genUsage "$usageMain" "Specify $SAMPLE_COLLECT_ARG (1st arg)"
exit 1
fi

derivedDataDir="$COLLECT_ROOT/$collect/$DERIVED_DATA_SUBDIR"

trainSubDir=${posArgs[1]}

if [ "$trainSubDir" = "" ] ; then
genUsage "$usageMain" "Specify cedr training data subdir relative to $derivedDataDir (2nd arg)"
exit 1
fi

outSubDir=${posArgs[2]}

if [ "$outSubDir" = "" ] ; then
genUsage "$usageMain" "Specify sentence bert training data subdir relative to $derivedDataDir (3rd arg)"
exit 1
fi

echo "=========================================================================="

set -o pipefail
trainDir="$derivedDataDir/$trainSubDir"
sbertDir="$derivedDataDir/$outSubDir"

python -u ./export_sentencebert/export_cedr_to_sentencebert.py --datafiles "$trainDir/data_query.tsv" \
"$trainDir/data_docs.tsv" \
--train_pairs "$trainDir/train_pairs.tsv" \
--valid_run "$trainDir/test_run.txt" \
--qrels "$trainDir/qrels.txt" \
--output_dir_name $sbertDir
Loading