From ac21b56ebd7b88d8efa68e7bc181ed9da12d9191 Mon Sep 17 00:00:00 2001 From: raruidol Date: Tue, 15 Jul 2025 09:23:25 +0100 Subject: [PATCH] docker version and confidence thresholds --- app/ari.py | 28 +++------------------------- docker-compose.yml | 2 +- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/app/ari.py b/app/ari.py index 71b4a2e..cf20521 100644 --- a/app/ari.py +++ b/app/ari.py @@ -72,11 +72,11 @@ def pipeline_predictions(pipeline, data): outputs = pipeline(pipeline_input) for out in outputs: - if out['label'] == 'Inference' and out['score'] > 0.95: + if out['label'] == 'Inference' and out['score'] > 0.9: labels.append(1) - elif out['label'] == 'Conflict' and out['score'] > 0.8: + elif out['label'] == 'Conflict' and out['score'] > 0.7: labels.append(2) - elif out['label'] == 'Rephrase' and out['score'] > 0.8: + elif out['label'] == 'Rephrase' and out['score'] > 0.7: labels.append(3) else: labels.append(0) @@ -84,28 +84,6 @@ def pipeline_predictions(pipeline, data): return labels -def make_predictions(trainer, tknz_data): - predicted_logprobs = trainer.predict(tknz_data) - ''' - predicted_labels = np.argmax(predicted_logprobs.predictions, axis=-1) - - return predicted_labels - ''' - labels = [] - for sample in predicted_logprobs.predictions: - torch_logits = torch.from_numpy(sample) - probabilities = F.softmax(torch_logits, dim=-1).numpy() - valid_check = probabilities > 0.95 - if True in valid_check: - labels.append(np.argmax(sample, axis=-1)) - elif np.argmax(sample, axis=-1) == 2 and probabilities[2] > 0.8: - labels.append(2) - else: - labels.append(-1) - - return labels - - def output_xaif(idents, labels, fileaif): original_aif = xaif.AIF(fileaif) diff --git a/docker-compose.yml b/docker-compose.yml index 32ed09e..36551ef 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '1' +version: '3' services: amf_ari: