feat: add reranker interface and makefile

This commit is contained in:
Sosokker 2025-06-24 17:42:31 +07:00
parent 3294dafaa6
commit cf7d1e8218
5 changed files with 240 additions and 0 deletions

View File

@ -0,0 +1,16 @@
.PHONY: create-tables help
help:
@echo "Available targets:"
@echo " create-tables - Create database tables using create_tables.py"
@echo " install-deps - Install Python dependencies using uv"
@echo " help - Show this help message"
install-deps:
uv sync --locked --no-install-project --no-dev
create-tables:
@echo "Creating database tables..."
uv run python scripts/create_tables.py
.DEFAULT_GOAL := help

View File

@ -4,3 +4,7 @@ class DocumentInsertionError(Exception):
class DocumentExtractionError(Exception):
"""Exception raised when document extraction from PDF fails."""
class ModelNotFoundError(Exception):
"""Exception raised when model is not found."""

View File

@ -16,6 +16,12 @@ class EmbeddingModel(Protocol):
def embed_query(self, text: str) -> np.ndarray: ...
class Reranker(Protocol):
def rerank(
self, documents: list[SearchResult], query: str
) -> list[SearchResult]: ...
class VectorDB(Protocol):
def upsert_documents(self, documents: list[dict]) -> None: ...

61
app/services/rerankers.py Normal file
View File

@ -0,0 +1,61 @@
import logging
# pyright: reportArgumentType=false
from sentence_transformers import CrossEncoder
from app.core.exception import ModelNotFoundError
from app.core.interfaces import Reranker, SearchResult
logger = logging.getLogger(__name__)
# pyright: reportCallIssue=false
class MiniLMReranker(Reranker):
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
try:
self.model = CrossEncoder(model_name)
except Exception as er:
err = f"Failed to load model '{model_name}'"
logger.exception(err)
raise ModelNotFoundError(err) from er
def rerank(self, documents: list[SearchResult], query: str) -> list[SearchResult]:
if not documents:
logger.warning("No documents to rerank.")
return []
# Preprocess pairs and keep track of original indexes
pairs = []
valid_docs = []
for i, doc in enumerate(documents):
content = doc.get("content", "")
if not content:
err = f"Document at index {i} has no content."
logger.warning(err)
continue
pairs.append((query, content))
valid_docs.append(doc)
if not pairs:
logger.warning("No valid document pairs to rerank.")
return []
try:
scores = self.model.predict(pairs)
except Exception as e:
err = f"Model prediction failed: {e}"
logger.exception(err)
return valid_docs # fallback: return unranked valid docs
# Sort by score descending
if len(scores) != len(valid_docs):
logger.warning("Mismatch in number of scores and documents")
return valid_docs # or handle the mismatch appropriately
result = sorted(
zip(scores, valid_docs, strict=False),
key=lambda x: x[0],
reverse=True,
)
return [doc for _, doc in result]

153
scripts/create_tables.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Database table creation script with pgvector support.
This script initializes the database with the required tables and extensions
for vector similarity search using pgvector.
"""
import os
from collections.abc import Generator
from contextlib import contextmanager
import psycopg2
import structlog
from dotenv import load_dotenv
from psycopg2.extensions import connection as pg_connection
from psycopg2.extensions import cursor as pg_cursor
# Configure structlog
structlog.configure(
processors=[
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.ConsoleRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(structlog.INFO),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=False,
)
logger = structlog.get_logger()
def get_db_config() -> dict:
"""
Retrieve database configuration from environment variables.
Returns:
dict: Database connection parameters
"""
return {
"host": os.getenv("DB_HOST", "localhost"),
"port": os.getenv("DB_PORT", "5432"),
"user": os.getenv("DB_USER", "user"),
"password": os.getenv("DB_PASSWORD", "password"),
"database": os.getenv("DB_NAME", "mydatabase"),
}
@contextmanager
def get_db_connection() -> Generator[tuple[pg_connection, pg_cursor], None, None]:
"""
Context manager for database connection handling.
Yields:
Tuple containing connection and cursor objects
"""
conn = None
try:
db_config = get_db_config()
conn = psycopg2.connect(**db_config)
cursor = conn.cursor()
logger.info(
"Successfully connected to PostgreSQL database",
database=db_config["database"],
host=db_config["host"],
)
yield conn, cursor
except Exception:
logger.exception("Database connection failed")
raise
finally:
if conn:
cursor.close()
conn.close()
def create_vector_extension(conn: pg_connection, cursor: pg_cursor) -> None:
"""Create pgvector extension if it doesn't exist."""
try:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
conn.commit()
logger.info("pgvector extension is enabled")
except Exception:
logger.exception("Failed to create pgvector extension")
conn.rollback()
raise
def create_documents_table(conn: pg_connection, cursor: pg_cursor) -> None:
"""Create documents table with vector support."""
try:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
id SERIAL PRIMARY KEY,
content TEXT NOT NULL,
embedding VECTOR(384), -- Match the dimension of your embedding model
source VARCHAR(255),
created_at TIMESTAMPTZ DEFAULT NOW()
);
""")
conn.commit()
logger.info("Table 'documents' created successfully")
except Exception:
logger.exception("Failed to create documents table")
conn.rollback()
raise
def create_vector_index(
conn: pg_connection, cursor: pg_cursor, dimensions: int = 384
) -> None:
"""Create HNSW index on the vector column."""
try:
logger.info("Creating HNSW index on vectors of dimension %s", dimensions)
cursor.execute("""
CREATE INDEX IF NOT EXISTS documents_embedding_idx
ON documents
USING HNSW (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
""")
conn.commit()
logger.info("HNSW index created successfully")
except Exception:
logger.exception("Failed to create vector index")
conn.rollback()
raise
def main() -> None:
"""Main function to set up the database schema."""
load_dotenv()
logger.info("Starting database setup")
try:
with get_db_connection() as (conn, cursor):
create_vector_extension(conn, cursor)
create_documents_table(conn, cursor)
create_vector_index(conn, cursor)
logger.info("Database setup completed successfully")
except Exception:
logger.exception("Database setup failed")
raise SystemExit(1) from None
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
main()