diff --git a/matchmaker/dp/oltw_dixon.py b/matchmaker/dp/oltw_dixon.py index 38905e8..8c55a05 100644 --- a/matchmaker/dp/oltw_dixon.py +++ b/matchmaker/dp/oltw_dixon.py @@ -74,6 +74,9 @@ def __init__( max_run_count=MAX_RUN_COUNT, frame_per_seg=FRAME_PER_SEG, frame_rate=FRAME_RATE, + state_to_ref_time_map = None, + ref_to_state_time_map = None, + state_space = None, **kwargs, ): super().__init__(reference_features=reference_features) @@ -84,6 +87,9 @@ def __init__( self.distance_func = distance_func.lower() self.max_run_count = max_run_count self.frame_per_seg = frame_per_seg + self.state_to_ref_time_map = state_to_ref_time_map + self.ref_to_state_time_map = ref_to_state_time_map + self.state_space = state_space self.reset() def reset(self): diff --git a/matchmaker/matchmaker.py b/matchmaker/matchmaker.py index cb21790..776e081 100644 --- a/matchmaker/matchmaker.py +++ b/matchmaker/matchmaker.py @@ -326,11 +326,16 @@ def __init__( state_space=np.unique(self.score_part.note_array()["onset_beat"]) ) elif method == "dixon": + state_to_ref_time_map, ref_to_state_time_map = self.get_time_maps() self.score_follower = OnlineTimeWarpingDixon( reference_features=self.reference_features, queue=self.stream.queue, distance_func=distance_func, frame_rate=self.frame_rate, + window_size=self.config["window_size"], + state_to_ref_time_map=state_to_ref_time_map, + ref_to_state_time_map=ref_to_state_time_map, + state_space=np.unique(self.score_part.note_array()["onset_beat"]) ) elif method == "hmm" and self.input_type == "midi": self.score_follower = PitchIOIHMM(