diff --git a/src/book.py b/src/book.py index a41e0c57..79dc8391 100644 --- a/src/book.py +++ b/src/book.py @@ -26,6 +26,7 @@ import misc import transaction as tr from core import kraken_asset_map +from database import set_price_db from price_data import PriceData log = logging.getLogger(__name__) @@ -284,7 +285,7 @@ def _read_coinbase(self, file_path: Path) -> None: assert _currency_spot == "EUR" # Save price in our local database for later. - self.price_data.set_price_db(platform, coin, "EUR", utc_time, eur_spot) + set_price_db(platform, coin, "EUR", utc_time, eur_spot) if operation == "Convert": # Parse change + coin from remark, which is @@ -317,7 +318,7 @@ def _read_coinbase(self, file_path: Path) -> None: ) # Save convert price in local database, too. - self.price_data.set_price_db( + set_price_db( platform, convert_coin, "EUR", utc_time, convert_eur_spot ) else: @@ -721,9 +722,9 @@ def _read_bitpanda_pro_trades(self, file_path: Path) -> None: # Save price in our local database for later. price = misc.force_decimal(_price) - self.price_data.set_price_db(platform, coin, "EUR", utc_time, price) + set_price_db(platform, coin, "EUR", utc_time, price) if best_price: - self.price_data.set_price_db( + set_price_db( platform, "BEST", "EUR", @@ -862,9 +863,7 @@ def _read_bitpanda(self, file_path: Path) -> None: change_fiat = misc.force_decimal(amount_fiat) # Save price in our local database for later. price = misc.force_decimal(asset_price) - self.price_data.set_price_db( - platform, asset, config.FIAT.upper(), utc_time, price - ) + set_price_db(platform, asset, config.FIAT.upper(), utc_time, price) if change < 0: log.error( @@ -1048,6 +1047,66 @@ def detect_exchange(self, file_path: Path) -> Optional[str]: return None + def get_price_from_csv(self) -> None: + """Calculate coin prices from buy/sell operations in CSV files. + + When exactly one buy and sell happend at the exact same time, + these two operations might belong together and we can calculate + the paid price for this transaction. + """ + # Group operations by platform. + for platform, platform_operations in misc.group_by( + self.operations, "platform" + ).items(): + # Group operations by time. + # Look at all operations which happend at the same time. + for timestamp, time_operations in misc.group_by( + platform_operations, "utc_time" + ).items(): + buytr = selltr = None + buycount = sellcount = 0 + + # Extract the buy and sell operation. + for operation in time_operations: + if isinstance(operation, tr.Buy): + buytr = operation + buycount += 1 + elif isinstance(operation, tr.Sell): + selltr = operation + sellcount += 1 + + # Skip the operations of this timestamp when there aren't + # exactly one buy and one sell operation. + # We can only match the buy and sell operations, when there + # are exactly one buy and one sell operation. + if not (buycount == 1 and sellcount == 1): + continue + + assert isinstance(timestamp, datetime.datetime) + assert isinstance(buytr, tr.Buy) + assert isinstance(selltr, tr.Sell) + + # Price definition example for buying BTC with EUR: + # Symbol: BTCEUR + # coin: BTC (buytr.coin) + # reference coin: EUR (selltr.coin) + # price = traded EUR / traded BTC + price = decimal.Decimal(selltr.change / buytr.change) + + logging.debug( + f"Adding {buytr.coin}/{selltr.coin} price from CSV: " + f"{price} for {platform} at {timestamp}" + ) + + set_price_db( + platform, + buytr.coin, + selltr.coin, + timestamp, + price, + overwrite=True, + ) + def read_file(self, file_path: Path) -> None: """Import transactions form an account statement. @@ -1060,6 +1119,7 @@ def read_file(self, file_path: Path) -> None: assert file_path.is_file() if exchange := self.detect_exchange(file_path): + try: read_file = getattr(self, f"_read_{exchange}") except AttributeError: diff --git a/src/database.py b/src/database.py new file mode 100644 index 00000000..211b57d5 --- /dev/null +++ b/src/database.py @@ -0,0 +1,314 @@ +import datetime +import decimal +import logging +import sqlite3 +from pathlib import Path +from typing import Optional, Tuple + +import config +import misc + +log = logging.getLogger(__name__) + + +def get_version(db_path: Path) -> int: + """Get database version from a database file. + + If the version table is missing, one is created. + + Args: + db_path (str): Path to database file. + + Raises: + RuntimeError: The database version is ambiguous. + + Returns: + int: Version of database file. + """ + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + try: + cur.execute("SELECT version FROM §version;") + versions = [int(v[0]) for v in cur.fetchall()] + except sqlite3.OperationalError as e: + if str(e) == "no such table: §version": + # The §version table doesn't exist. Create one. + cur.execute("CREATE TABLE §version(version INT);") + cur.execute("INSERT INTO §version (version) VALUES (0);") + return 0 + else: + raise e + + if len(versions) == 1: + version = versions[0] + return version + else: + raise RuntimeError( + f"The database version of the file `{db_path.name}` is ambigious. " + f"The table `§version` should have one entry, but has {len(versions)}." + ) + + +def get_price_db( + db_path: Path, + tablename: str, + utc_time: datetime.datetime, +) -> Optional[decimal.Decimal]: + """Try to retrieve the price from our local database. + + Args: + db_path (Path) + tablename (str) + utc_time (datetime.datetime) + + Returns: + Optional[decimal.Decimal]: Price. + """ + if db_path.is_file(): + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + query = f"SELECT price FROM `{tablename}` WHERE utc_time=?;" + + try: + cur.execute(query, (utc_time,)) + except sqlite3.OperationalError as e: + if str(e) == f"no such table: {tablename}": + return None + raise e + + if prices := cur.fetchone(): + return misc.force_decimal(prices[0]) + + return None + + +def mean_price_db( + db_path: Path, + tablename: str, + utc_time: datetime.datetime, +) -> decimal.Decimal: + """Try to retrieve the price right before and after `utc_time` + from our local database. + + Return 0 if the price could not be estimated. + The function does not check, if a price for `utc_time` exists. + + Args: + db_path (Path) + tablename (str) + utc_time (datetime.datetime) + + Returns: + decimal.Decimal: Price. + """ + if db_path.is_file(): + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + + before_query = ( + f"SELECT utc_time, price FROM `{tablename}` " + f"WHERE utc_time 0 " + "ORDER BY utc_time DESC " + "LIMIT 1" + ) + try: + cur.execute(before_query, (utc_time,)) + except sqlite3.OperationalError as e: + if str(e) == f"no such table: {tablename}": + return decimal.Decimal() + raise e + if result := cur.fetchone(): + before_time = misc.parse_iso_timestamp_to_decimal_timestamp(result[0]) + before_price = misc.force_decimal(result[1]) + else: + return decimal.Decimal() + + after_query = ( + f"SELECT utc_time, price FROM `{tablename}` " + f"WHERE utc_time>? AND price > 0 " + "ORDER BY utc_time ASC " + "LIMIT 1" + ) + try: + cur.execute(after_query, (utc_time,)) + except sqlite3.OperationalError as e: + if str(e) == f"no such table: {tablename}": + return decimal.Decimal() + raise e + if result := cur.fetchone(): + after_time = misc.parse_iso_timestamp_to_decimal_timestamp(result[0]) + after_price = misc.force_decimal(result[1]) + else: + return decimal.Decimal() + + if before_price and after_price: + d_utc_time = misc.to_decimal_timestamp(utc_time) + # Linear gradiant between the neighbored transactions. + m = (after_price - before_price) / (after_time - before_time) + price = before_price + (d_utc_time - before_time) * m + return price + + return decimal.Decimal() + + +def __delete_price_db( + db_path: Path, + tablename: str, + utc_time: datetime.datetime, +) -> None: + """Delete price from database + + Args: + db_path (Path) + tablename (str) + utc_time (datetime.datetime) + """ + + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + query = f"DELETE FROM `{tablename}` WHERE utc_time=?;" + cur.execute(query, (utc_time,)) + conn.commit() + + +def __set_price_db( + db_path: Path, + tablename: str, + utc_time: datetime.datetime, + price: decimal.Decimal, +) -> None: + """Write price to database. + + Create database/table if necessary. + + Args: + db_path (Path) + tablename (str) + utc_time (datetime.datetime) + price (decimal.Decimal) + """ + if not db_path.exists(): + from patch_database import create_new_database + + create_new_database(db_path) + + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + query = f"INSERT INTO `{tablename}`" "('utc_time', 'price') VALUES (?, ?);" + try: + cur.execute(query, (utc_time, str(price))) + except sqlite3.OperationalError as e: + if str(e) == f"no such table: {tablename}": + create_query = ( + f"CREATE TABLE `{tablename}`" + "(utc_time DATETIME PRIMARY KEY, " + "price VARCHAR(255) NOT NULL);" + ) + cur.execute(create_query) + cur.execute(query, (utc_time, str(price))) + else: + raise e + conn.commit() + + +def set_price_db( + platform: str, + coin: str, + reference_coin: str, + utc_time: datetime.datetime, + price: decimal.Decimal, + db_path: Optional[Path] = None, + overwrite: bool = False, +) -> None: + """Write price to database. + + Tries to insert a historical price into the local database. + + A warning will be raised, if there is already a different price. + + Args: + platform (str) + coin (str) + reference_coin (str) + utc_time (datetime.datetime) + price (decimal.Decimal) + """ + assert coin != reference_coin + + coin_a, coin_b, inverted = _sort_pair(coin, reference_coin) + tablename = get_tablename(coin_a, coin_b) + + if inverted: + price = misc.reciprocal(price) + + if db_path is None and platform: + db_path = get_db_path(platform) + + assert isinstance(db_path, Path), "no db path given" + + try: + __set_price_db(db_path, tablename, utc_time, price) + except sqlite3.IntegrityError as e: + if str(e) == f"UNIQUE constraint failed: {tablename}.utc_time": + # Trying to add an already existing price in db. + if overwrite: + # Overwrite price. + log.debug( + "Overwriting price information for " + f"{platform=}, {tablename=} at {utc_time=}" + ) + __delete_price_db(db_path, tablename, utc_time) + __set_price_db(db_path, tablename, utc_time, price) + else: + # Check price from db and issue warning, if prices do not match. + price_db = get_price_db(db_path, tablename, utc_time) + if price != price_db: + log.warning( + "Tried to write price to database, " + "but a different price exists already: " + f"{platform=}, {tablename=}, {utc_time=}, " + f"{price=} != {price_db=}" + ) + else: + raise e + + +def _sort_pair(coin: str, reference_coin: str) -> Tuple[str, str, bool]: + """Sort the coin pair in alphanumerical order. + + Args: + coin (str) + reference_coin (str) + + Returns: + Tuple[str, str, bool]: First coin, second coin, inverted + """ + if inverted := coin > reference_coin: + coin_a = reference_coin + coin_b = coin + else: + coin_a = coin + coin_b = reference_coin + return coin_a, coin_b, inverted + + +def get_tablename(coin: str, reference_coin: str) -> str: + return f"{coin}/{reference_coin}" + + +def get_tablenames_from_db(cur: sqlite3.Cursor) -> list[str]: + cur.execute("SELECT name FROM sqlite_master WHERE type='table';") + tablenames = [result[0] for result in cur.fetchall()] + return tablenames + + +def get_db_path(platform: str) -> Path: + return Path(config.DATA_PATH, f"{platform}.db") + + +def check_database_or_create(platform: str) -> None: + from patch_database import create_new_database + + db_path = get_db_path(platform) + if not db_path.exists(): + create_new_database(db_path) diff --git a/src/main.py b/src/main.py index e6d13374..842f1da7 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ import log_config # noqa: F401 from book import Book +from patch_database import patch_databases from price_data import PriceData from taxman import Taxman @@ -25,6 +26,8 @@ def main() -> None: + patch_databases() + price_data = PriceData() book = Book(price_data) taxman = Taxman(book, price_data) @@ -35,6 +38,7 @@ def main() -> None: log.warning("Stopping CoinTaxman.") return + book.get_price_from_csv() taxman.evaluate_taxation() taxman.export_evaluation_as_csv() taxman.print_evaluation() diff --git a/src/misc.py b/src/misc.py index cb4846ca..32f80cdb 100644 --- a/src/misc.py +++ b/src/misc.py @@ -170,7 +170,7 @@ def parse_iso_timestamp_to_decimal_timestamp(d: str) -> decimal.Decimal: return to_decimal_timestamp(datetime.datetime.fromisoformat(d)) -def group_by(lst: L, key: str) -> dict[str, L]: +def group_by(lst: L, key: str) -> dict[Any, L]: """Group a list of objects by `key`. Args: @@ -178,7 +178,7 @@ def group_by(lst: L, key: str) -> dict[str, L]: key (str) Returns: - dict[str, list]: Dict with different `key`as keys. + dict[Any, list]: Dict with different `key`as keys. """ d = collections.defaultdict(list) for e in lst: diff --git a/src/patch_database.py b/src/patch_database.py new file mode 100644 index 00000000..35c40854 --- /dev/null +++ b/src/patch_database.py @@ -0,0 +1,223 @@ +# CoinTaxman +# Copyright (C) 2021 Carsten Docktor + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import datetime +import decimal +import logging +import sqlite3 +import sys +from inspect import getmembers, isfunction +from pathlib import Path +from typing import Iterator, Optional + +import config +from database import set_price_db + +FUNC_PREFIX = "__patch_" +log = logging.getLogger(__name__) + + +def get_version(db_path: Path) -> int: + """Get database version from a database file. + + If the version table is missing, one is created. + + Args: + db_path (str): Path to database file. + + Raises: + RuntimeError: The database version is ambiguous. + + Returns: + int: Version of database file. + """ + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + try: + cur.execute("SELECT version FROM §version;") + versions = [int(v[0]) for v in cur.fetchall()] + except sqlite3.OperationalError as e: + if str(e) == "no such table: §version": + # The §version table doesn't exist. Create one. + update_version(db_path, 0) + return 0 + else: + raise e + + if len(versions) == 1: + version = versions[0] + return version + else: + raise RuntimeError( + f"The database version of the file `{db_path.name}` is ambigious. " + f"The table `§version` should have one entry, but has {len(versions)}." + ) + + +def update_version(db_path: Path, version: int) -> None: + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + + try: + cur.execute("DELETE FROM §version;") + except sqlite3.OperationalError as e: + if str(e) == "no such table: §version": + cur.execute("CREATE TABLE §version(version INT);") + else: + raise e + + assert isinstance(version, int) + log.debug(f"Updating version of {db_path} to {version}") + cur.execute(f"INSERT INTO §version (version) VALUES ({version});") + + +def create_new_database(db_path: Path) -> None: + assert not db_path.exists() + version = get_latest_version() + update_version(db_path, version) + + +def get_patch_func_version(func_name: str) -> int: + assert func_name.startswith( + FUNC_PREFIX + ), f"Patch function `{func_name}` should start with {FUNC_PREFIX}." + len_func_prefix = len(FUNC_PREFIX) + version_str = func_name[len_func_prefix:] + version = int(version_str) + return version + + +def get_tablenames(cur: sqlite3.Cursor, ignore_version_table: bool = True) -> list[str]: + query = "SELECT name FROM sqlite_master WHERE type='table'" + if ignore_version_table: + query += " AND name != '§version'" + cur.execute(f"{query};") + tablenames = [result[0] for result in cur.fetchall()] + return tablenames + + +def __patch_001(db_path: Path) -> None: + """Convert prices from float to string + + Args: + db_path (Path) + """ + with sqlite3.connect(db_path) as conn: + query = "SELECT name,sql FROM sqlite_master WHERE type='table'" + cur = conn.execute(query) + for tablename, sql in cur.fetchall(): + if "price str" not in sql.lower(): + query = f""" + CREATE TABLE "sql_temp_table" ( + "utc_time" DATETIME PRIMARY KEY, + "price" STR NOT NULL + ); + INSERT INTO "sql_temp_table" ("price","utc_time") + SELECT "price","utc_time" FROM "{tablename}"; + DROP TABLE "{tablename}"; + ALTER TABLE "sql_temp_table" "{tablename}"; + """ + + +def __patch_002(db_path: Path) -> None: + """Group tablenames, so that the symbols are alphanumerical. + + Args: + db_path (Path) + """ + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + tablenames = get_tablenames(cur) + # Iterate over all tables. + for tablename in tablenames: + base_asset, quote_asset = tablename.split("/") + + # Adjust the order, when the symbols aren't ordered alphanumerical. + if base_asset > quote_asset: + + # Query all prices from the table. + cur = conn.execute(f"Select utc_time, price FROM `{tablename}`;") + + for _utc_time, _price in list(cur.fetchall()): + # Convert the data. + # Try non-fractional seconds first, then fractional seconds + try: + utc_time = datetime.datetime.strptime( + _utc_time, "%Y-%m-%d %H:%M:%S%z" + ) + except ValueError: + utc_time = datetime.datetime.strptime( + _utc_time, "%Y-%m-%d %H:%M:%S.%f%z" + ) + price = decimal.Decimal(_price) + set_price_db("", base_asset, quote_asset, utc_time, price, db_path) + cur = conn.execute(f"DROP TABLE `{tablename}`;") + + +def _get_patch_func_names() -> Iterator[str]: + func_names = ( + f[0] + for f in getmembers(sys.modules[__name__], isfunction) + if f[0].startswith(FUNC_PREFIX) + ) + return func_names + + +def _get_patch_func_versions() -> Iterator[int]: + func_names = _get_patch_func_names() + func_version = map(get_patch_func_version, func_names) + return func_version + + +def get_sorted_patch_func_names(current_version: Optional[int] = None) -> list[str]: + func_names = ( + f + for f in _get_patch_func_names() + if current_version is None or get_patch_func_version(f) > current_version + ) + # Sort patch functions chronological. + return sorted(func_names, key=get_patch_func_version) + + +def get_latest_version() -> int: + func_versions = _get_patch_func_versions() + return max(func_versions) + + +def patch_databases() -> None: + # Check if any database paths exist. + database_paths = [p for p in Path(config.DATA_PATH).glob("*.db") if p.is_file()] + if not database_paths: + return + + # Patch all databases separatly. + for db_path in database_paths: + # Read version from database. + current_version = get_version(db_path) + + patch_func_names = get_sorted_patch_func_names(current_version=current_version) + if not patch_func_names: + continue + + # Run the patch functions. + for patch_func_name in patch_func_names: + logging.info("applying patch %s", patch_func_name.removeprefix(FUNC_PREFIX)) + patch_func = eval(patch_func_name) + patch_func(db_path) + + # Update version. + new_version = get_patch_func_version(patch_func_name) + update_version(db_path, new_version) diff --git a/src/price_data.py b/src/price_data.py index 719a7e1b..a7b52daf 100644 --- a/src/price_data.py +++ b/src/price_data.py @@ -22,7 +22,7 @@ import sqlite3 import time from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Union import requests @@ -30,6 +30,13 @@ import misc import transaction from core import kraken_pair_map +from database import ( + get_db_path, + get_price_db, + get_tablename, + mean_price_db, + set_price_db, +) log = logging.getLogger(__name__) @@ -279,8 +286,9 @@ def _get_price_bitpanda_pro( for num_offset in range(num_max_offsets): # if no trades can be found, move 30 min window to the past window_offset = num_offset * t - end = utc_time.astimezone(datetime.timezone.utc) \ - - datetime.timedelta(minutes=window_offset) + end = utc_time.astimezone(datetime.timezone.utc) - datetime.timedelta( + minutes=window_offset + ) begin = end - datetime.timedelta(minutes=t) # https://github.com/python/mypy/issues/3176 @@ -288,8 +296,8 @@ def _get_price_bitpanda_pro( "unit": "MINUTES", "period": t, # convert ISO 8601 format to RFC3339 timestamp - "from": begin.isoformat().replace('+00:00', 'Z'), - "to": end.isoformat().replace('+00:00', 'Z'), + "from": begin.isoformat().replace("+00:00", "Z"), + "to": end.isoformat().replace("+00:00", "Z"), } if num_offset: log.debug( @@ -457,185 +465,6 @@ def _get_price_kraken( ) return decimal.Decimal() - def __get_price_db( - self, - db_path: Path, - tablename: str, - utc_time: datetime.datetime, - ) -> Optional[decimal.Decimal]: - """Try to retrieve the price from our local database. - - Args: - db_path (Path) - tablename (str) - utc_time (datetime.datetime) - - Returns: - Optional[decimal.Decimal]: Price. - """ - if db_path.is_file(): - with sqlite3.connect(db_path) as conn: - cur = conn.cursor() - query = f"SELECT price FROM `{tablename}` WHERE utc_time=?;" - - try: - cur.execute(query, (utc_time,)) - except sqlite3.OperationalError as e: - if str(e) == f"no such table: {tablename}": - return None - raise e - - if prices := cur.fetchone(): - return misc.force_decimal(prices[0]) - - return None - - def __mean_price_db( - self, - db_path: Path, - tablename: str, - utc_time: datetime.datetime, - ) -> decimal.Decimal: - """Try to retrieve the price right before and after `utc_time` - from our local database. - - Return 0 if the price could not be estimated. - The function does not check, if a price for `utc_time` exists. - - Args: - db_path (Path) - tablename (str) - utc_time (datetime.datetime) - - Returns: - decimal.Decimal: Price. - """ - if db_path.is_file(): - with sqlite3.connect(db_path) as conn: - cur = conn.cursor() - - before_query = ( - f"SELECT utc_time, price FROM `{tablename}` " - f"WHERE utc_time 0 " - "ORDER BY utc_time DESC " - "LIMIT 1" - ) - try: - cur.execute(before_query, (utc_time,)) - except sqlite3.OperationalError as e: - if str(e) == f"no such table: {tablename}": - return decimal.Decimal() - raise e - if result := cur.fetchone(): - before_time = misc.parse_iso_timestamp_to_decimal_timestamp( - result[0] - ) - before_price = misc.force_decimal(result[1]) - else: - return decimal.Decimal() - - after_query = ( - f"SELECT utc_time, price FROM `{tablename}` " - f"WHERE utc_time>? AND price > 0 " - "ORDER BY utc_time ASC " - "LIMIT 1" - ) - try: - cur.execute(after_query, (utc_time,)) - except sqlite3.OperationalError as e: - if str(e) == f"no such table: {tablename}": - return decimal.Decimal() - raise e - if result := cur.fetchone(): - after_time = misc.parse_iso_timestamp_to_decimal_timestamp( - result[0] - ) - after_price = misc.force_decimal(result[1]) - else: - return decimal.Decimal() - - if before_price and after_price: - d_utc_time = misc.to_decimal_timestamp(utc_time) - # Linear gradiant between the neighbored transactions. - m = (after_price - before_price) / (after_time - before_time) - price = before_price + (d_utc_time - before_time) * m - return price - - return decimal.Decimal() - - def __set_price_db( - self, - db_path: Path, - tablename: str, - utc_time: datetime.datetime, - price: decimal.Decimal, - ) -> None: - """Write price to database. - - Create database/table if necessary. - - Args: - db_path (Path) - tablename (str) - utc_time (datetime.datetime) - price (decimal.Decimal) - """ - with sqlite3.connect(db_path) as conn: - cur = conn.cursor() - query = f"INSERT INTO `{tablename}`" "('utc_time', 'price') VALUES (?, ?);" - try: - cur.execute(query, (utc_time, str(price))) - except sqlite3.OperationalError as e: - if str(e) == f"no such table: {tablename}": - create_query = ( - f"CREATE TABLE `{tablename}`" - "(utc_time DATETIME PRIMARY KEY, " - "price FLOAT NOT NULL);" - ) - cur.execute(create_query) - cur.execute(query, (utc_time, str(price))) - else: - raise e - conn.commit() - - def set_price_db( - self, - platform: str, - coin: str, - reference_coin: str, - utc_time: datetime.datetime, - price: decimal.Decimal, - ) -> None: - """Write price to database. - - Tries to insert a historical price into the local database. - - A warning will be raised, if there is already a different price. - - Args: - platform (str): [description] - coin (str): [description] - reference_coin (str): [description] - utc_time (datetime.datetime): [description] - price (decimal.Decimal): [description] - """ - assert coin != reference_coin - db_path = self.get_db_path(platform) - tablename = self.get_tablename(coin, reference_coin) - try: - self.__set_price_db(db_path, tablename, utc_time, price) - except sqlite3.IntegrityError as e: - if str(e) == f"UNIQUE constraint failed: {tablename}.utc_time": - price_db = self.get_price(platform, coin, utc_time, reference_coin) - if price != price_db: - log.warning( - "Tried to write price to database, " - "but a different price exists already." - f"({platform=}, {tablename=}, {utc_time=}, {price=})" - ) - else: - raise e - def get_price( self, platform: str, @@ -645,32 +474,29 @@ def get_price( **kwargs: Any, ) -> decimal.Decimal: """Get the price of a coin pair from a specific `platform` at `utc_time`. - The function tries to retrieve the price from the local database first. If the price does not exist, its gathered from a platform specific function and saved to our local database for future access. - Args: platform (str) coin (str) utc_time (datetime.datetime) reference_coin (str, optional): Defaults to config.FIAT. - Raises: NotImplementedError: Platform specific GET function is not - implemented. - + implemented. Returns: decimal.Decimal: Price of the coin pair. """ if coin == reference_coin: return decimal.Decimal("1") - db_path = self.get_db_path(platform) - tablename = self.get_tablename(coin, reference_coin) + db_path = get_db_path(platform) + tablename = get_tablename(coin, reference_coin) # Check if price exists already in our database. - if (price := self.__get_price_db(db_path, tablename, utc_time)) is None: + if (price := get_price_db(db_path, tablename, utc_time)) is None: + # Price doesn't exists. Fetch price from platform. try: get_price = getattr(self, f"_get_price_{platform}") except AttributeError: @@ -678,13 +504,15 @@ def get_price( price = get_price(coin, utc_time, reference_coin, **kwargs) assert isinstance(price, decimal.Decimal) - self.__set_price_db(db_path, tablename, utc_time, price) + set_price_db( + platform, coin, reference_coin, utc_time, price, db_path=db_path + ) if config.MEAN_MISSING_PRICES and price <= 0.0: # The price is missing. Check for prices before and after the # transaction and estimate the price. # Do not save price in database. - price = self.__mean_price_db(db_path, tablename, utc_time) + price = mean_price_db(db_path, tablename, utc_time) return price