diff --git a/msed/predict_events.py b/msed/predict_events.py index 521d7ca..0c6f7aa 100644 --- a/msed/predict_events.py +++ b/msed/predict_events.py @@ -146,7 +146,7 @@ def predict_events( for ev in window_events: start = int(ev[0] * window_size) + window_idx * stride stop = int(ev[1] * window_size) + window_idx * stride - prediction_mask[ev[-1] - 1, start:stop] = 1 + prediction_mask[ev[-1], start:stop] = 1 for ev, p in zip(class_names, prediction_mask): predictions[subject_id][ev] = binary_to_array(p) n_event = len(predictions[subject_id][ev])