From 1be0f7a389aa213d3fb9cdf20c0b70331d002bc7 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Thu, 8 Dec 2022 22:05:39 +0800 Subject: [PATCH 1/2] handle n > m --- km_matcher.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/km_matcher.py b/km_matcher.py index 5cb0c5f..188d67b 100644 --- a/km_matcher.py +++ b/km_matcher.py @@ -12,9 +12,15 @@ def __init__(self, weights): weights = np.array(weights).astype(np.float32) self.weights = weights self.n, self.m = weights.shape + self.is_transpose = False + if self.n > self.m: + self.weights = self.weights.T + self.n, self.m = self.weights.shape + self.is_transpose = True assert self.n <= self.m + # init label - self.label_x = np.max(weights, axis=1) + self.label_x = np.max(self.weights, axis=1) self.label_y = np.zeros((self.m, ), dtype=np.float32) self.max_match = 0 @@ -80,20 +86,29 @@ def find_augment_path(self): queue.append(x) self.add_to_tree(self.yx[y], x) - def solve(self, verbose = False): + def _solve(self, verbose = False): while self.max_match < self.n: x, y = self.find_augment_path() self.do_augment(x, y) - sum = 0. + sum_ = 0. + matches = [] for x in range(self.n): if verbose: print('match {} to {}, weight {:.4f}'.format(x, self.xy[x], self.weights[x, self.xy[x]])) - sum += self.weights[x, self.xy[x]] - self.best = sum + sum_ += self.weights[x, self.xy[x]] + matches.append((x, self.xy[x])) + self.best = sum_ if verbose: - print('ans: {:.4f}'.format(sum)) - return sum + print('ans: {:.4f}'.format(sum_)) + return sum_, matches + + + def solve(self, verbose=False): + sum_, matches = self._solve(verbose=verbose) + if self.is_transpose: + matches = sorted([(y, x) for x, y in matches]) + return sum_, matches, self.is_transpose def add_to_tree(self, x, prevx): From 23a5462cc1d65f64ed752e2d335cc04cf8815491 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Thu, 8 Dec 2022 22:09:01 +0800 Subject: [PATCH 2/2] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b05abb..5fbeeb6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ or using default setting ``` weights = np.random.randn(n, m) matcher = KMMatcher(weights) -best = matcher.solve() +best, all_matches, _ = matcher.solve() ``` ### Performance