Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Added
Changed
~~~~~~~

- Changed the backend storage of `CitationMixin` cache from CSV to SQLite
for better performance and concurrency support. The existing CSV cache
will be automatically migrated to SQLite upon first use.

Deprecated
~~~~~~~~~~

Expand Down
161 changes: 129 additions & 32 deletions bib_lookup/citation_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" """

import sqlite3
import warnings
from typing import Optional, Sequence, Union

Expand All @@ -20,7 +21,40 @@ class CitationMixin(object):

_bl = BibLookup(timeout=1.0, ignore_errors=False)

citation_cache = CACHE_DIR / "bib-lookup-cache.csv"
citation_cache_csv = CACHE_DIR / "bib-lookup-cache.csv"
citation_cache_db = CACHE_DIR / "bib-lookup-cache.db"
citation_cache = citation_cache_db

def _init_db(self):
"""Initialize sqlite db and migrate csv if exists."""
conn = sqlite3.connect(self.citation_cache_db)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS citations (
doi TEXT PRIMARY KEY,
citation TEXT
)
""")

# Backward compatibility: Migrate CSV to SQLite
if self.citation_cache_csv.exists():
try:
df = pd.read_csv(self.citation_cache_csv)
# Batch insert data, use OR IGNORE to avoid duplicates
data = list(df[["doi", "citation"]].itertuples(index=False, name=None))
cursor.executemany(
"INSERT OR IGNORE INTO citations (doi, citation) VALUES (?, ?)",
data,
)
conn.commit()
# Delete CSV after successful migration
self.citation_cache_csv.unlink()
warnings.warn(f"Migrated citation cache from CSV to SQLite: {self.citation_cache_db}", UserWarning)
except Exception as e:
warnings.warn(f"Failed to migrate CSV cache: {e}", UserWarning)

conn.commit()
conn.close()

def get_citation(
self,
Expand Down Expand Up @@ -56,34 +90,73 @@ def get_citation(

"""
self._bl.clear_cache()
if self.citation_cache.exists():
df_cc = pd.read_csv(self.citation_cache)
else:
df_cc = pd.DataFrame(columns=["doi", "citation"])
df_cc.to_csv(self.citation_cache, index=False)
self._init_db()

if self.doi is not None:
if isinstance(self.doi, str):
doi = [self.doi]
else:
doi = self.doi

# If doi is empty, return empty result
if not doi:
if print_result:
return
else:
return ""

if not lookup:
citation = "\n".join(doi)
if print_result:
print(citation)
return
else:
return citation

# Fetch from cache
conn = sqlite3.connect(self.citation_cache_db)
cursor = conn.cursor()

cached_citations = []
existing_dois = set()

# SQLite has a limit on the number of variables in a query (default 999).
# We chunk the DOIs to avoid hitting this limit.
chunk_size = 900
for i in range(0, len(doi), chunk_size):
chunk = doi[i : i + chunk_size]
if not chunk:
continue

placeholders = ",".join("?" * len(chunk))

if format is not None and format != self._bl.format:
# no cache for format other than bibtex
pass
else:
query = f"SELECT citation FROM citations WHERE doi IN ({placeholders})"
cursor.execute(query, chunk)
cached_citations.extend([row[0] for row in cursor.fetchall()])

query_exist = f"SELECT doi FROM citations WHERE doi IN ({placeholders})"
cursor.execute(query_exist, chunk)
existing_dois.update({row[0] for row in cursor.fetchall()})

if format is not None and format != self._bl.format:
citation = "" # no cache for format other than bibtex
citation = ""
else:
citation = "\n".join(df_cc[df_cc["doi"].isin(doi)]["citation"].tolist())
doi = [item for item in doi if item not in df_cc["doi"].tolist()]
if print_result:
citation = "\n".join(cached_citations)
if print_result and citation:
print(citation)
if len(doi) > 0:

conn.close()

# Filter out DOIs that were found in cache
doi_to_fetch = [item for item in doi if item not in existing_dois]

if len(doi_to_fetch) > 0:
new_citations = []
for item in doi:
for item in doi_to_fetch:
try:
bl_res = self._bl(
item,
Expand All @@ -106,18 +179,31 @@ def get_citation(
except Exception:
if print_result:
print(f"Failed to lookup citation for {item}")

if format is None or format == self._bl.format:
# only cache bibtex format
new_citations = [
item for item in new_citations if item["citation"] is not None and item["citation"].startswith("@")
# Filter for valid bibtex citations (starting with @)
valid_new_citations = [
item
for item in new_citations
if item["citation"] is not None and item["citation"].strip().startswith("@")
]
df_new = pd.DataFrame(new_citations)
if len(df_new) > 0:
df_new.to_csv(self.citation_cache, mode="a", header=False, index=False)
else:
df_new = pd.DataFrame(new_citations)
if len(df_new) > 0:
citation += "\n" + "\n".join(df_new["citation"].tolist())

if valid_new_citations:
conn = sqlite3.connect(self.citation_cache_db)
cursor = conn.cursor()
data_to_insert = [(item["doi"], item["citation"]) for item in valid_new_citations]
cursor.executemany("INSERT OR REPLACE INTO citations (doi, citation) VALUES (?, ?)", data_to_insert)
conn.commit()
conn.close()

# Update the returned citation string with newly fetched citations
additional_citations = [item["citation"] for item in new_citations]
if additional_citations:
if citation:
citation += "\n" + "\n".join(additional_citations)
else:
citation = "\n".join(additional_citations)
else:
citation = ""
doi = []
Expand Down Expand Up @@ -145,17 +231,23 @@ def update_cache(self, doi: Optional[Union[str, Sequence[str]]] = None) -> None:
None

"""
if self.citation_cache.exists():
df_cc = pd.read_csv(self.citation_cache)
else:
df_cc = pd.DataFrame(columns=["doi", "citation"])
self._init_db()
conn = sqlite3.connect(self.citation_cache_db)
cursor = conn.cursor()

if doi is None:
doi = df_cc.doi.tolist()
if isinstance(doi, str):
doi = [doi]
cursor.execute("SELECT doi FROM citations")
doi_list = [row[0] for row in cursor.fetchall()]
else:
if isinstance(doi, str):
doi_list = [doi]
else:
doi_list = list(doi)

conn.close() # Close connection while fetching to avoid locking issues if fetching takes long

new_citations = []
for item in doi:
for item in doi_list:
try:
bl_res = self._bl(item, timeout=10.0)
if bl_res not in self._bl.lookup_errors:
Expand All @@ -167,6 +259,11 @@ def update_cache(self, doi: Optional[Union[str, Sequence[str]]] = None) -> None:
)
except Exception:
print(f"Failed to lookup citation for {item}")
df_cc = pd.concat([df_cc, pd.DataFrame(new_citations)])
df_cc = df_cc.drop_duplicates(subset="doi", keep="last", ignore_index=True)
df_cc.to_csv(self.citation_cache, index=False)

if new_citations:
conn = sqlite3.connect(self.citation_cache_db)
cursor = conn.cursor()
data_to_insert = [(item["doi"], item["citation"]) for item in new_citations]
cursor.executemany("INSERT OR REPLACE INTO citations (doi, citation) VALUES (?, ?)", data_to_insert)
conn.commit()
conn.close()
Loading
Loading