From 04de8186aa9434cea12820846d27124c91080598 Mon Sep 17 00:00:00 2001 From: Michael Klear Date: Thu, 1 Jun 2023 12:19:08 -0700 Subject: [PATCH 1/3] add greedy move --- markovify/chain.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/markovify/chain.py b/markovify/chain.py index b19e8e2..c7a8740 100644 --- a/markovify/chain.py +++ b/markovify/chain.py @@ -118,6 +118,19 @@ def move(self, state): r = random.random() * cumdist[-1] selection = choices[bisect.bisect(cumdist, r)] return selection + + def greedy_move(self, state): + """ + Given a state, choose the most likely next item + """ + if self.compiled: + choices, _ = self.model[state] + elif state == tuple([BEGIN] * self.state_size): + choices = self.begin_choices + else: + choices, weights = zip(*self.model[state].items()) + selection = choices[-1] + return selection def gen(self, init_state=None): """ From 6cdc8775d8e1dbba3559a080826eb5dbcc7999ba Mon Sep 17 00:00:00 2001 From: Michael Klear Date: Thu, 1 Jun 2023 14:33:18 -0700 Subject: [PATCH 2/3] fix greedy algo --- markovify/chain.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/markovify/chain.py b/markovify/chain.py index c7a8740..bc2e6bf 100644 --- a/markovify/chain.py +++ b/markovify/chain.py @@ -124,13 +124,18 @@ def greedy_move(self, state): Given a state, choose the most likely next item """ if self.compiled: - choices, _ = self.model[state] + choices, cumdist = self.model[state] elif state == tuple([BEGIN] * self.state_size): choices = self.begin_choices + cumdist = self.begin_cumdist else: choices, weights = zip(*self.model[state].items()) - selection = choices[-1] - return selection + cumdist = list(accumulate(weights)) + # r = 0 * cumdist[-1] + # selection_idx = bisect.bisect(cumdist, r) + # print(selection_idx) + # selection = choices[selection_idx] + return choices[0] def gen(self, init_state=None): """ From 084dfe239cfb3c714ee4108d8b992a45eb39e6ec Mon Sep 17 00:00:00 2001 From: Michael Klear Date: Thu, 1 Jun 2023 14:33:32 -0700 Subject: [PATCH 3/3] remove commented code --- markovify/chain.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/markovify/chain.py b/markovify/chain.py index bc2e6bf..fef8ae5 100644 --- a/markovify/chain.py +++ b/markovify/chain.py @@ -131,10 +131,6 @@ def greedy_move(self, state): else: choices, weights = zip(*self.model[state].items()) cumdist = list(accumulate(weights)) - # r = 0 * cumdist[-1] - # selection_idx = bisect.bisect(cumdist, r) - # print(selection_idx) - # selection = choices[selection_idx] return choices[0] def gen(self, init_state=None):