diff --git a/activity_browser/bwutils/__init__.py b/activity_browser/bwutils/__init__.py index 0f943859c..18839d5ac 100644 --- a/activity_browser/bwutils/__init__.py +++ b/activity_browser/bwutils/__init__.py @@ -13,6 +13,7 @@ from .montecarlo import MonteCarloLCA from .multilca import MLCA, Contributions from .pedigree import PedigreeMatrix +from .searchengine import SearchEngine, MetaDataSearchEngine from .sensitivity_analysis import GlobalSensitivityAnalysis from .superstructure import SuperstructureContributions, SuperstructureMLCA from .uncertainty import (CFUncertaintyInterface, ExchangeUncertaintyInterface, diff --git a/activity_browser/bwutils/metadata.py b/activity_browser/bwutils/metadata.py index 4b43e6c3f..6f96814fa 100644 --- a/activity_browser/bwutils/metadata.py +++ b/activity_browser/bwutils/metadata.py @@ -2,13 +2,13 @@ import itertools import sqlite3 import pickle +import sys from time import time from functools import lru_cache -from typing import Set +from typing import Set, Optional from logging import getLogger from playhouse.shortcuts import model_to_dict - import pandas as pd from qtpy.QtCore import Qt, QObject, Signal, SignalInstance @@ -17,6 +17,8 @@ from bw2data.errors import UnknownObject from bw2data.backends import sqlite3_lci_db, ActivityDataset +from activity_browser.bwutils.searchengine import MetaDataSearchEngine + from activity_browser import signals @@ -65,6 +67,12 @@ def __init__(self, parent=None): self.moveToThread(application.thread()) self.connect_signals() + self.search_engine_whitelist = [ + "id", "name", "synonyms", "unit", "key", "database", # generic + "CAS number", "categories", # biosphere specific + "product", "reference product", "classifications", "location", "properties" # activity specific + ] + def connect_signals(self): signals.project.changed.connect(self.sync) signals.node.changed.connect(self.on_node_changed) @@ -74,11 +82,32 @@ def connect_signals(self): def on_node_deleted(self, ds): try: - self.dataframe.drop(ds.key, inplace=True) + self.dataframe = self.dataframe.drop(ds.key) + self.remove_identifier_from_search_engine(ds) self.synced.emit() except KeyError: pass + def remove_identifier_from_search_engine(self, ds): + if not hasattr(self, "search_engine"): + return + data = model_to_dict(ds) + identifier = data["id"] + if identifier in self.search_engine.database_id_manager(data["database"]): + self.search_engine.remove_identifier(identifier) + self.search_engine.reset_database_id_manager() + + def remove_identifiers_from_search_engine(self, identifiers): + if not hasattr(self, "search_engine"): + return + t = time() + for identifier in identifiers: + self.search_engine.remove_identifier(identifier, logging=False) + self.search_engine.reset_database_id_manager() + log.debug(f"Search index updated in {time() - t:.2f} seconds " + f"for {len(identifiers)} removed items " + f"({len(self.search_engine.df)} items ({self.search_engine.size_of_index()}) currently).") + def on_node_changed(self, new, old): data_raw = model_to_dict(new) data = data_raw.pop("data") @@ -96,13 +125,32 @@ def on_node_changed(self, new, old): for col in [col for col in data.columns if col not in self.dataframe.columns]: self.dataframe[col] = pd.NA self.dataframe.loc[new.key] = data.loc[new.key] + self.change_identifier_in_search_engine(identifier=data.loc[new.key, "id"], data=data.loc[[new.key]]) elif self.dataframe.empty: # an activity has been added and the dataframe was empty self.dataframe = data + self.add_identifier_to_search_engine(data) else: # an activity has been added and needs to be concatenated to existing metadata self.dataframe = pd.concat([self.dataframe, data], join="outer") + self.add_identifier_to_search_engine(data) self.thread().eventDispatcher().awake.connect(self._emitSyncLater, Qt.ConnectionType.UniqueConnection) + def add_identifier_to_search_engine(self, data: pd.DataFrame): + if not hasattr(self, "search_engine"): + return + search_engine_cols = list(set(data.columns) & set(self.search_engine_whitelist)) # intersection becomes columns + data = data[search_engine_cols] + self.search_engine.add_identifier(data.copy()) + self.search_engine.reset_database_id_manager() + + def change_identifier_in_search_engine(self, identifier, data: pd.DataFrame): + if not hasattr(self, "search_engine"): + return + search_engine_cols = list(set(data.columns) & set(self.search_engine_whitelist)) # intersection becomes columns + data = data[search_engine_cols] + self.search_engine.change_identifier(identifier=identifier, data=data.copy()) + self.search_engine.reset_database_id_manager() + @property def databases(self): return set(self.dataframe.get("database", [])) @@ -154,7 +202,10 @@ def sync_databases(self) -> None: for db_name in [x for x in self.databases if x not in bd.databases]: # deleted databases + remove_search_engine = self.dataframe[self.dataframe["database"] == db_name]["id"] self.dataframe.drop(db_name, level=0, inplace=True) + if len(remove_search_engine) > 0: + self.remove_identifiers_from_search_engine(remove_search_engine) sync = True for db_name in [x for x in bd.databases if x not in self.databases]: @@ -167,7 +218,7 @@ def sync_databases(self) -> None: self.dataframe = data else: self.dataframe = pd.concat([self.dataframe, data], join="outer") - + self.add_identifier_to_search_engine(data) sync = True if sync: @@ -183,6 +234,7 @@ def _get_database(self, db_name: str) -> pd.DataFrame | None: def sync(self) -> None: """Deletes metadata when the project is changed.""" + t = time() log.debug("Synchronizing MetaDataStore") con = sqlite3.connect(sqlite3_lci_db._filepath) @@ -191,6 +243,13 @@ def sync(self) -> None: self.dataframe = self._parse_df(node_df) + size_bytes = sys.getsizeof(self.dataframe) + if size_bytes < 1024 ** 3: + size = f"{size_bytes / (1024 ** 2):.1f} MB" + else: + size = f"{size_bytes / (1024 ** 3):.2f} GB" + log.debug(f"MetaDataStore Synchronized in {time() - t:.2f} seconds for {len(self.dataframe)} items ({size}))") + self.init_search() # init search index self.synced.emit() def _parse_df(self, raw_df: pd.DataFrame) -> pd.DataFrame: @@ -343,5 +402,20 @@ def _unpacker(self, classifications: list, system: str) -> list: system_classifications.append(result) # result is either "" or the classification return system_classifications + def init_search(self): + self.search_engine = MetaDataSearchEngine(self.dataframe, identifier_name="id", searchable_columns=self.search_engine_whitelist) + + def db_search(self, query:str, database: Optional[str] = None, return_counter: bool = False, logging: bool = True): + # we do fuzzy search as we re-index results (combining products and activities) for database_products table + # anyway, so including literal results quite literally is a waste of time at this point + return self.search_engine.fuzzy_search(query, database=database, return_counter=return_counter, logging=logging) + + def search(self, query:str): + return self.search_engine.search(query) + + def auto_complete(self, word:str, context: Optional[set] = None, database: Optional[str] = None): + word = self.search_engine.clean_text(word) + completions = self.search_engine.auto_complete(word, context=context, database=database) + return completions AB_metadata = MetaDataStore() diff --git a/activity_browser/bwutils/searchengine/__init__.py b/activity_browser/bwutils/searchengine/__init__.py new file mode 100644 index 000000000..a3ed1d8e1 --- /dev/null +++ b/activity_browser/bwutils/searchengine/__init__.py @@ -0,0 +1,2 @@ +from .base import SearchEngine +from .metadata_search import MetaDataSearchEngine diff --git a/activity_browser/bwutils/searchengine/base.py b/activity_browser/bwutils/searchengine/base.py new file mode 100644 index 000000000..5b9127985 --- /dev/null +++ b/activity_browser/bwutils/searchengine/base.py @@ -0,0 +1,817 @@ +from itertools import permutations, chain +import itertools +import functools +from collections import Counter, OrderedDict, defaultdict +from logging import getLogger +import math +import multiprocessing as mp +from time import time +from typing import Iterable, Optional +import pandas as pd +import numpy as np +import re +import sys + + +log = getLogger(__name__) + + +class SearchEngine: + """ + A Search Engine class, takes a dataframe and makes it searchable. + + A search requires a string, and will return a list of unique identifiers in the dataframe. + There are three options for search: + SearchEngine.literal_search(): searches for exact matches of the search query + SearchEngine.fuzzy_search(): searches for approximate matches of search query, sorted by relevance + SearchEngine.search(): combines both of the above, literal matches are returned first, next all fuzzy results, + but subsets sorted by relevance. + It is recommended to always use searchEngine.search(), but the other options are there. + + Initialization takes: + df: Dataframe that needs to be searchable. + identifier_name: values in this column will be returned as search results, all values in this column need to be unique. + searchable_columns: these columns need to be searchable, if none are given, all columns will be made searchable. + + Updating data is possible as well: + add_identifier(): adds this identifier to the searchable data + remove_identifier(): removes this identifier from the searchable data + change_identifier(): changes this identifier (wrapper for remove_identifier and add_identifier) + + """ + + def __init__(self, df: pd.DataFrame, identifier_name: str, searchable_columns: list = []): + t = time() + log.debug(f"SearchEngine initializing for {len(df)} items") + + # compile regex patterns for cleaning + self.SUB_END_PATTERN = re.compile(r"[,.\"'`)\[\]}\\/\-−_:;+…]+(?=\s|$)") # remove these from end of word + self.SUB_START_PATTERN = re.compile(r"(?:^|\s)[,.\"'`(\[{\\/\-−_:;+]+") # remove these from start of word + self.ONE_SPACE_PATTERN = re.compile(r"\s+") # remove these multiple whitespaces + + self.q = 2 # character length of q grams + self.base_weight = 10 # base weighting for sorting results + + if identifier_name not in df.columns: # make sure identifier col exist + raise NameError(f"Identifier column {identifier_name} not found in dataframe. Use an existing column name.") + if df[identifier_name].nunique() != df.shape[0]: # make sure identifiers are all unique + raise KeyError( + f"Identifier column {identifier_name} must only contain unique values. Found {df[identifier_name].nunique()} unique values for length {df.shape[0]}") + + self.identifier_name = identifier_name + + # ensure columns given actually exist + # always ensure "identifier" is present + if searchable_columns == []: + # if no list is given, assume all columns are searchable + self.columns = list(df.columns) + else: + # create subset of columns to be searchable, discard rest + self.columns = [col for col in searchable_columns if col in df.columns] + if self.identifier_name not in self.columns: # keep identifier col + self.columns.append(self.identifier_name) + df = df[self.columns] + # set the identifier column as index + df = df.set_index(self.identifier_name, drop=False) + + # convert all data to str + df = df.astype(str) + + # find the self.identifier_name column index and store as int + self.identifier_column = self.columns.index(self.identifier_name) + + # store all searchable column indices except the identifier + self.searchable_columns = [i for i in range(len(self.columns)) if i != self.identifier_column] + + # initialize search index dicts and update df + self.identifier_to_word = {} + self.word_to_identifier = {} + self.word_to_q_grams = {} + self.q_gram_to_word = {} + self.df = pd.DataFrame() + + self.update_index(df) + + log.debug(f"SearchEngine Initialized in {time() - t:.2f} seconds") + + # +++ Utility functions + + def update_index(self, update_df: pd.DataFrame) -> None: + """Update search index dicts and the df.""" + + def update_dict(update_me: dict, new: dict) -> dict: + """Update a dict of counters with new dict of counters.""" + # set to empty set if we know update_me is empty, otherwise, find set intersection + update_keys = set() if len(update_me) == 0 else new.keys() & update_me.keys() + if len(update_keys) == 0: + new_data = new + else: + for update_key in update_keys: + update_me[update_key].update(new[update_key]) + new_data = {key: value for key, value in new.items() if key not in update_keys} + # finally add any completely new data + # update_me.update(new_data) + update_me = update_me | new_data + return update_me + + if len(update_df) == 0: + return + + t = time() + size_old = len(self.df) + # identifier to word and df + i2w, update_df = self.words_in_df(update_df) + self.identifier_to_word = update_dict(self.identifier_to_word, i2w) + self.df = pd.concat([self.df, update_df]) + # word to identifier + w2i = self.reverse_dict_many_to_one(i2w) + self.word_to_identifier = update_dict(self.word_to_identifier, w2i) + # word to q-gram + w2q = self.list_to_q_grams(w2i.keys()) + self.word_to_q_grams = update_dict(self.word_to_q_grams, w2q) + # q-gram to word + q2w = self.reverse_dict_many_to_one(w2q) + self.q_gram_to_word = update_dict(self.q_gram_to_word, q2w) + size_new = len(self.df) + size_dif = size_new - size_old + size_msg = (f"{size_dif} changed items at {int(round(size_dif/(time() - t), 0))} items/sec " + f"({size_new} items ({self.size_of_index()}) currently)") if size_dif > 1 \ + else f"1 changed item ({size_new} items ({self.size_of_index()}) currently)" + log.debug(f"Search index updated in {time() - t:.2f} seconds for {size_msg}.") + + def clean_text(self, text: str): + """Clean a string so it doesn't contain weird characters or multiple spaces etc.""" + text = text.lower() + text = self.SUB_END_PATTERN.sub("", text) + text = self.SUB_START_PATTERN.sub(" ", text) + + text = self.ONE_SPACE_PATTERN.sub(" ", text).strip() + return text + + def text_to_positional_q_gram(self, text: str) -> list: + """Return a positional list of q-grams for the given string. + + q-grams are n-grams on character level. + q-grams at q=2 of "word" would be "wo", "or" and "rd" + https://en.wikipedia.org/wiki/N-gram + + Note: these are technically _positional_ q-grams, but we don't use their positions currently. + """ + q = self.q + n = len(text) + # just return a single-item list if the text is equal or shorter than q + # else, generate q-grams + if n <= q: + return [text] + return list(text[i:i + q] for i in range(n - q + 1)) + + def df_clean_worker(self, df): + """Clean the text in query_col.""" + df["query_col"] = df["query_col"].apply(self.clean_text) + return df + + def df_clean(self, df): + """Clean the text in query_col. + + apply multi-processing when the computer is able and its relevant + """ + def chunk_dataframe(df: pd.DataFrame, chunk_size: int): + """Split DataFrame into chunks of specified size.""" + return [df.iloc[i:i + chunk_size] for i in range(0, len(df), chunk_size)] + + max_cores = max(1, mp.cpu_count() - 1) # leave at least 1 core for other processes + min_chunk_size = 2500 + if max_cores > 1 and len(df) > min_chunk_size * 2: + for i in range(max_cores, 0, -1): + chunk_size = int(math.ceil(len(df) / i)) + if chunk_size >= min_chunk_size: + break + use_cores = i + else: + use_cores = 1 + if use_cores == 1: + return self.df_clean_worker(df) + + chunks = chunk_dataframe(df, chunk_size) + with mp.Pool(processes=use_cores) as pool: + results = pool.starmap(self.df_clean_worker, [(chunk,) for chunk in chunks]) + return pd.concat(results) + + def words_in_df(self, df: pd.DataFrame = None) -> tuple[dict, pd.DataFrame]: + """Return a dict of {identifier: word} for df.""" + + df = df if df is not None else self.df.copy() + df = df.fillna("") # avoid nan + # assemble query_col + df["query_col"] = df.iloc[:, self.searchable_columns].astype(str).agg(" | ".join, axis=1) + # clean all text at once using vectorized operations + df["query_col"] = self.df_clean(df.loc[:, ["query_col"]]) + # build the identifier_word_dict dictionary + identifier_word_dict = df["query_col"].apply(lambda text: Counter(text.split(" "))).to_dict() + return identifier_word_dict, df + + def reverse_dict_many_to_one(self, dictionary: dict) -> dict: + """Reverse a dictionary of Counter objects.""" + reverse = defaultdict(Counter) + for identifier, counter_object in dictionary.items(): + for countable, count in counter_object.items(): + reverse[countable][identifier] += count + return dict(reverse) + + def list_to_q_grams(self, word_list: Iterable) -> dict: + """Convert a list of unique words to a dict with Counter objects. + + Number will be the occurrences of that q-gram in that word. + + return = { + "word": Counter( + "wo": 1 + "or": 1 + "rd": 1 + ), + ... + } + """ + text_to_q_gram = self.text_to_positional_q_gram + return { + word: Counter(text_to_q_gram(word)) + for word in word_list + } + + def word_in_index(self, word: str) -> bool: + """Convenience function to check if a single word is in the search index.""" + if " " in word: + raise Exception( + f"Given word '{word}' must not contain spaces.") + return word in self.word_to_identifier.keys() + + def size_of_index(self): + """return the size of the search index in MB or GB.""" + s_df = sys.getsizeof(self.df) + s_i2w = sys.getsizeof(self.identifier_to_word) + s_w2i = sys.getsizeof(self.word_to_identifier) + s_w2q = sys.getsizeof(self.word_to_q_grams) + s_q2w = sys.getsizeof(self.q_gram_to_word) + size_bytes = s_df + s_i2w + s_w2i + s_w2q + s_q2w + + if size_bytes < 1024 ** 3: + return f"{size_bytes / (1024 ** 2):.1f} MB" + else: + return f"{size_bytes / (1024 ** 3):.2f} GB" + + # +++ Changes to searchable data + + def add_identifier(self, data: pd.DataFrame) -> None: + """Add this data to the search index. + + identifier column is REQUIRED to be present + ALL data in the given dataframe will be added, if columns should not be added, they should be removed before + calling this function + """ + + # ensure we have identifier column + if self.identifier_name not in data.columns: + raise Exception( + f"Identifier column '{self.identifier_name}' not in new data, impossible to add data without identifier") + + # make sure we the new identifiers do not yet exist + existing_ids = set(self.df.index.to_list()) + for identifier in data[self.identifier_name]: + if identifier in existing_ids: + raise Exception( + f"Identifier '{identifier}' is already in use, use a different identifier or use the change_identifier function.") + + # make sure all new identifiers given are unique + if data[self.identifier_name].nunique() != data.shape[0]: + raise KeyError( + f"Identifier column {self.identifier_name} must only contain unique values. Found {data[self.identifier_name].nunique()} unique values for length {data.shape[0]}") + + df_cols = self.columns + # add cols to new data that are missing + for col in df_cols: + if col not in data.columns: + data.loc[:, col] = [""] * len(data) + # re-order cols, first existing, then new + df_col_set = set(df_cols) + new_cols = [col for col in data.columns if col not in self.columns if col not in df_col_set] + data_cols = df_cols + new_cols + data = data[data_cols] # re-order new data to be in correct order + + # add cols from new data to correct places + self.columns.extend(new_cols) + self.searchable_columns.extend([i for i, col in enumerate(data_cols) if col in new_cols]) + + # convert df + data = data.set_index(self.identifier_name, drop=False) + data = data.fillna("") + data = data.astype(str) + + # update the search index data + self.update_index(data) + + def remove_identifier(self, identifier, logging=True) -> None: + """Remove this identifier from self.df and the search index. + """ + if logging: + t = time() + + # make sure the identifier exists + if identifier not in self.df.index.to_list(): + raise Exception( + f"Identifier '{identifier}' does not exist in the search data, cannot remove identifier that do not exist.") + + self.df = self.df.drop(identifier) + + # find words that may need to be removed + words = self.identifier_to_word[identifier] + for word in words: + if len(self.word_to_identifier[word]) == 1: + # this word is only found in this identifier, + # remove the word and check for q grams + del self.word_to_identifier[word] + + q_grams = self.word_to_q_grams[word] + for q_gram in q_grams: + if len(self.q_gram_to_word[q_gram]) == 1: + # this q_gram is only used in this word, + # remove it + del self.q_gram_to_word[q_gram] + elif len(self.q_gram_to_word[q_gram]) > 1: + # this q_gram is used in multiple words, only remove the word from the q_gram + del self.q_gram_to_word[q_gram][word] + + del self.word_to_q_grams[word] + else: + # this word is found in multiple identifiers + # word_to_q_gram and q_gram_to_word do not need to be changed, the word still exists + # remove the identifier the word in word_to_identifier + del self.word_to_identifier[word][identifier] + # finally, remove the identifier + del self.identifier_to_word[identifier] + + if logging: + log.debug(f"Search index updated in {time() - t:.2f} seconds " + f"for 1 removed item ({len(self.df)} items ({self.size_of_index()}) currently).") + + def change_identifier(self, identifier, data: pd.DataFrame) -> None: + """Change this identifier. + + identifier must be an identifier that is in use + data must be a dataframe of 1 row with all change data + data is overwritten with the new data in 'data', columns not given remain unchanged + """ + + # make sure only 1 change item is given + if len(data) > 1 or len(data) < 1: + raise Exception( + f"change data must be for exactly 1 identifier, but {len(data)} items were given.") + # make sure correct use of identifier + if identifier not in self.df.index.to_list(): + raise Exception( + f"Identifier '{identifier}' does not exist in the search data, use an existing identifier or use the add_identifier function.") + if self.identifier_name in data.columns and data[self.identifier_name].to_list() != [identifier]: + raise Exception( + "Identifier field cannot be changed, first remove item and then add new identifier") + if "query_col" in data.keys(): + log.debug( + f"Field 'query_col' is a protected field for search engine and will be ignored for changing {identifier}") + + + # overwrite new data where relevant + update_data = self.df.loc[[identifier], self.columns] + data = data.reset_index(drop=True) + for col in data.columns: + value = data.loc[0, col] + update_data[col] = [value] + + # remove the entry + self.remove_identifier(identifier, logging=False) + # add entry with updated data + self.add_identifier(update_data) + + # +++ Search + + def filter_dataframe(self, df: pd.DataFrame, pattern: str, search_columns: Optional[list] = None) -> pd.Series: + """Filter the search columns of a dataframe on a pattern. + + Returns a mask (true/false) pd.Series with matching items.""" + + search_columns = search_columns if search_columns else self.columns + mask = functools.reduce( + np.logical_or, + [ + df[col].apply(lambda x: pattern in x.lower()) + for col in search_columns + ], + ) + return mask + + def literal_search(self, text, df: Optional[pd.DataFrame] = None) -> list: + """Do literal search of the text in all original columns that were given.""" + + if df is None: + df = self.df.copy() + + identifiers = self.filter_dataframe(df, text) + df = df.loc[identifiers] + identifiers = df.index.to_list() + return identifiers + + def osa_distance(self, word1: str, word2: str, cutoff: int = 0, cutoff_return: int = 1000) -> int: + """Calculate the Optimal String Alignment (OSA) edit distance between two strings, return edit distance. + + Has additional cutoff variable, if cutoff is higher than 0 and if the words have + a larger edit distance, return a large number (note: cutoff <= edit_dist, not cutoff < edit_dist) + + OSA is a restricted form of the Damerau–Levenshtein distance. + https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance#Optimal_string_alignment_distance + + The edit distance is how many operations (insert, delete, substitute or transpose a character) need to happen to convert one string to another. + insert and delete are obvious operations, but substitute and transpose are explained: + substitute: replace one character with another: e.g. word1='cat' word2='cab', 't'->'b' substitution is 1 operation + transpose: swap the places of two adjacent characters with each other: e.g. word1='coal' word2='cola' 'al' -> 'la' transposition is 1 operation + + The minimum amount of edit operations (OSA edit distance) is returned. + """ + if word1 == word2: + # if the strings are the same, immediately return 0 + return 0 + + len1, len2 = len(word1), len(word2) + + if 0 < cutoff <= abs(len1 - len2): + # if the length difference between 2 words is over the cutoff, + # just return instead of calculating the edit distance + return cutoff_return + + if len1 == 0 or len2 == 0: + # in case (at least) one of the strings is empty, + # return the length of the longest string + return max(len1, len2) + + if len1 < len2 and cutoff > 0: + # make sure word1 is always the longest (required for early stopping with cutoff) + word1, word2 = word2, word1 + len1, len2 = len2, len1 + + # Initialize matrix + distance = [[0] * len2 for _ in range(len1)] + + # calculate shortest edit distance + for i in range(len1): + for j in range(len2): + cost = 0 if word1[i] == word2[j] else 1 + + # Compute distances for insertion, deletion and substitution + insertion = distance[i][j - 1] + 1 if j > 0 else i + 1 + deletion = distance[i - 1][j] + 1 if i > 0 else j + 1 + substitution = distance[i - 1][j - 1] + cost if i > 0 and j > 0 else max(i, j) + cost + + distance[i][j] = min(deletion, insertion, substitution) + + # Compute transposition when relevant + if i > 0 and j > 0 and word1[i] == word2[j - 1] and word1[i - 1] == word2[j]: + transposition = distance[i - 2][j - 2] + 1 if i > 1 and j > 1 else max(i, j) - 1 + distance[i][j] = min(distance[i][j], transposition) + + # stop early if we surpass cutoff + if 0 < cutoff <= min(distance[i]): + return cutoff_return + return distance[i][j] + + def find_q_gram_matches(self, q_grams: set) -> pd.DataFrame: + """Find which of the given q_grams exist in self.q_gram_to_word, + return a sorted dataframe of best matching words. + """ + n_q_grams = len(q_grams) + + matches = {} + + # find words that match our q-grams + for q_gram in q_grams: + if words := self.q_gram_to_word.get(q_gram, False): + # q_gram exists in our search index + for word in words: + matches[word] = matches.get(word, 0) + words[word] + + # if we find no results, return an empty dataframe + if len(matches) == 0: + return pd.DataFrame({"word": [], "matches": []}) + + # otherwise, create a dataframe and + # reduce search results to most relevant results + matches = {"word": matches.keys(), "matches": matches.values()} + matches = pd.DataFrame(matches) + max_q = max(matches["matches"]) # this has the most matching q-grams + + # determine how many results we want to keep based on how good our results are + min_q = min(max(max_q * 0.32, # have at least a third of q-grams of best match or... + max(n_q_grams * 0.5, # if more, at least half the q-grams in the query word? + 1)), # okay just do 1 q-gram if there are no more in the word + max_q) # never have min_q be over max_q + + matches = matches[matches["matches"] >= min_q] + matches = matches.sort_values(by="matches", ascending=False) + matches = matches.reset_index(drop=True) + + return matches.iloc[:min(len(matches), 2500), :] # return at most this many results + + def spell_check(self, text: str, skip_len=1) -> OrderedDict: + """Create an OrderedDict of each word in the text (space separated) + with as values possible alternatives. + + Alternatives are first found with q-grams, then refined with string edit distance + + We rank alternative words based on 1) edit distance 2) how often a word is used in an entry + If too many results are found, we only keep edit distance 1, + if we want more results, we keep with longer edit distance up to `never_accept_this` + + word_results = OrderedDict( + "word": [work] + ) + + NOTE: only ALTERNATIVES are ever returned, this function returns empty list for item BOTH when + 1) the exact word is in the data + 2) when there are no suitable alternatives + """ + count_occurence = lambda x: sum(self.word_to_identifier[x].values()) # count occurences of a word + + word_results = OrderedDict() + + matches_min = 3 # ideally we have at least this many alternatives + matches_max = 10 # ideally don't much more than this many matches + always_accept_this = 1 # values of this edit distance or lower always accepted + never_accept_this = 4 # values this edit distance or over always rejected + + # make list of unique words + text = self.clean_text(text) + words = OrderedDict() + for word in text.split(" "): + if len(word) != 0: + words[word] = False + words = words.keys() + + for word in words: + if len(word) <= skip_len: # dont look for alternatives for text this short + word_results[word] = [] + continue + + # reduce acceptable edit distance with short words + dont_accept = int(round(max(1, min((len(word) * 0.66), never_accept_this)), 0)) + + # first, find possible matches quickly + q_grams = self.text_to_positional_q_gram(word) + possible_matches = self.find_q_gram_matches(set(q_grams)) + + first_matches = Counter() + other_matches = {} + + # now, refine with edit distance + for row in possible_matches.itertuples(): + + edit_distance = self.osa_distance(word, row[1], cutoff=dont_accept) + + if edit_distance == 0: + continue # we are looking for alternatives only, not the exact word + elif edit_distance <= always_accept_this: + first_matches[row[1]] = count_occurence(row[1]) + elif edit_distance < dont_accept: + if not other_matches.get(edit_distance): + other_matches[edit_distance] = Counter() + other_matches[edit_distance][row[1]] = count_occurence(row[1]) + else: + continue + + # add matches in correct order: + matches = [match for match, _ in first_matches.most_common()] + # if we have fewer matches than goal, add more 'less good' matches + if len(matches) < matches_min: + for i in range(always_accept_this + 1, dont_accept): + # iteratively increase matches with 'worse' results so we hit goal of minimum alternatives + if new := other_matches.get(i): + prev_num = 10e100 + for match, num in new.most_common(): + if num == prev_num: + matches.append(match) + elif num != prev_num and len(matches) <= matches_max: + matches.append(match) + else: + break + prev_num = num + + word_results[word] = matches + return word_results + + def build_queries(self, query_text) -> list: + """Make all possible subsets of words in the query, including alternative words.""" + query_text = self.spell_check(query_text) + + # find all combinations of the query words as given + queries = list(query_text.keys()) + subsets = list(chain.from_iterable( + (itertools.combinations( + queries, r) for r in range(1, len(queries) + 1)))) + all_queries = [] + + for combination in subsets: + # add the 'default' option + all_queries.append(combination) + # now add all options with all alternatives + for i, word in enumerate(combination): + for alternative in query_text.get(word, []): + alternative_combination = list(combination) + alternative_combination[i] = alternative + all_queries.append(alternative_combination) + + return all_queries + + def weigh_identifiers(self, identifiers: Counter, weight: int, weighted_ids: Counter) -> Counter: + """Add weights to identifier counter for these identifiers times how often it occurs in identifier.""" + for identifier, occurrences in identifiers.items(): + weighted_ids[identifier] += (weight * occurrences) + return weighted_ids + + def search_size_1(self, queries: list, original_words: set, orig_word_weight=5, exact_word_weight=1) -> dict: + """Return a dict of {query_word: Counter(identifier)}. + + queries: is a list of len 1 tuple/lists of words that are a searched word or a 'spell checked' similar word + original words: a list of words actually searched for (not including spellchecked) + + orig_word_weight: additional weight to add to original words + exact_word_weight: additional weight to add to exact word matches (as opposed to be 'in' str) + + First, we find all matching words, creating a dict of words in 'queries' as keys and words matching that query word as list of values + Next, we convert this to identifiers and add weights: + Weight will be increased if matching 'orig_word_weight' or 'exact_word_weight' + """ + matches = {} + # add each word in search index if query_word in word + for word in self.word_to_identifier.keys(): + for query in queries: + # query is list/tuple of len 1 + query_word = query[0] # only use the word + if query_word in word: + words = matches.get(query_word, []) + words.extend([word]) + matches[query_word] = words + + # now convert matched words to matched identifiers + matched_identifiers = {} + for word, matching_words in matches.items(): + for matched_word in matching_words: + weight = self.base_weight + id_counter = matched_identifiers.get(word, Counter()) + + # add the word n times, where n is the weight, original search word is weighted higher than alternatives + if matched_word in original_words: + weight += orig_word_weight # increase weight for original word + if matched_word == word: + weight += exact_word_weight # increase weight for exact matching word + + id_counter = self.weigh_identifiers(self.word_to_identifier[matched_word], weight, id_counter) + matched_identifiers[word] = id_counter + + return matched_identifiers + + def fuzzy_search(self, text: str, return_counter: bool = False) -> list: + """Search the dataframe, finding approximate matches and return a list of identifiers, + ranked by how well each identifier matches the search text. + + 1. First, identifiers matching single words (and spell-checked alternatives) are found and weighted. + 2. If the search term consisted of multiple words, combinations of those words are checked next. + 2.1 Increasing in size (first two words, then three etc.), we look for identifiers that contain that set of + words, these are also weighted, based on the sum of all one-word weights (from first step) and the length + of the sequence. + 2.2 Next, we also look specifically for combinations occurring next to each other. And add more weight like + the step above (2.1). + We multiply the weighting of step 2 by the sequence length, based on the assumption that finding more search + words will be a more relevant result than just finding a single word, and again if they are in the + correct order. + + Finally, all found identifiers are sorted on their weight and returned. + """ + text = text.strip() + + queries = self.build_queries(text) + + # make list of unique original words + orig_words = OrderedDict() + for word in text.split(" "): + orig_words[word] = False + orig_words = orig_words.keys() + orig_words = {self.clean_text(word) for word in orig_words} + + # order the queries by the amount of words they contain + # we do this because longer queries (more words) are harder to find, but we have many alternatives so we search in a smaller search space + queries_by_size = OrderedDict() + longest_query = max([len(q) for q in queries]) + for query_len in range(1, longest_query + 1): + queries_by_size[query_len] = [q for q in queries if len(q) == query_len] + + # first handle queries of length 1 + query_to_identifier = self.search_size_1(queries_by_size[1], orig_words) + + # get all results into a df, we rank further later + all_identifiers = set() + for id_list in [id_list for id_list in query_to_identifier.values()]: + all_identifiers.update(id_list) + search_df = self.df.loc[list(all_identifiers)] + + # now, we search for combinations of query words and get only those identifiers + # we then reduce de search_df further for only those matching identifiers + # we then search the permutations of that set of words + for q_len, query_set in queries_by_size.items(): + if q_len == 1: + # we already did these above + continue + for query in query_set: + # get the intersection of all identifiers + # meaning, a set of identifiers that occur in ALL sets of len(1) for the individual words in the query + # this ensures we only ever search data where ALL items occur to substantially reduce search-space + # finally, make this a Counter (with each item=1) so we can properly weigh things later + query_id_sets = [set(query_to_identifier.get(q_word)) for q_word in query if + query_to_identifier.get(q_word, False)] + if len(query_id_sets) == 0: + continue + query_identifier_set = set.intersection(*query_id_sets) + if len(query_identifier_set) == 0: + # there is no match for this combination of query words, skip + break + + # now we convert the query identifiers to a Counter of 'occurrence', + # where we weigh queries with only original words higher + query_identifiers = Counter() + for identifier in query_identifier_set: + weight = 0 + for query_word in query: + # if the query_word and identifier combination exist get score, otherwise 0 + weight += query_to_identifier.get(query_word, {}).get(identifier, 0) + + query_identifiers[identifier] = weight + + # we now add these identifiers to a counter for this query name, + query_name = " ".join(query) + + weight = self.base_weight * q_len + query_to_identifier[query_name] = self.weigh_identifiers(query_identifiers, weight, Counter()) + + # now search for all permutations of this query combined with a space + query_df = search_df[search_df[self.identifier_name].isin(query_identifiers)] + for query_perm in permutations(query): + mask = self.filter_dataframe(query_df, " ".join(query_perm), search_columns=["query_col"]) + new_df = query_df.loc[mask].reset_index(drop=True) + if len(new_df) == 0: + # there is no match for this permutation of words, skip + continue + new_id_list = new_df[self.identifier_name] + + new_ids = Counter() + for new_id in new_id_list: + new_ids[new_id] = query_identifiers[new_id] + + # we weigh a combination of words that is next also to each other even higher than just the words separately + query_to_identifier[query_name] = self.weigh_identifiers(new_ids, weight, + query_to_identifier[query_name]) + # now finally, move to one object sorted list by highest score + all_identifiers = Counter() + for identifiers in query_to_identifier.values(): + all_identifiers += identifiers + + if return_counter: + return all_identifiers + # now sort on highest weights and make list type + sorted_identifiers = [identifier for identifier, _ in all_identifiers.most_common()] + return sorted_identifiers + + def search(self, text) -> list: + """Search the dataframe on this text, return a sorted list of identifiers.""" + t = time() + text = text.strip() + + if len(text) == 0: + log.debug(f"Empty search, returned all items") + return self.df.index.to_list() + + fuzzy_identifiers = self.fuzzy_search(text) + if len(fuzzy_identifiers) == 0: + log.debug(f"Found 0 search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return [] + + # take the fuzzy search sub-set of data and search it literally + df = self.df.loc[fuzzy_identifiers].copy() + + literal_identifiers = self.literal_search(text, df) + if len(literal_identifiers) == 0: + log.debug( + f"Found {len(fuzzy_identifiers)} search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return fuzzy_identifiers + + # append any fuzzy identifiers that were not found in the literal search + literal_id_set = set(literal_identifiers) + remaining_fuzzy_identifiers = [ + _id for _id in fuzzy_identifiers if _id not in literal_id_set] + identifiers = literal_identifiers + remaining_fuzzy_identifiers + + log.debug( + f"Found {len(identifiers)} ({len(literal_identifiers)} literal) search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return identifiers diff --git a/activity_browser/bwutils/searchengine/metadata_search.py b/activity_browser/bwutils/searchengine/metadata_search.py new file mode 100644 index 000000000..1814a3e8a --- /dev/null +++ b/activity_browser/bwutils/searchengine/metadata_search.py @@ -0,0 +1,447 @@ +from itertools import permutations +from collections import Counter, OrderedDict +from logging import getLogger +from time import time +from typing import Optional +import pandas as pd + +from activity_browser.bwutils.searchengine import SearchEngine + + +log = getLogger(__name__) + + +class MetaDataSearchEngine(SearchEngine): + + # caching for faster operation + def database_id_manager(self, database): + if not hasattr(self, "all_database_ids"): + self.all_database_ids = {} + + if database_ids := self.all_database_ids.get(database): + self.database_ids = database_ids + self.current_database = database + elif database is not None: + self.database_ids = set(self.df[self.df["database"] == database].index.to_list()) + self.all_database_ids[database] = self.database_ids + self.current_database = database + else: + self.database_ids = None + self.current_database = "_@@NO_DB_" + return self.database_ids + + def reset_database_id_manager(self): + if hasattr(self, "all_database_ids"): + del self.all_database_ids + if hasattr(self, "database_ids"): + del self.database_ids + + def database_word_manager(self, database): + if not hasattr(self, "all_database_words"): + self.all_database_words = {} + + if database_words := self.all_database_words.get(database): + self.database_words = database_words + elif database is not None: + ids = self.database_id_manager(database) + self.database_words = self.reverse_dict_many_to_one({_id: self.identifier_to_word[_id] for _id in ids}) + self.all_database_words[database] = self.database_words + else: + self.database_words = None + return self.database_words + + def reset_database_word_manager(self, database): + if hasattr(self, "all_database_words") and self.all_database_words.get(database): + del self.all_database_words[database] + if hasattr(self, "database_words"): + del self.database_words + + def database_search_cache(self, database, query, result = None): + if not hasattr(self, "search_cache"): + self.search_cache = {} + + if result: + if self.search_cache.get(database): + self.search_cache[database][query] = result + else: + self.search_cache[database] = {query: result} + return + if db_cache := self.search_cache.get(database): + if cached_result := db_cache.get(query): + return cached_result + return + + def reset_search_cache(self, database): + if hasattr(self, "search_cache") and self.search_cache.get(database): + del self.search_cache[database] + + def reset_all_caches(self, databases): + self.reset_database_id_manager() + for database in databases: + self.reset_database_word_manager(database) + self.reset_search_cache(database) + + def add_identifier(self, data: pd.DataFrame) -> None: + super().add_identifier(data) + self.reset_all_caches(data["database"].unique()) + + def remove_identifiers(self, identifiers, logging=True) -> None: + t = time() + + identifiers = set(identifiers) + current_identifiers = set(self.df.index.to_list()) + identifiers = identifiers | current_identifiers # only remove identifiers currently in the data + databases = self.df.loc[identifiers, ["databases"]].unique() # extract databases for cache cleaning + if len(identifiers) == 0: + return + + for identifier in identifiers: + super().remove_identifier(identifier, logging=False) + + if logging: + log.debug(f"Search index updated in {time() - t:.2f} seconds " + f"for {len(identifiers)} removed items ({len(self.df)} items ({self.size_of_index()}) currently).") + self.reset_all_caches(databases) + + def change_identifier(self, identifier, data: pd.DataFrame) -> None: + super().change_identifier(identifier, data) + self.reset_all_caches(data["database"].unique()) + + def auto_complete(self, word: str, context: Optional[set] = set(), database: Optional[str] = None) -> list: + """Based on spellchecker, make more useful for autocompletions + """ + def word_to_identifier_to_word(check_word): + if len(context) == 0: + return 1 + multiplier = 1 + for identifier in self.word_to_identifier[check_word]: + for context_word in context: + for spell_checked_context_word in spell_checked_context[context_word]: + if spell_checked_context_word in self.identifier_to_word[identifier]: + multiplier += 1 + if context_word not in self.word_to_identifier.keys(): + continue + if context_word in self.identifier_to_word[identifier]: + multiplier += 4 + return multiplier + + # count occurrences of a word, count double so word_to_identifier_to_word will never multiply by 1 + count_occurrence = lambda x: sum(self.word_to_identifier[x].values()) * 2 + + if len(word) <= 1: + return [] + + self.database_id_manager(database) + + if len(context) > 0: + spell_checked_context = {} + for context_word in context: + spell_checked_context[context_word] = self.spell_check(context_word).get(context_word, [])[:5] + + matches_min = 2 # ideally we have at least this many alternatives + matches_max = 4 # ideally don't much more than this many matches + never_accept_this = 4 # values this edit distance or over always rejected + # or max 2/3 of len(word) if less than never_accept_this + never_accept_this = int(round(max(1, min((len(word) * 0.66), never_accept_this)), 0)) + + # first, find possible matches quickly + q_grams = self.text_to_positional_q_gram(word) + possible_matches = self.find_q_gram_matches(set(q_grams), return_all=True) + + first_matches = Counter() + other_matches = {} + probably_keys = Counter() # if we suspect it's a key hash, dump it at the end of the list + + # now, refine with edit distance + for row in possible_matches.itertuples(): + if word == row[1]: + continue + # find edit distance of same size strings + edit_distance = self.osa_distance(word, row[1][:len(word)], cutoff=never_accept_this) + if len(row[1]) == 32 and edit_distance <= 1: + probably_keys[row[1]] = 100 - edit_distance # keys need to be sorted on edit distance, not on occurence + elif edit_distance == 0: + first_matches[row[1]] = count_occurrence(row[1]) * word_to_identifier_to_word(row[1]) + elif edit_distance < never_accept_this and len(first_matches) < matches_min: + if not other_matches.get(edit_distance): + other_matches[edit_distance] = Counter() + other_matches[edit_distance][row[1]] = count_occurrence(row[1]) * word_to_identifier_to_word(row[1]) + else: + continue + + # add matches in correct order: + matches = [match for match, _ in first_matches.most_common()] + # if we have fewer matches than goal, add more 'less good' matches + if len(matches) < matches_min: + for i in range(1, never_accept_this): + # iteratively increase matches with 'worse' results so we hit goal of minimum alternatives + if new := other_matches.get(i): + prev_num = 10e100 + for match, num in new.most_common(): + if num == prev_num: + matches.append(match) + elif num != prev_num and len(matches) <= matches_max: + matches.append(match) + else: + break + prev_num = num + + matches = matches + [match for match, _ in probably_keys.most_common()] + return matches + + def find_q_gram_matches(self, q_grams: set, return_all: bool = False) -> pd.DataFrame: + """Overwritten for extra database specific reduction of results. + """ + n_q_grams = len(q_grams) + + matches = {} + + # find words that match our q-grams + for q_gram in q_grams: + if words := self.q_gram_to_word.get(q_gram, False): + # q_gram exists in our search index + for word in words: + if isinstance(self.database_ids, set): + # DATABASE SPECIFIC now filter on whether word is in the database + in_db = False + for _id in self.word_to_identifier[word]: + if _id in self.database_ids: + in_db = True + break + else: + in_db = True + if in_db: + matches[word] = matches.get(word, 0) + words[word] + + # if we find no results, return an empty dataframe + if len(matches) == 0: + return pd.DataFrame({"word": [], "matches": []}) + + # otherwise, create a dataframe and + # reduce search results to most relevant results + matches = {"word": matches.keys(), "matches": matches.values()} + matches = pd.DataFrame(matches) + max_q = max(matches["matches"]) # this has the most matching q-grams + + # determine how many results we want to keep based on how good our results are + if not return_all: + min_q = min(max(max_q * 0.32, # have at least a third of q-grams of best match or... + max(n_q_grams * 0.5, # if more, at least half the q-grams in the query word? + 1)), # okay just do 1 q-gram if there are no more in the word + max_q) # never have min_q be over max_q + else: + min_q = 0 + + matches = matches[matches["matches"] >= min_q] + matches = matches.sort_values(by="matches", ascending=False) + matches = matches.reset_index(drop=True) + + return matches.iloc[:min(len(matches), 2500), :] # return at most this many results + + def search_size_1(self, queries: list, original_words: set, orig_word_weight=5, exact_word_weight=1) -> dict: + """Return a dict of {query_word: Counter(identifier)}. + + queries: is a list of len 1 tuple/lists of words that are a searched word or a 'spell checked' similar word + original words: a list of words actually searched for (not including spellchecked) + + orig_word_weight: additional weight to add to original words + exact_word_weight: additional weight to add to exact word matches (as opposed to be 'in' str) + + First, we find all matching words, creating a dict of words in 'queries' as keys and words matching that query word as list of values + Next, we convert this to identifiers and add weights: + Weight will be increased if matching 'orig_word_weight' or 'exact_word_weight' + """ + matches = {} + t2 = time() + # add each word in search index if query_word in word + for word in self.database_words.keys(): + for query in queries: + # query is list/tuple of len 1 + query_word = query[0] # only use the word + if query_word in word: + words = matches.get(query_word, []) + words.extend([word]) + matches[query_word] = words + + # now convert matched words to matched identifiers + matched_identifiers = {} + for word, matching_words in matches.items(): + if result := self.database_search_cache(self.current_database, word): + matched_identifiers[word] = result + continue + id_counter = matched_identifiers.get(word, Counter()) + for matched_word in matching_words: + weight = self.base_weight + + # add the word n times, where n is the weight, original search word is weighted higher than alternatives + if matched_word in original_words: + weight += orig_word_weight # increase weight for original word + if matched_word == word: + weight += exact_word_weight # increase weight for exact matching word + + id_counter = self.weigh_identifiers(self.database_words[matched_word], weight, id_counter) + matched_identifiers[word] = id_counter + self.database_search_cache(self.current_database, word, matched_identifiers[word]) + + return matched_identifiers + + def fuzzy_search(self, text: str, database: Optional[str] = None, return_counter: bool = False, logging: bool = True) -> list: + """Overwritten for extra database specific reduction of results. + """ + t = time() + text = text.strip() + + if len(text) == 0: + log.debug(f"Empty search, returned all items") + return self.df.index.to_list() + + # DATABASE SPECIFIC get the set of ids that is in this database + self.database_id_manager(database) + self.database_word_manager(database) + + queries = self.build_queries(text) + + # make list of unique original words + orig_words = OrderedDict() + for word in text.split(" "): + orig_words[word] = False + orig_words = orig_words.keys() + orig_words = {self.clean_text(word) for word in orig_words} + + # order the queries by the amount of words they contain + # we do this because longer queries (more words) are harder to find, but we have many alternatives so we search in a smaller search space + queries_by_size = OrderedDict() + longest_query = max([len(q) for q in queries]) + for query_len in range(1, longest_query + 1): + queries_by_size[query_len] = [q for q in queries if len(q) == query_len] + + # first handle queries of length 1 + query_to_identifier = self.search_size_1(queries_by_size[1], orig_words) + + # DATABASE SPECIFIC ensure all identifiers are in the database + if isinstance(self.database_ids, set): + new_q2i = {} + for word, _ids in query_to_identifier.items(): + keep = set.intersection(set(_ids.keys()), self.database_ids) + new_id_counter = Counter() + for _id in keep: + new_id_counter[_id] = _ids[_id] + if len(new_id_counter) > 0: + new_q2i[word] = new_id_counter + query_to_identifier = new_q2i + + # get all results into a df, we rank further later + all_identifiers = set() + for id_list in [id_list for id_list in query_to_identifier.values()]: + all_identifiers.update(id_list) + search_df = self.df.loc[list(all_identifiers)] + + # now, we search for combinations of query words and get only those identifiers + # we then reduce de search_df further for only those matching identifiers + # we then search the permutations of that set of words + for q_len, query_set in queries_by_size.items(): + if q_len == 1: + # we already did these above + continue + for query in query_set: + # get the intersection of all identifiers + # meaning, a set of identifiers that occur in ALL sets of len(1) for the individual words in the query + # this ensures we only ever search data where ALL items occur to substantially reduce search-space + # finally, make this a Counter (with each item=1) so we can properly weigh things later + query_id_sets = [set(query_to_identifier.get(q_word)) for q_word in query if + query_to_identifier.get(q_word, False)] + if len(query_id_sets) == 0: + continue + query_identifier_set = set.intersection(*query_id_sets) + if len(query_identifier_set) == 0: + # there is no match for this combination of query words, skip + break + + # now we convert the query identifiers to a Counter of 'occurrence', + # where we weigh queries with only original words higher + query_identifiers = Counter() + for identifier in query_identifier_set: + weight = 0 + for query_word in query: + # if the query_word and identifier combination exist get score, otherwise 0 + weight += query_to_identifier.get(query_word, {}).get(identifier, 0) + + query_identifiers[identifier] = weight + + # we now add these identifiers to a counter for this query name, + query_name = " ".join(query) + + weight = self.base_weight * q_len + query_to_identifier[query_name] = self.weigh_identifiers(query_identifiers, weight, Counter()) + + # now search for all permutations of this query combined with a space + query_df = search_df[search_df[self.identifier_name].isin(query_identifiers)] + for query_perm in permutations(query): + query_perm_str = " ".join(query_perm) + if result := self.database_search_cache(self.current_database, query_perm_str): + new_ids = result + else: + mask = self.filter_dataframe(query_df, query_perm_str, search_columns=["query_col"]) + new_df = query_df.loc[mask].reset_index(drop=True) + if len(new_df) == 0: + # there is no match for this permutation of words, skip + continue + new_id_list = new_df[self.identifier_name] + + new_ids = Counter() + for new_id in new_id_list: + new_ids[new_id] = query_identifiers[new_id] + self.database_search_cache(self.current_database, query_perm_str, new_ids) + # we weigh a combination of words that is next also to each other even higher than just the words separately + query_to_identifier[query_name] = self.weigh_identifiers(new_ids, weight, + query_to_identifier[query_name]) + # now finally, move to one object sorted list by highest score + all_identifiers = Counter() + for identifiers in query_to_identifier.values(): + all_identifiers += identifiers + + if return_counter: + return_this = all_identifiers + else: + # now sort on highest weights and make list type + return_this = [identifier[0] for identifier in all_identifiers.most_common()] + if logging: + log.debug( + f"Found {len(all_identifiers)} search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return return_this + + def search(self, text, database: Optional[str] = None) -> list: + """Search the dataframe on this text, return a sorted list of identifiers.""" + t = time() + text = text.strip() + + if len(text) == 0: + log.debug(f"Empty search, returned all items") + return self.df.index.to_list() + + # get the set of ids that is in this database + self.database_id_manager(database) + + fuzzy_identifiers = self.fuzzy_search(text, database=database, logging=False) + if len(fuzzy_identifiers) == 0: + log.debug(f"Found 0 search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return [] + + # take the fuzzy search sub-set of data and search it literally + df = self.df.loc[fuzzy_identifiers].copy() + + literal_identifiers = self.literal_search(text, df) + if len(literal_identifiers) == 0: + log.debug( + f"Found {len(fuzzy_identifiers)} search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return fuzzy_identifiers + + # append any fuzzy identifiers that were not found in the literal search + literal_id_set = set(literal_identifiers) + remaining_fuzzy_identifiers = [ + _id for _id in fuzzy_identifiers if _id not in literal_id_set] + identifiers = literal_identifiers + remaining_fuzzy_identifiers + + log.debug( + f"Found {len(identifiers)} ({len(literal_identifiers)} literal) search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds") + return identifiers diff --git a/activity_browser/layouts/panes/database_products.py b/activity_browser/layouts/panes/database_products.py index 4a49851c5..475824266 100644 --- a/activity_browser/layouts/panes/database_products.py +++ b/activity_browser/layouts/panes/database_products.py @@ -1,5 +1,6 @@ from logging import getLogger from time import time +from collections import Counter import pandas as pd from qtpy import QtWidgets, QtCore, QtGui @@ -56,8 +57,11 @@ def __init__(self, parent, db_name: str): self.table_view = ProductView(self) self.table_view.setModel(self.model) self.model.setDataFrame(self.build_df()) + self.model.has_external_search = True + self.model.external_col_name = db_name - self.search = widgets.ABLineEdit(self) + self.search = widgets.MetaDataAutoCompleteTextEdit(self) + self.search.database_name = db_name self.search.setMaximumHeight(30) self.search.setPlaceholderText("Quick Search") @@ -81,7 +85,11 @@ def connect_signals(self): signals.database.deleted.connect(self.on_database_deleted) self.table_view.filtered.connect(self.search_error) - self.search.textChangedDebounce.connect(self.table_view.setAllFilter) + self.search.textChangedDebounce.connect(self.set_queries) + + def set_queries(self, query: str) -> None: + self.model.set_external_query(query) + self.table_view.setAllFilter(query) def saveState(self): """ @@ -360,6 +368,27 @@ def selected_activities(self) -> [tuple]: items = [i.internalPointer() for i in self.selectedIndexes() if isinstance(i.internalPointer(), ProductItem)] return list({item["activity_key"] for item in items if item["activity_key"] is not None}) + def buildQuery(self) -> str: + queries = ["(index == index)"] + + # query for the column filters + for col in list(self.columnFilters): + if col not in self.model().columns(): + del self.columnFilters[col] + + for col, query in self.columnFilters.items(): + q = f"({col}.astype('str').str.contains('{self.format_query(query)}'))" + queries.append(q) + + # query for the all filter + if self.allFilter.startswith('='): + queries.append(f"({self.allFilter[1:]})") + + query = " & ".join(queries) + log.debug(f"{self.__class__.__name__} built query: {query}") + + return query + class ProductItem(ui.widgets.ABDataItem): """ @@ -454,3 +483,35 @@ def values_from_indices(key: str, indices: list[QtCore.QModelIndex]): continue values.append(item[key]) return values + + def external_search(self, query): + t = time() + results = AB_metadata.db_search(query, database=self.external_col_name, return_counter=True, logging=False) + t2 = time() + + # extract a dict with 'key' as key and 'id' as values from the metadata + result_ids = set(results.keys()) + # extract df with only result IDs and columns 'id' and 'key' + df = AB_metadata.dataframe[AB_metadata.dataframe["id"].isin(result_ids)].loc[:, ["id", "key"]] + df = df.set_index("key", drop=True) + translate_dict = df.to_dict()["id"] + result_keys = set(translate_dict.keys()) + + # convert the metadata id scores to row id scores + row_scores = Counter() + match_df = self.dataframe[self.dataframe["activity_key"].isin(result_keys) | self.dataframe["product_key"].isin(result_keys)] + cols = ["activity_key", "product_key"] + match_df = match_df.loc[:, cols] + for row in match_df.itertuples(): + act_score = results.get(translate_dict.get(row[1]), 0) + prd_score = results.get(translate_dict.get(row[2]), 0) + row_scores[row[0]] = act_score + prd_score + + # finally only return the indices + sorted_indices = [identifier for identifier, _ in row_scores.most_common()] + log.debug( + f"ProductModel search in '{self.external_col_name}' ({len(self.dataframe)} items) " + f"found {len(sorted_indices)} results " + f"for '{query}' in {time() - t:.2f} seconds ({t2 - t:.2f}s actual search, {time() - t2:.2f}s reorder for table)" + ) + return sorted_indices diff --git a/activity_browser/ui/widgets/__init__.py b/activity_browser/ui/widgets/__init__.py index f8c0c439b..89d2c30ca 100644 --- a/activity_browser/ui/widgets/__init__.py +++ b/activity_browser/ui/widgets/__init__.py @@ -1,8 +1,8 @@ from .abstract_pane import ABAbstractPane from .comparison_switch import SwitchComboBox from .cutoff_menu import CutoffMenu -from .line_edit import (ABLineEdit, SignalledComboEdit, SignalledLineEdit, - SignalledPlainTextEdit) +from .line_edit import ABLineEdit, SignalledComboEdit, SignalledLineEdit, SignalledPlainTextEdit +from .text_edit import MetaDataAutoCompleteTextEdit from .treeview import ABTreeView from .item_model import ABItemModel from .item import ABAbstractItem, ABBranchItem, ABDataItem diff --git a/activity_browser/ui/widgets/item_model.py b/activity_browser/ui/widgets/item_model.py index 2772c8f0f..3a2fd6d74 100644 --- a/activity_browser/ui/widgets/item_model.py +++ b/activity_browser/ui/widgets/item_model.py @@ -26,6 +26,9 @@ def __init__(self, parent=None, dataframe=None): self.sort_column: int = 0 # column that is currently sorted self.sort_order: Qt.SortOrder = Qt.SortOrder.AscendingOrder self._query = "" # Pandas query currently applied to the dataframe + self.has_external_search = False + self._external_query = "" + self.external_col_name = "" self.setDataFrame(self.dataframe) @@ -192,17 +195,22 @@ def endResetModel(self): # apply any queries to the dataframe if q := self.query(): - df = self.dataframe.query(q).reset_index(drop=True).copy() + df = self.dataframe.copy() + if self.has_external_search and self._external_query != "": + indices = self.external_search(self._external_query) + df = df.loc[indices] + df = df.query(q).reset_index(drop=True) else: df = self.dataframe.copy() - if not self.sort_column > len(self.columns()) - 1: - # apply the sorting - df.sort_values( - by=self.columns()[self.sort_column], - ascending=(self.sort_order == Qt.SortOrder.AscendingOrder), - inplace=True, ignore_index=True - ) + if not (self.has_external_search and self._external_query != ""): + if not self.sort_column > len(self.columns()) - 1: + # apply the sorting + df.sort_values( + by=self.columns()[self.sort_column], + ascending=(self.sort_order == Qt.SortOrder.AscendingOrder), + inplace=True, ignore_index=True + ) # rebuild the ABItem tree self.root = self.branchItemClass("root") @@ -271,11 +279,17 @@ def setQuery(self, query: str): self._query = query self.endResetModel() + def set_external_query(self, query: str): + if not query.startswith("="): + self._external_query = query + else: + self._external_query = "" + + def external_search(self, query): + NotImplementedError + def hasChildren(self, parent: QtCore.QModelIndex): item = parent.internalPointer() if isinstance(item, ABAbstractItem): return item.has_children() return super().hasChildren(parent) - - - diff --git a/activity_browser/ui/widgets/line_edit.py b/activity_browser/ui/widgets/line_edit.py index 655d269d5..427663938 100644 --- a/activity_browser/ui/widgets/line_edit.py +++ b/activity_browser/ui/widgets/line_edit.py @@ -1,7 +1,6 @@ from qtpy import QtWidgets from qtpy.QtCore import QTimer, Slot, Signal, SignalInstance from qtpy.QtGui import QTextFormat -from qtpy.QtWidgets import QCompleter class ABLineEdit(QtWidgets.QLineEdit): @@ -111,12 +110,3 @@ def focusOutEvent(self, event): self._before = after actions.ActivityModify.run(self._key, self._field, after) super(SignalledComboEdit, self).focusOutEvent(event) - - -class AutoCompleteLineEdit(QtWidgets.QLineEdit): - """Line Edit with a completer attached""" - - def __init__(self, items: list[str], parent=None): - super().__init__(parent=parent) - completer = QCompleter(items, self) - self.setCompleter(completer) diff --git a/activity_browser/ui/widgets/text_edit.py b/activity_browser/ui/widgets/text_edit.py new file mode 100644 index 000000000..9daf4fabe --- /dev/null +++ b/activity_browser/ui/widgets/text_edit.py @@ -0,0 +1,247 @@ +from qtpy import QtWidgets +from qtpy.QtCore import QTimer, Signal, SignalInstance, QStringListModel, Qt +from qtpy.QtGui import QSyntaxHighlighter, QTextCharFormat, QTextDocument, QFont +from qtpy.QtWidgets import QCompleter, QStyledItemDelegate, QStyle + +from activity_browser.bwutils import AB_metadata + + +class UnknownWordHighlighter(QSyntaxHighlighter): + def __init__(self, parent: QTextDocument, known_words: set): + super().__init__(parent) + self.known_words = known_words + + # define the format for unknown words + self.unknown_format = QTextCharFormat() + self.unknown_format.setUnderlineStyle(QTextCharFormat.SpellCheckUnderline) + self.unknown_format.setUnderlineColor(Qt.red) + + def highlightBlock(self, text: str): + if text.startswith("="): + return + words = text.split() + index = 0 + for word in words: + word_len = len(word) + if word and word not in self.known_words: + self.setFormat(index, word_len, self.unknown_format) + index += word_len + 1 # +1 for the space + + +class AutoCompleteDelegate(QStyledItemDelegate): + def __init__(self, parent=None): + super().__init__(parent) + self.current_word_index = -1 + + def paint(self, painter, option, index): + text = index.data(Qt.DisplayRole) + + painter.save() + + # Draw selection background if selected + if option.state & QStyle.State_Selected: + painter.fillRect(option.rect, option.palette.highlight()) + painter.setPen(option.palette.highlightedText().color()) + else: + painter.setPen(option.palette.text().color()) + + # Split text into words and draw each with appropriate font + words = text.split(" ") + x = option.rect.x() + y = option.rect.y() + spacing = 4 # space between words + font = option.font + metrics = painter.fontMetrics() + + for i, word in enumerate(words): + word_font = QFont(font) + if i+1 == self.current_word_index: + word_font.setBold(True) + painter.setFont(word_font) + + word_width = metrics.horizontalAdvance(word) + painter.drawText(x, y + metrics.ascent() + (option.rect.height() - metrics.height()) // 2, word) + x += word_width + spacing + painter.restore() + + +class ABTextEdit(QtWidgets.QTextEdit): + textChangedDebounce: SignalInstance = Signal(str) + _debounce_ms = 250 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._debounce_timer = QTimer(self, singleShot=True) + + self.textChanged.connect(self._set_debounce) + self._debounce_timer.timeout.connect(self._emit_debounce) + + def _set_debounce(self): + self._debounce_timer.setInterval(self._debounce_ms) + self._debounce_timer.start() + + def _emit_debounce(self): + self.textChangedDebounce.emit(self.toPlainText()) + + def debounce(self): + return self._debounce_ms + + def setDebounce(self, ms: int): + self._debounce_ms = ms + + +class ABAutoCompleTextEdit(ABTextEdit): + def __init__(self, parent=None, highlight_unknown=False): + super().__init__(parent=parent) + self.auto_complete_word = "" + + # autocompleter settings + self.model = QStringListModel() + self.completer = QCompleter(self.model) + self.completer.setWidget(self) + self.popup = self.completer.popup() + self.delegate = AutoCompleteDelegate(self.popup) # set custom delegate to bold the current word + self.popup.setItemDelegate(self.delegate) + self.popup.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.completer.setPopup(self.popup) + self.completer.setCompletionMode(QCompleter.UnfilteredPopupCompletion) # allow all items in popup list + self.completer.activated.connect(self._insert_auto_complete) + + self.textChanged.connect(self._sanitize_input) + if highlight_unknown: + self.highlighter = UnknownWordHighlighter(self.document(), set()) + self.cursorPositionChanged.connect(self._set_autocomplete_items) + + def keyPressEvent(self, event): + key = event.key() + + if key in (Qt.Key_Enter, Qt.Key_Return, Qt.Key_Tab): + # insert an autocomplete item + # capture enter/return/tab key + index = self.popup.currentIndex() + completion_text = index.data(Qt.DisplayRole) + self.completer.activated.emit(completion_text) + return + elif key in (Qt.Key_Space,): + self.popup.close() + + super().keyPressEvent(event) + + # trigger on text input keys + if event.text() or key in (Qt.LeftArrow, Qt.RightArrow): # filters out non-text keys except l/r arrows + self._set_autocomplete_items() + + def _sanitize_input(self): + raise NotImplementedError + + def _set_autocomplete_items(self): + raise NotImplementedError + + def _insert_auto_complete(self, completion): + cursor = self.textCursor() + position = cursor.position() + completion = completion + " " # add space to end of new text + + # find where to put cursor back + new_position = position + while new_position < len(completion) and completion[new_position] != " ": + new_position += 1 + new_position += 1 # add one char for space + + # set new text from completion + self.blockSignals(True) + self.clear() + self.setText(completion) + # set the cursor location + cursor.setPosition(min(new_position, len(completion))) + self.setTextCursor(cursor) + self.blockSignals(False) + + # house keeping + self._emit_debounce() + self.popup.close() + self.auto_complete_word = "" + self.model.setStringList([]) + + +class MetaDataAutoCompleteTextEdit(ABAutoCompleTextEdit): + """TextEdit with MetaDataStore completer attached.""" + def __init__(self, parent=None): + super().__init__(parent=parent, highlight_unknown=True) + self.database_name = "" + + def _sanitize_input(self): + self._debounce_timer.stop() + text = self.toPlainText() + clean_text = AB_metadata.search_engine.ONE_SPACE_PATTERN.sub(" ", text) + + if clean_text != text: + cursor = self.textCursor() + position = cursor.position() + self.blockSignals(True) + self.clear() + self.insertPlainText(clean_text) + self.blockSignals(False) + cursor.setPosition(min(position, len(clean_text))) + self.setTextCursor(cursor) + + known_words = set() + for identifier in AB_metadata.search_engine.database_id_manager(self.database_name): + known_words.update(AB_metadata.search_engine.identifier_to_word[identifier].keys()) + self.highlighter.known_words = known_words + + if len(text) == 0: + self.popup.close() + self._set_debounce() + + def _set_autocomplete_items(self): + text = self.toPlainText() + if text.startswith("="): + self.model.setStringList([]) + self.auto_complete_word = "" + self.popup.close() + return + + # find the start and end of the word under the cursor + cursor = self.textCursor() + position = cursor.position() + start = position + while start > 0 and text[start - 1] != " ": + start -= 1 + end = position + while end < len(text) and text[end] != " ": + end += 1 + current_word = text[start:end] + if not current_word: + self.model.setStringList([]) + self.popup.close() + self.auto_complete_word = "" + return + if self.auto_complete_word == current_word: + # avoid unnecessary auto_complete calls if the current word didnt change + return + self.auto_complete_word = current_word + + context = set((text[:start] + text[end:]).split(" ")) + self.delegate.current_word_index = len(text[:start].split(" ")) # current word index for bolding + # get suggestions for the current word + suggestions = AB_metadata.auto_complete(current_word, context=context, database=self.database_name) + suggestions = suggestions[:6] # at most 6, though we should get ~3 usually + # replace the current word with each alternative + items = [] + for alt in suggestions: + new_text = text[:start] + alt + text[end:] + items.append(new_text) + if len(items) == 0: + self.popup.close() + return + + self.model.setStringList(items) + # set correct height now that we have data + max_height = max( + 20, + self.popup.sizeHintForRow(0) * 3 + 2 * self.popup.frameWidth() + ) + self.popup.setMaximumHeight(max_height) + self.completer.complete() diff --git a/activity_browser/ui/widgets/treeview.py b/activity_browser/ui/widgets/treeview.py index 89cb49aa0..36222b1bb 100644 --- a/activity_browser/ui/widgets/treeview.py +++ b/activity_browser/ui/widgets/treeview.py @@ -6,6 +6,7 @@ from qtpy.QtCore import Qt from .item_model import ABItemModel +from activity_browser.ui import widgets log = getLogger(__name__) @@ -25,11 +26,11 @@ def __init__(self, pos: QtCore.QPoint, view: "ABTreeView"): col_index = view.columnAt(pos.x()) col_name = model.columns()[col_index] - search_box = QtWidgets.QLineEdit(self) + search_box = widgets.ABLineEdit(self) search_box.setText(view.columnFilters.get(col_name, "")) search_box.setPlaceholderText("Search") search_box.selectAll() - search_box.textChanged.connect(lambda query: view.setColumnFilter(col_name, query)) + search_box.textChangedDebounce.connect(lambda query: view.setColumnFilter(col_name, query)) widget_action = QtWidgets.QWidgetAction(self) widget_action.setDefaultWidget(search_box) self.addAction(widget_action) diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 000000000..0c40f4340 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,243 @@ +import pytest +import pandas as pd +from activity_browser.bwutils import SearchEngine + + +def data_for_test(): + return pd.DataFrame([ + ["a", "coal production", "coal"], + ["b", "coal production", "something"], + ["c", "coal production", "coat"], + ["d", "coal hello production", "something"], + ["e", "dont zzfind me", "hello world"], + ["f", "coat", "zzanother word"], + ["g", "coalispartofthisword", "things"], + ["h", "coal", "coal"], + ], + columns = ["id", "col1", "col2"]) + + +# test standard init +def test_search_init(): + """Do initialization tests.""" + df = data_for_test() + + # init search class with non-existent identifier col and fail + with pytest.raises(Exception): + _ = SearchEngine(df, identifier_name="non_existent_col_name") + # init search class with non-unique identifiers and fail + df2 = df.copy() + df2.iloc[0, 0] = "b" + with pytest.raises(Exception): + _ = SearchEngine(df2, identifier_name="id") + # init search class correctly + se = SearchEngine(df, identifier_name="id") + + +# test internals +def test_reverse_dict(): + """Do test to reverse the special Counter dict.""" + df = data_for_test() + se = SearchEngine(df, identifier_name="id") + + # reverse once and verify + w2i = se.reverse_dict_many_to_one(se.identifier_to_word) + assert w2i == se.word_to_identifier + + # reverse again and verify is same as original + i2w = se.reverse_dict_many_to_one(w2i) + assert i2w == se.identifier_to_word + + +def test_string_distance(): + """Do tests specifically for string distance function.""" + df = data_for_test() + se = SearchEngine(df, identifier_name="id") + + # same word + assert se.osa_distance("coal", "coal") == 0 + # empty string is length of other word + assert se.osa_distance("coal", "") == 4 + + # insert + assert se.osa_distance("coal", "coa") == 1 + # delete + assert se.osa_distance("coal", "coall") == 1 + # substitute + assert se.osa_distance("coal", "coat") == 1 + # transpose + assert se.osa_distance("coal", "cola") == 1 + + # longer edit distance + assert se.osa_distance("coal", "chocolate") == 6 + # reverse order gives same result + assert se.osa_distance("coal", "chocolate") == se.osa_distance("chocolate", "coal") + # cutoff + assert se.osa_distance("coal", "chocolate", cutoff=5, cutoff_return=1000) == 1000 + assert se.osa_distance("coal", "chocolate", cutoff=6, cutoff_return=1000) == 1000 + assert se.osa_distance("coal", "chocolate", cutoff=7, cutoff_return=1000) == 6 + # length cutoff + assert se.osa_distance("coal", "coallongword") == 8 + assert se.osa_distance("coal", "coallongword", cutoff=5, cutoff_return=1000) == 1000 + + # two entirely different words (test of early stopping) + assert se.osa_distance("brown", "jumped") == 6 + assert se.osa_distance("brown", "jumped", cutoff=6, cutoff_return=1000) == 1000 + assert se.osa_distance("brown", "jumped", cutoff=7, cutoff_return=1000) == 6 + + +# test functionality +def test_in_index(): + """Do checks for checking if word is in the index.""" + df = data_for_test() + se = SearchEngine(df, identifier_name="id") + + # use string with space + with pytest.raises(Exception): + se.word_in_index("coal and space") + + assert se.word_in_index("coal") + assert not se.word_in_index("coa") + + +def test_spellcheck(): + """Do checks spell checking.""" + df = data_for_test() + se = SearchEngine(df, identifier_name="id") + + checked = se.spell_check("coa productions something flintstones") + # coal HAS to be first, it is found more often in the data + assert checked["coa"] == ["coal", "coat"] + # find production + assert checked["productions"] == ["production"] + # should be empty as there is no alternative (but this word occurs) + assert checked["something"] == [] + # should be empty as there is no alternative (does not exist) + assert checked["flintstones"] == [] + + +def test_search_base(): + """Do checks for correct search ranking.""" + + df = data_for_test() + + # init search class and two searches + se = SearchEngine(df, identifier_name="id") + # do search on specific term + assert se.search("coal") == ["a", "h", "c", "b", "d", "g", "f"] + # do search on other term + assert se.search("coal production") == ["a", "c", "b", "d", "h", "f", "g"] + # do search on typo + assert se.search("cola") == ["a", "c", "h", "b", "d", "f", "g"] + # do search on longer typo + assert se.search("cola production") == ["c", "a", "b", "d", "h", "f", "g"] + # do search on something we will definitely not find + assert se.search("dontFindThis") == [] + + # init search class with 1 col searchable + se = SearchEngine(df, identifier_name="id", searchable_columns=["col2"]) + assert se.search("coal") == ["a", "h", "c"] + + +def test_search_add_identifier(): + """Do tests for adding identifier.""" + df = data_for_test() + + # create base item to add + new_base_item = pd.DataFrame([ + ["i", "coal production", "coal production"], + ], + columns=["id", "col1", "col2"]) + + # use existing identifier and fail + se = SearchEngine(df, identifier_name="id") + wrong_id = new_base_item.copy() + wrong_id.iloc[0, 0] = "a" + with pytest.raises(Exception): + se.add_identifier(wrong_id) + + # add data without identifier column + se = SearchEngine(df, identifier_name="id") + no_id = new_base_item.copy() + del no_id["id"] + with pytest.raises(Exception): + se.add_identifier(no_id) + + # use column more (and find data in new col) + se = SearchEngine(df, identifier_name="id") + col_more = new_base_item.copy() + col_more["col3"] = ["potatoes"] + se.add_identifier(col_more) + assert se.search("potatoes") == ["i"] + + # use column less (should be filled with empty string) + se = SearchEngine(df, identifier_name="id") + col_less = new_base_item.copy() + del col_less["col2"] + se.add_identifier(col_less) + assert se.df.loc["i", "col2"] == "" + + # do search, add item and verify results are different + se = SearchEngine(df, identifier_name="id") + assert se.search("coal production") == ["a", "c", "b", "d", "h", "f", "g"] + se.add_identifier(new_base_item) + assert se.search("coal production") == ["i", "a", "c", "b", "d", "h", "f", "g"] + + +def test_search_remove_identifier(): + """Do tests for removing identifier.""" + df = data_for_test() + + # use non-existent identifier and fail + se = SearchEngine(df, identifier_name="id") + with pytest.raises(Exception): + se.remove_identifier(identifier="i") + + # do search, remove item and verify results are different + se = SearchEngine(df, identifier_name="id") + assert se.search("coal production") == ["a", "c", "b", "d", "h", "f", "g"] + se.remove_identifier(identifier="a") + assert se.search("coal production") == ["c", "b", "d", "h", "f", "g"] + + # now search on something only in a column we later remove + assert se.search("find") == ["e"] + se.remove_identifier(identifier="e") + assert se.search("find") == [] + + +def test_search_change_identifier(): + """Do tests for changing identifier.""" + df = data_for_test() + + # create base item to add + edit_data = pd.DataFrame([ + ["a", "cant find me anymore", "something different"], + ], + columns=["id", "col1", "col2"]) + + # use non-existent identifier and fail + se = SearchEngine(df, identifier_name="id") + missing_id = edit_data.copy() + missing_id["id"] = ["i"] + with pytest.raises(Exception): + se.change_identifier(identifier="i", data=missing_id) + + # use mismatched identifier and fail + se = SearchEngine(df, identifier_name="id") + wrong_id = edit_data.copy() + wrong_id["id"] = ["i"] + with pytest.raises(Exception): + se.change_identifier(identifier="a", data=wrong_id) + + # do search, change item and verify results are different + se = SearchEngine(df, identifier_name="id") + assert se.search("coal production") == ["a", "c", "b", "d", "h", "f", "g"] + se.change_identifier(identifier="a", data=edit_data) + assert se.search("coal production") == ["c", "b", "d", "h", "f", "g"] + # now change the same item partially and verify results are different + new_edit_data = pd.DataFrame([ + ["a", "coal"], + ], + columns=["id", "col1"]) + se.change_identifier(identifier="a", data=new_edit_data) + assert se.search("coal production") == ["c", "b", "d", "h", "a", "f", "g"]