@@ -937,6 +937,10 @@ def train(config: dict[str, t.Any]):
937937 if eval_metrics :
938938 mlflow .log_metric ("val loss" , eval_metrics ["loss" ], step = epoch )
939939 mlflow .log_metric ("val accuracy" , eval_metrics ["top1" ], step = epoch )
940+ for vr in utils .EVAL_VERIFICATION_RATES :
941+ mlflow .log_metric (f"FA@{ int (100 * vr )} " , eval_metrics [f"fa@{ int (100 * vr )} " ])
942+ mlflow .log_metric (f"AFA@{ int (100 * vr )} " , eval_metrics [f"afa@{ int (100 * vr )} " ])
943+
940944
941945 if output_dir is not None :
942946 lrs = [param_group ['lr' ] for param_group in optimizer .param_groups ]
@@ -1152,6 +1156,7 @@ def validate(
11521156 losses_m = utils .AverageMeter ()
11531157 top1_m = utils .AverageMeter ()
11541158 top5_m = utils .AverageMeter ()
1159+ correct_with_confidences_m = utils .CorrectnessOfPredictionsWithConfidencesMeter ()
11551160
11561161 model .eval ()
11571162
@@ -1193,6 +1198,7 @@ def validate(
11931198 losses_m .update (reduced_loss .item (), input .size (0 ))
11941199 top1_m .update (acc1 .item (), output .size (0 ))
11951200 top5_m .update (acc5 .item (), output .size (0 ))
1201+ correct_with_confidences_m .update (output , target )
11961202
11971203 batch_time_m .update (time .time () - end )
11981204 end = time .time ()
@@ -1206,7 +1212,32 @@ def validate(
12061212 f'Acc@5: { top5_m .val :>7.3f} ({ top5_m .avg :>7.3f} )'
12071213 )
12081214
1209- metrics = OrderedDict ([('loss' , losses_m .avg ), ('top1' , top1_m .avg ), ('top5' , top5_m .avg )])
1215+ metrics = OrderedDict (
1216+ [
1217+ ("loss" , losses_m .avg ),
1218+ ("top1" , top1_m .avg ),
1219+ ("top5" , top5_m .avg ),
1220+ * [
1221+ (f"fa@{ int (vr * 100 )} " , fa )
1222+ for vr , fa in zip (
1223+ utils .EVAL_VERIFICATION_RATES ,
1224+ correct_with_confidences_m .final_accuracy (
1225+ utils .EVAL_VERIFICATION_RATES
1226+ ),
1227+ )
1228+ ],
1229+ * [
1230+ (f"afa@{ int (vr * 100 )} " , afa )
1231+ for vr , afa in zip (
1232+ utils .EVAL_VERIFICATION_RATES ,
1233+ correct_with_confidences_m .average_final_accuracy (
1234+ utils .EVAL_VERIFICATION_RATES
1235+ ),
1236+ )
1237+ ],
1238+ ]
1239+ )
1240+
12101241
12111242 return metrics
12121243
0 commit comments