From 462279b60fff180ed189415fdfeb84a84ca27f61 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 23 Apr 2026 23:53:16 +0000 Subject: [PATCH 1/2] Initial plan From dbe127c6e3f8d80b1eb3199c35ecbc4aa8800c57 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:02:57 +0000 Subject: [PATCH 2/2] =?UTF-8?q?Add=20BM25SRetriever=20=E2=80=93=20pure-Pyt?= =?UTF-8?q?hon=20BM25=20without=20Java/Pyserini=20dependency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/DataScienceUIBK/Rankify/sessions/9e4d6c84-c9c5-4b75-b11a-e82fe15ba9a1 Co-authored-by: abdoelsayed2016 <27821589+abdoelsayed2016@users.noreply.github.com> --- pyproject.toml | 3 + rankify/retrievers/__init__.py | 86 ++++-- rankify/retrievers/bm25_retriever.py | 18 +- rankify/retrievers/bm25s_retriever.py | 306 +++++++++++++++++++++ rankify/retrievers/diver_bm25_retriever.py | 19 +- rankify/retrievers/retriever.py | 121 ++++++-- tests/test_bm25s_retriever.py | 230 ++++++++++++++++ 7 files changed, 733 insertions(+), 50 deletions(-) create mode 100644 rankify/retrievers/bm25s_retriever.py create mode 100644 tests/test_bm25s_retriever.py diff --git a/pyproject.toml b/pyproject.toml index 4a7f4e4..26e1bfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,9 @@ retriever = [ # Sparse retrieval (BM25) "pyserini==0.43.0", + # Sparse retrieval (BM25S - pure Python, no Java dependency) + "bm25s>=0.2.0", + # Dense retrieval "faiss-cpu==1.9.0.post1", "h5py==3.12.1", diff --git a/rankify/retrievers/__init__.py b/rankify/retrievers/__init__.py index 5362bfb..62e50bc 100644 --- a/rankify/retrievers/__init__.py +++ b/rankify/retrievers/__init__.py @@ -1,28 +1,82 @@ -# rankify/retrievers/__init__.py - MODIFIED VERSION +# rankify/retrievers/__init__.py from .retriever import Retriever from .base_retriever import BaseRetriever -from .bm25_retriever import BM25Retriever -from .dense_retriever import DenseRetriever -from .ance_retriever import ANCERetriever # NEW IMPORT -from .bge_retriever import BGERetriever -from .colbert_retriever import ColBERTRetriever -from .contriever_retriever import ContrieverRetriever -from .online_retriever import OnlineRetriever -from .hyde_retriever import HydeRetriever -from .diver_dense_retriever import DiverDenseRetriever -from .diver_bm25_retriever import DiverBM25Retriever -from .reasonir_retriever import ReasonIRRetriever -from .reasonembed_retriever import ReasonEmbedRetriever -from .bge_reasoner_retriever import BgeReasonerRetriever +from .bm25s_retriever import BM25SRetriever + +try: + from .bm25_retriever import BM25Retriever +except ImportError: + BM25Retriever = None # type: ignore[assignment,misc] + +try: + from .dense_retriever import DenseRetriever +except ImportError: + DenseRetriever = None # type: ignore[assignment,misc] + +try: + from .ance_retriever import ANCERetriever +except ImportError: + ANCERetriever = None # type: ignore[assignment,misc] + +try: + from .bge_retriever import BGERetriever +except ImportError: + BGERetriever = None # type: ignore[assignment,misc] + +try: + from .colbert_retriever import ColBERTRetriever +except ImportError: + ColBERTRetriever = None # type: ignore[assignment,misc] + +try: + from .contriever_retriever import ContrieverRetriever +except ImportError: + ContrieverRetriever = None # type: ignore[assignment,misc] + +try: + from .online_retriever import OnlineRetriever +except ImportError: + OnlineRetriever = None # type: ignore[assignment,misc] + +try: + from .hyde_retriever import HydeRetriever +except ImportError: + HydeRetriever = None # type: ignore[assignment,misc] + +try: + from .diver_dense_retriever import DiverDenseRetriever +except ImportError: + DiverDenseRetriever = None # type: ignore[assignment,misc] + +try: + from .diver_bm25_retriever import DiverBM25Retriever +except ImportError: + DiverBM25Retriever = None # type: ignore[assignment,misc] + +try: + from .reasonir_retriever import ReasonIRRetriever +except ImportError: + ReasonIRRetriever = None # type: ignore[assignment,misc] + +try: + from .reasonembed_retriever import ReasonEmbedRetriever +except ImportError: + ReasonEmbedRetriever = None # type: ignore[assignment,misc] + +try: + from .bge_reasoner_retriever import BgeReasonerRetriever +except ImportError: + BgeReasonerRetriever = None # type: ignore[assignment,misc] __all__ = [ "Retriever", - "BaseRetriever", + "BaseRetriever", + "BM25SRetriever", "BM25Retriever", "DenseRetriever", - "ANCERetriever", # NEW EXPORT + "ANCERetriever", "BGERetriever", "ColBERTRetriever", "ContrieverRetriever", diff --git a/rankify/retrievers/bm25_retriever.py b/rankify/retrievers/bm25_retriever.py index 67afc58..ae671cd 100644 --- a/rankify/retrievers/bm25_retriever.py +++ b/rankify/retrievers/bm25_retriever.py @@ -1,14 +1,19 @@ # bm25_retriever.py import json from typing import List -from pyserini.search.lucene import LuceneSearcher -from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer from tqdm import tqdm import os from .base_retriever import BaseRetriever from .index_manager import IndexManager from rankify.dataset.dataset import Document, Context +try: + from pyserini.search.lucene import LuceneSearcher + from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer + _PYSERINI_AVAILABLE = True +except ImportError: + _PYSERINI_AVAILABLE = False + class BM25Retriever(BaseRetriever): """ BM25 retriever implementation using Pyserini's LuceneSearcher. @@ -17,6 +22,13 @@ class BM25Retriever(BaseRetriever): """ def __init__(self, index_type: str = "wiki", index_folder: str = None, **kwargs): + if not _PYSERINI_AVAILABLE: + raise ImportError( + "pyserini is required for BM25Retriever. " + "Install it with: pip install pyserini " + "Or use BM25SRetriever for a pure-Python alternative: " + "Retriever(method='bm25s', ...)" + ) super().__init__(**kwargs) self.index_type = index_type self.index_folder = index_folder @@ -48,7 +60,7 @@ def _load_reverse_mapping(self): fwd = json.load(f) # { "orig_id": 123 } m = {str(v): k for k, v in fwd.items()} # ensure string keys return m - def _initialize_searcher(self) -> LuceneSearcher: + def _initialize_searcher(self): """Initialize Lucene searcher.""" if self.index_path.startswith("wikipedia-") or "prebuilt" in self.index_path: return LuceneSearcher.from_prebuilt_index(self.index_path) diff --git a/rankify/retrievers/bm25s_retriever.py b/rankify/retrievers/bm25s_retriever.py new file mode 100644 index 0000000..9d9312f --- /dev/null +++ b/rankify/retrievers/bm25s_retriever.py @@ -0,0 +1,306 @@ +# bm25s_retriever.py +""" +BM25S Retriever - a pure-Python BM25 retriever backed by the ``bm25s`` library. + +Unlike the Pyserini-based :class:`BM25Retriever`, this retriever has no Java +dependency and is considerably lighter (no JVM, no Lucene). + +Usage:: + + from rankify.retrievers import BM25SRetriever + from rankify.dataset.dataset import Document, Question, Answer + + retriever = BM25SRetriever( + n_docs=10, + corpus_path="/path/to/corpus.jsonl", # build index on first run + index_folder="/path/to/save/index", # persist index here + ) + docs = retriever.retrieve([Document(question=Question("What is BM25?"), + answers=Answer([]))]) + +The corpus file can be: + +* **JSONL** (one JSON object per line): each line must contain ``id``, + ``title``, and ``text`` (or ``contents``) fields. +* **TSV** (tab-separated): columns are ``id\\ttext\\ttitle`` (the same layout + used by the ``psgs_w100.tsv`` Wikipedia dump). + +Once built, the index is saved to ``/bm25s_index/`` so subsequent +calls skip the expensive indexing step. +""" + +import os +import json +from typing import List, Optional, Tuple + +from tqdm import tqdm + +from .base_retriever import BaseRetriever +from rankify.dataset.dataset import Document, Context + + +# --------------------------------------------------------------------------- +# Small helper – avoids importing pyserini +# --------------------------------------------------------------------------- + +def _has_answers(text: str, answers: List[str]) -> bool: + """Return *True* if any answer string appears (case-insensitively) in *text*.""" + text_lower = text.lower() + return any(ans.lower() in text_lower for ans in answers) + + +# --------------------------------------------------------------------------- +# Retriever +# --------------------------------------------------------------------------- + +class BM25SRetriever(BaseRetriever): + """ + BM25 retriever using `bm25s `_ – a pure + Python implementation with **no Java / JVM dependency**. + + Parameters + ---------- + index_type: + Logical corpus name used to locate a cached index (``"wiki"`` or + ``"msmarco"``). Ignored when *index_folder* is given explicitly. + index_folder: + Directory used to persist (or load) the bm25s index. A sub-directory + ``bm25s_index`` is created inside this path. If *None* the location is + derived from *index_type* via the :class:`~rankify.retrievers.index_manager.IndexManager`. + corpus_path: + Path to the raw corpus file (JSONL or TSV). Required when no + pre-built index exists yet. + stopwords: + Stopword list passed to ``bm25s.tokenize``. Use ``"en"`` for English + (default) or ``None`` / ``""`` to disable. + stemmer_lang: + ISO language code for the PyStemmer ``Stemmer.Stemmer`` (e.g. + ``"english"``). Requires the optional ``PyStemmer`` package. Set to + ``None`` (default) to skip stemming. + n_docs: + Number of documents to return per query. + batch_size / threads: + Inherited from :class:`BaseRetriever`; unused by this implementation + but kept for API compatibility. + """ + + def __init__( + self, + index_type: str = "wiki", + index_folder: str = None, + corpus_path: str = None, + stopwords: str = "en", + stemmer_lang: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.index_type = index_type + self.corpus_path = corpus_path + self.stopwords = stopwords or None # bm25s treats "" same as None + self.stemmer_lang = stemmer_lang + + # Resolve the directory where the index will live + if index_folder: + self.index_folder = index_folder + else: + from .index_manager import IndexManager + im = IndexManager() + try: + self.index_folder = im.get_index_path("bm25s", index_type) + except (ValueError, KeyError): + # bm25s not registered in IndexManager yet – fall back to cache dir + cache_dir = os.environ.get("RERANKING_CACHE_DIR", "./cache") + self.index_folder = os.path.join(cache_dir, "index", f"bm25s_{index_type}") + + self.stemmer = self._init_stemmer() + # _initialize_searcher is called here; it either loads or builds the index + self.searcher = self._initialize_searcher() + + # ------------------------------------------------------------------ + # BaseRetriever interface + # ------------------------------------------------------------------ + + def _initialize_searcher(self): + """Load an existing bm25s index or build one from *corpus_path*.""" + import bm25s # lazy import keeps package optional at import time + + index_path = os.path.join(self.index_folder, "bm25s_index") + + if os.path.isdir(index_path) and os.listdir(index_path): + print(f"Loading BM25S index from {index_path} …") + retriever = bm25s.BM25.load(index_path, load_corpus=True) + return retriever + + # Index does not exist yet – build it + if not self.corpus_path: + raise FileNotFoundError( + f"No pre-built BM25S index found at '{index_path}'. " + "Please provide 'corpus_path' to build the index." + ) + + return self._build_index(index_path) + + def retrieve(self, documents: List[Document]) -> List[Document]: + """Retrieve the top-*n_docs* contexts for every document in *documents*.""" + import bm25s # lazy import + + queries = [doc.question.question for doc in documents] + print(f"Retrieving {len(documents)} document(s) with BM25S …") + + query_tokens = bm25s.tokenize( + queries, + stopwords=self.stopwords, + stemmer=self.stemmer, + show_progress=False, + ) + + results, scores = self.searcher.retrieve(query_tokens, k=self.n_docs) + + for i, document in enumerate(tqdm(documents, desc="Processing documents")): + contexts: List[Context] = [] + num_hits = results.shape[1] + for j in range(num_hits): + doc_data = results[i, j] + score = float(scores[i, j]) + + doc_id = str(doc_data.get("id", "")) + title = doc_data.get("title", "") + text = doc_data.get("text", "") + + answers = document.answers.answers if document.answers else [] + context = Context( + id=doc_id, + title=title, + text=text, + score=score, + has_answer=_has_answers(text, answers), + ) + contexts.append(context) + + document.contexts = contexts + + return documents + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _init_stemmer(self): + """Return a PyStemmer stemmer, or *None* if unavailable / not requested.""" + if not self.stemmer_lang: + return None + try: + import Stemmer # PyStemmer (optional) + return Stemmer.Stemmer(self.stemmer_lang) + except ImportError: + print( + "Warning: PyStemmer is not installed; stemming is disabled. " + "Install it with: pip install PyStemmer" + ) + return None + + def _build_index(self, index_path: str): + """Build a bm25s index from *self.corpus_path* and save it.""" + import bm25s # lazy import + + corpus, corpus_texts = self._load_corpus(self.corpus_path) + + print(f"Tokenizing {len(corpus)} documents …") + corpus_tokens = bm25s.tokenize( + corpus_texts, + stopwords=self.stopwords, + stemmer=self.stemmer, + show_progress=True, + ) + + print("Indexing corpus …") + retriever = bm25s.BM25(corpus=corpus) + retriever.index(corpus_tokens) + + os.makedirs(index_path, exist_ok=True) + print(f"Saving BM25S index to {index_path} …") + retriever.save(index_path) + + return retriever + + def _load_corpus(self, corpus_path: str) -> Tuple[List[dict], List[str]]: + """ + Load the corpus from a JSONL or TSV file. + + Returns + ------- + corpus: + List of ``{"id": …, "title": …, "text": …}`` dicts. + corpus_texts: + List of strings (``"\\n<text>"``) used for tokenisation. + """ + corpus: List[dict] = [] + + lower = corpus_path.lower() + if lower.endswith(".jsonl") or lower.endswith(".json"): + corpus = self._load_jsonl(corpus_path) + elif lower.endswith(".tsv"): + corpus = self._load_tsv(corpus_path) + else: + # Try JSONL first, fall back to TSV heuristic + try: + corpus = self._load_jsonl(corpus_path) + except (json.JSONDecodeError, UnicodeDecodeError): + corpus = self._load_tsv(corpus_path) + + if not corpus: + raise ValueError(f"Corpus loaded from '{corpus_path}' is empty.") + + corpus_texts = [ + f"{doc['title']}\n{doc['text']}" if doc.get("title") else doc["text"] + for doc in corpus + ] + return corpus, corpus_texts + + @staticmethod + def _load_jsonl(corpus_path: str) -> List[dict]: + corpus = [] + with open(corpus_path, "r", encoding="utf-8") as fh: + for line in tqdm(fh, desc="Loading corpus (JSONL)"): + line = line.strip() + if not line: + continue + doc = json.loads(line) + doc_id = str(doc.get("id") or doc.get("docid") or "") + title = doc.get("title", "") + text = doc.get("text") or doc.get("contents", "") + corpus.append({"id": doc_id, "title": title, "text": text}) + return corpus + + @staticmethod + def _load_tsv(corpus_path: str) -> List[dict]: + """ + Load TSV corpus. Supports both ``id\\ttext\\ttitle`` (psgs_w100 layout) + and ``id\\ttitle\\ttext`` layouts; the header row is used to detect the + column order when present. + """ + corpus = [] + title_col, text_col = 2, 1 # psgs_w100 default: id | text | title + + with open(corpus_path, "r", encoding="utf-8") as fh: + for i, line in enumerate(tqdm(fh, desc="Loading corpus (TSV)")): + line = line.rstrip("\n") + if not line: + continue + parts = line.split("\t") + if i == 0 and parts[0].lower() in ("id", "docid"): + # Detect column order from header + lower_parts = [p.lower() for p in parts] + if "title" in lower_parts and "text" in lower_parts: + title_col = lower_parts.index("title") + text_col = lower_parts.index("text") + continue # skip header row + + if len(parts) < 2: + continue + + doc_id = parts[0] + text = parts[text_col] if len(parts) > text_col else "" + title = parts[title_col] if len(parts) > title_col else "" + corpus.append({"id": doc_id, "title": title, "text": text}) + return corpus diff --git a/rankify/retrievers/diver_bm25_retriever.py b/rankify/retrievers/diver_bm25_retriever.py index cbc75aa..7b917d8 100644 --- a/rankify/retrievers/diver_bm25_retriever.py +++ b/rankify/retrievers/diver_bm25_retriever.py @@ -9,11 +9,15 @@ from typing import List from tqdm import tqdm -from gensim.corpora import Dictionary -from gensim.models import LuceneBM25Model -from gensim.similarities import SparseMatrixSimilarity -from pyserini import analysis -from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer +try: + from gensim.corpora import Dictionary + from gensim.models import LuceneBM25Model + from gensim.similarities import SparseMatrixSimilarity + from pyserini import analysis + from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer + _DIVER_DEPS_AVAILABLE = True +except ImportError: + _DIVER_DEPS_AVAILABLE = False from .base_retriever import BaseRetriever from rankify.dataset.dataset import Document, Context @@ -64,6 +68,11 @@ def __init__( b: float = 0.4, **kwargs, ): + if not _DIVER_DEPS_AVAILABLE: + raise ImportError( + "pyserini and gensim are required for DiverBM25Retriever. " + "Install them with: pip install pyserini gensim" + ) super().__init__(**kwargs) if corpus_path is None: diff --git a/rankify/retrievers/retriever.py b/rankify/retrievers/retriever.py index 0e886b8..18a7680 100644 --- a/rankify/retrievers/retriever.py +++ b/rankify/retrievers/retriever.py @@ -2,35 +2,101 @@ from typing import List, Dict, Type from rankify.dataset.dataset import Document from .base_retriever import BaseRetriever -from .bm25_retriever import BM25Retriever -from .dense_retriever import DenseRetriever -from .ance_retriever import ANCERetriever -from .bge_retriever import BGERetriever -from .colbert_retriever import ColBERTRetriever -from .contriever_retriever import ContrieverRetriever -from .online_retriever import OnlineRetriever -from .hyde_retriever import HydeRetriever -from .diver_dense_retriever import DiverDenseRetriever -from .diver_bm25_retriever import DiverBM25Retriever -from .reasonir_retriever import ReasonIRRetriever -from .reasonembed_retriever import ReasonEmbedRetriever -from .bge_reasoner_retriever import BgeReasonerRetriever -from .unicoil_retriever import UniCOILRetriever -from .splade_v2_retriever import SpladeV2Retriever -from .api_embedding_retriever import APIEmbeddingRetriever - -# Method mapping - UPDATED WITH PROPER ANCE SUPPORT -METHOD_MAP: Dict[str, Type[BaseRetriever]] = { +from .bm25s_retriever import BM25SRetriever + +try: + from .bm25_retriever import BM25Retriever +except ImportError: + BM25Retriever = None # type: ignore[assignment,misc] + +try: + from .dense_retriever import DenseRetriever +except ImportError: + DenseRetriever = None # type: ignore[assignment,misc] + +try: + from .ance_retriever import ANCERetriever +except ImportError: + ANCERetriever = None # type: ignore[assignment,misc] + +try: + from .bge_retriever import BGERetriever +except ImportError: + BGERetriever = None # type: ignore[assignment,misc] + +try: + from .colbert_retriever import ColBERTRetriever +except ImportError: + ColBERTRetriever = None # type: ignore[assignment,misc] + +try: + from .contriever_retriever import ContrieverRetriever +except ImportError: + ContrieverRetriever = None # type: ignore[assignment,misc] + +try: + from .online_retriever import OnlineRetriever +except ImportError: + OnlineRetriever = None # type: ignore[assignment,misc] + +try: + from .hyde_retriever import HydeRetriever +except ImportError: + HydeRetriever = None # type: ignore[assignment,misc] + +try: + from .diver_dense_retriever import DiverDenseRetriever +except ImportError: + DiverDenseRetriever = None # type: ignore[assignment,misc] + +try: + from .diver_bm25_retriever import DiverBM25Retriever +except ImportError: + DiverBM25Retriever = None # type: ignore[assignment,misc] + +try: + from .reasonir_retriever import ReasonIRRetriever +except ImportError: + ReasonIRRetriever = None # type: ignore[assignment,misc] + +try: + from .reasonembed_retriever import ReasonEmbedRetriever +except ImportError: + ReasonEmbedRetriever = None # type: ignore[assignment,misc] + +try: + from .bge_reasoner_retriever import BgeReasonerRetriever +except ImportError: + BgeReasonerRetriever = None # type: ignore[assignment,misc] + +try: + from .unicoil_retriever import UniCOILRetriever +except ImportError: + UniCOILRetriever = None # type: ignore[assignment,misc] + +try: + from .splade_v2_retriever import SpladeV2Retriever +except ImportError: + SpladeV2Retriever = None # type: ignore[assignment,misc] + +try: + from .api_embedding_retriever import APIEmbeddingRetriever +except ImportError: + APIEmbeddingRetriever = None # type: ignore[assignment,misc] + +# Method mapping – entries whose class could not be imported are excluded +_CANDIDATE_MAP: Dict[str, object] = { "bm25": BM25Retriever, + "bm25s": BM25SRetriever, "dpr-multi": DenseRetriever, "dpr-single": DenseRetriever, - "ance-multi": ANCERetriever, + "ance-multi": ANCERetriever, "bpr-single": DenseRetriever, - "bge": BGERetriever, - "colbert": ColBERTRetriever, - "contriever": ContrieverRetriever, - "online": OnlineRetriever, - "hyde": HydeRetriever, + "bge": BGERetriever, + "colbert": ColBERTRetriever, + "contriever": ContrieverRetriever, + "online": OnlineRetriever, + "hyde": HydeRetriever, "diver-dense": DiverDenseRetriever, "diver-bm25": DiverBM25Retriever, "reasonir": ReasonIRRetriever, @@ -43,6 +109,9 @@ "cohere-embedding": APIEmbeddingRetriever, "voyage-embedding": APIEmbeddingRetriever, } +METHOD_MAP: Dict[str, Type[BaseRetriever]] = { + k: v for k, v in _CANDIDATE_MAP.items() if v is not None +} class Retriever: """ @@ -79,7 +148,7 @@ def __init__(self, method: str, n_docs: int = 10, index_type: str = "wiki", Initialize the retriever. Args: - method (str): Retrieval method ('bm25', 'dpr-multi', 'dpr-single', 'ance', 'ance-multi', 'bpr-single', etc.) + method (str): Retrieval method ('bm25', 'bm25s', 'dpr-multi', 'dpr-single', 'ance', 'ance-multi', 'bpr-single', etc.) n_docs (int): Number of documents to retrieve per query index_type (str): Index type ('wiki', 'msmarco') - ignored if index_folder is provided index_folder (str): Path to custom index folder (optional) diff --git a/tests/test_bm25s_retriever.py b/tests/test_bm25s_retriever.py new file mode 100644 index 0000000..16a859a --- /dev/null +++ b/tests/test_bm25s_retriever.py @@ -0,0 +1,230 @@ +""" +Unit tests for BM25SRetriever – pure-Python BM25 retriever using ``bm25s``. + +These tests run without any heavy dependencies (no Java, no Pyserini, no GPU). +They exercise: + - Building an index from JSONL corpus + - Building an index from TSV corpus + - Loading a persisted index + - Correctness of returned contexts (scores, has_answer flag, title/text) + - Edge-cases (no results, empty answers list) + - Unified ``Retriever(method='bm25s', ...)`` interface +""" + +import json +import os +import tempfile +import unittest + +from rankify.dataset.dataset import Answer, Context, Document, Question +from rankify.retrievers.bm25s_retriever import BM25SRetriever, _has_answers +from rankify.retrievers.retriever import Retriever + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_CORPUS = [ + {"id": "1", "title": "Anarchism", "text": "Anarchism is a political philosophy that advocates self-governed societies."}, + {"id": "2", "title": "Paris", "text": "Paris is the capital and largest city of France."}, + {"id": "3", "title": "Python", "text": "Python is a high-level, general-purpose programming language."}, + {"id": "4", "title": "BM25", "text": "BM25 is a bag-of-words retrieval function used to rank documents."}, + {"id": "5", "title": "Biology", "text": "Biology is the scientific study of life and living organisms."}, +] + + +def _write_jsonl(path: str, docs=_CORPUS): + with open(path, "w", encoding="utf-8") as fh: + for doc in docs: + fh.write(json.dumps(doc) + "\n") + + +def _write_tsv(path: str, docs=_CORPUS): + """Write PSGs-style TSV: id TAB text TAB title.""" + with open(path, "w", encoding="utf-8") as fh: + fh.write("id\ttext\ttitle\n") + for doc in docs: + fh.write(f"{doc['id']}\t{doc['text']}\t{doc['title']}\n") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestHasAnswers(unittest.TestCase): + def test_present(self): + self.assertTrue(_has_answers("Paris is great", ["Paris"])) + + def test_case_insensitive(self): + self.assertTrue(_has_answers("PARIS is great", ["paris"])) + + def test_absent(self): + self.assertFalse(_has_answers("London is great", ["Paris"])) + + def test_empty_answers(self): + self.assertFalse(_has_answers("Paris is great", [])) + + +class TestBM25SRetrieverJSONL(unittest.TestCase): + """Build index from JSONL corpus and run queries.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.corpus_path = os.path.join(self.tmpdir, "corpus.jsonl") + _write_jsonl(self.corpus_path) + self.index_folder = os.path.join(self.tmpdir, "index") + self.retriever = BM25SRetriever( + n_docs=3, + corpus_path=self.corpus_path, + index_folder=self.index_folder, + ) + + def _query(self, question: str, answers=None): + doc = Document( + question=Question(question=question), + answers=Answer(answers=answers or []), + ) + return self.retriever.retrieve([doc])[0] + + def test_returns_n_docs(self): + result = self._query("French capital") + self.assertEqual(len(result.contexts), 3) + + def test_contexts_are_context_objects(self): + result = self._query("philosophy") + for ctx in result.contexts: + self.assertIsInstance(ctx, Context) + + def test_top_result_is_relevant(self): + result = self._query("capital of France") + self.assertEqual(result.contexts[0].title, "Paris") + + def test_has_answer_flag_true(self): + result = self._query("capital city", answers=["Paris"]) + top = result.contexts[0] + self.assertTrue(top.has_answer) + + def test_has_answer_flag_false(self): + result = self._query("programming language", answers=["Java"]) + # Java does not appear in any doc text + for ctx in result.contexts: + self.assertFalse(ctx.has_answer) + + def test_context_fields_populated(self): + result = self._query("political philosophy") + ctx = result.contexts[0] + self.assertIsInstance(ctx.id, str) + self.assertIsInstance(ctx.title, str) + self.assertIsInstance(ctx.text, str) + self.assertIsInstance(ctx.score, float) + + def test_scores_descending(self): + result = self._query("biology living organisms") + scores = [ctx.score for ctx in result.contexts] + self.assertEqual(scores, sorted(scores, reverse=True)) + + +class TestBM25SRetrieverTSV(unittest.TestCase): + """Build index from TSV corpus.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.corpus_path = os.path.join(self.tmpdir, "corpus.tsv") + _write_tsv(self.corpus_path) + self.index_folder = os.path.join(self.tmpdir, "index") + self.retriever = BM25SRetriever( + n_docs=2, + corpus_path=self.corpus_path, + index_folder=self.index_folder, + ) + + def test_top_result_tsv(self): + doc = Document( + question=Question(question="capital France"), + answers=Answer(answers=["Paris"]), + ) + result = self.retriever.retrieve([doc])[0] + self.assertEqual(result.contexts[0].title, "Paris") + self.assertTrue(result.contexts[0].has_answer) + + +class TestBM25SIndexPersistence(unittest.TestCase): + """Index built on first run should be loadable without corpus_path.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + corpus_path = os.path.join(self.tmpdir, "corpus.jsonl") + _write_jsonl(corpus_path) + self.index_folder = os.path.join(self.tmpdir, "index") + # Build the index + BM25SRetriever(n_docs=2, corpus_path=corpus_path, index_folder=self.index_folder) + + def test_load_without_corpus_path(self): + """Loading without corpus_path should succeed because the index already exists.""" + loaded = BM25SRetriever(n_docs=2, index_folder=self.index_folder) + doc = Document( + question=Question(question="programming language"), + answers=Answer(answers=[]), + ) + result = loaded.retrieve([doc])[0] + self.assertEqual(len(result.contexts), 2) + self.assertEqual(result.contexts[0].title, "Python") + + def test_missing_index_no_corpus_raises(self): + """Without a pre-built index and without corpus_path, expect FileNotFoundError.""" + with self.assertRaises(FileNotFoundError): + BM25SRetriever(n_docs=2, index_folder=os.path.join(self.tmpdir, "nonexistent")) + + +class TestBM25SRetrieverBatchQueries(unittest.TestCase): + """Multiple queries in a single call.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + corpus_path = os.path.join(self.tmpdir, "corpus.jsonl") + _write_jsonl(corpus_path) + index_folder = os.path.join(self.tmpdir, "index") + self.retriever = BM25SRetriever(n_docs=2, corpus_path=corpus_path, index_folder=index_folder) + + def test_batch_retrieval(self): + docs = [ + Document(question=Question(question="anarchism politics"), answers=Answer(answers=[])), + Document(question=Question(question="Paris France"), answers=Answer(answers=["Paris"])), + Document(question=Question(question="python programming"), answers=Answer(answers=[])), + ] + results = self.retriever.retrieve(docs) + self.assertEqual(len(results), 3) + for doc in results: + self.assertEqual(len(doc.contexts), 2) + + +class TestRetrieverUnifiedInterface(unittest.TestCase): + """Ensure ``Retriever(method='bm25s', ...)`` works end-to-end.""" + + def test_bm25s_in_supported_methods(self): + self.assertIn("bm25s", Retriever.supported_methods()) + + def test_unified_retriever(self): + with tempfile.TemporaryDirectory() as tmpdir: + corpus_path = os.path.join(tmpdir, "corpus.jsonl") + _write_jsonl(corpus_path) + index_folder = os.path.join(tmpdir, "index") + + retriever = Retriever( + method="bm25s", + n_docs=2, + index_folder=index_folder, + corpus_path=corpus_path, + ) + docs = [Document( + question=Question(question="capital France"), + answers=Answer(answers=["Paris"]), + )] + results = retriever.retrieve(docs) + self.assertEqual(len(results[0].contexts), 2) + self.assertEqual(results[0].contexts[0].title, "Paris") + + +if __name__ == "__main__": + unittest.main()