diff --git a/hw_boolean_search.py b/hw_boolean_search.py index d98b8cc..0c4a603 100644 --- a/hw_boolean_search.py +++ b/hw_boolean_search.py @@ -4,32 +4,125 @@ import argparse import codecs import sys +from nltk.stem import SnowballStemmer + +import json +import string + class Index: def __init__(self, index_file): - # TODO: build index - pass + stemmer = SnowballStemmer("russian") + + self.index: dict[str, set[str]] = {} + f = codecs.open(index_file, encoding="utf-8", mode="r") + + for i, ln in enumerate(f): + if i % 200 == 0: + print(f"index {i}") + sentence = ln.strip().translate(str.maketrans(string.punctuation, ' '*len(string.punctuation))).split() + for i in range(1, len(sentence)): + word = stemmer.stem(sentence[i].lower()) + if word not in self.index: + self.index[word] = set[str]() + + self.index[word].add(sentence[0]) + + f.close() class QueryTree: def __init__(self, qid, query): - # TODO: parse query and create query tree - pass + self._stemmer = SnowballStemmer("russian") + self._request: list[str] = [] + + tmp = "" + + for c in query.lower(): + match c: + case " " | "(" | ")" | "|": + if len(tmp) > 0: + self._request.append(tmp) + + tmp = "" + self._request.append(str(c)) + + case other: + tmp += c + + if len(tmp) > 0: + self._request.append(tmp) + + + def _get(self): + if self._i < len(self._request): + self._c = self._request[self._i] + self._i += 1 + else: + self._c = "!" def search(self, index): - # TODO: lookup query terms in the index and implement boolean search logic - pass + self._i = 0 + self._get() + + result = self._or(index) + + return result + + def _or(self, index): + result = self._and(index) + + while self._c == "|": + self._get() + result = result | self._and(index) + + return result + + def _and(self, index): + result = self._token(index) + + while self._c == " ": + self._get() + result = result & self._token(index) + + return result + + + def _token(self, index): + result = set() + + if self._c == "(": + self._get() + result = self._or(index) + if self._c != ")": + raise ValueError('Unmatched bracket') + else: + if self._stemmer.stem(self._c) in index.index: + result = index.index[self._stemmer.stem(self._c)] + + self._get() + + return result + class SearchResults: + def __init__(self): + self._results = [] def add(self, found): - # TODO: add next query's results - pass + self._results.append(found) def print_submission(self, objects_file, submission_file): - # TODO: generate submission file - pass + inp = codecs.open(objects_file, encoding="utf-8", mode="r") + outp = codecs.open(submission_file, encoding="utf-8", mode="w") + + outp.write("ObjectId,Relevance\n") + inp.readline() + for ln in inp: + pair = ln.strip().split(",") + outp.write(f"{pair[0]},{int(pair[2] in self._results[int(pair[1]) - 1])}\n") + def main(): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..34e7188 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +nltk==3.8.1