mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 22:44:03 +01:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
# pyright: reportArgumentType=false
|
|
from sentence_transformers import CrossEncoder
|
|
from structlog import get_logger
|
|
|
|
from app.core.exception import ModelNotFoundError
|
|
from app.core.interfaces import Reranker, SearchResult
|
|
|
|
logger = get_logger()
|
|
|
|
# pyright: reportCallIssue=false
|
|
|
|
|
|
class MiniLMReranker(Reranker):
|
|
def __init__(self, model_name: str):
|
|
try:
|
|
self.model = CrossEncoder(model_name)
|
|
except Exception as er:
|
|
err = f"Failed to load model '{model_name}'"
|
|
logger.exception(err)
|
|
raise ModelNotFoundError(err) from er
|
|
|
|
def rerank(self, documents: list[SearchResult], query: str) -> list[SearchResult]:
|
|
if not documents:
|
|
logger.warning("No documents to rerank.")
|
|
return []
|
|
|
|
# Preprocess pairs and keep track of original indexes
|
|
pairs = []
|
|
valid_docs = []
|
|
for i, doc in enumerate(documents):
|
|
content = doc.get("content", "")
|
|
if not content:
|
|
err = f"Document at index {i} has no content."
|
|
logger.warning(err)
|
|
continue
|
|
pairs.append((query, content))
|
|
valid_docs.append(doc)
|
|
|
|
if not pairs:
|
|
logger.warning("No valid document pairs to rerank.")
|
|
return []
|
|
|
|
try:
|
|
scores = self.model.predict(pairs)
|
|
except Exception as e:
|
|
err = f"Model prediction failed: {e}"
|
|
logger.exception(err)
|
|
return valid_docs # fallback: return unranked valid docs
|
|
|
|
# Sort by score descending
|
|
if len(scores) != len(valid_docs):
|
|
logger.warning("Mismatch in number of scores and documents")
|
|
return valid_docs # or handle the mismatch appropriately
|
|
|
|
result = sorted(
|
|
zip(scores, valid_docs, strict=False),
|
|
key=lambda x: x[0],
|
|
reverse=True,
|
|
)
|
|
return [doc for _, doc in result]
|