diff --git a/glove/glove.py b/glove/glove.py index ec90ca3..7068bd4 100644 --- a/glove/glove.py +++ b/glove/glove.py @@ -307,3 +307,21 @@ def most_similar_paragraph(self, paragraph, number=5, **kwargs): paragraph_vector = self.transform_paragraph(paragraph, **kwargs) return self._similarity_query(paragraph_vector, number) + + def return_word_vector(self, word): + """ + returns glove vector corresponding to word + """ + + if self.word_vectors is None: + raise Exception('Model must be fit before querying') + + if self.dictionary is None: + raise Exception('No word dictionary supplied') + + try: + word_idx = self.dictionary[word] + except KeyError: + raise Exception('Word not in dictionary') + + return self.word_vectors[word_idx]