mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
feat: add reranker interface and makefile
This commit is contained in:
parent
3294dafaa6
commit
cf7d1e8218
16
Makefile
16
Makefile
@ -0,0 +1,16 @@
|
||||
.PHONY: create-tables help
|
||||
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " create-tables - Create database tables using create_tables.py"
|
||||
@echo " install-deps - Install Python dependencies using uv"
|
||||
@echo " help - Show this help message"
|
||||
|
||||
install-deps:
|
||||
uv sync --locked --no-install-project --no-dev
|
||||
|
||||
create-tables:
|
||||
@echo "Creating database tables..."
|
||||
uv run python scripts/create_tables.py
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
@ -4,3 +4,7 @@ class DocumentInsertionError(Exception):
|
||||
|
||||
class DocumentExtractionError(Exception):
|
||||
"""Exception raised when document extraction from PDF fails."""
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
"""Exception raised when model is not found."""
|
||||
|
||||
@ -16,6 +16,12 @@ class EmbeddingModel(Protocol):
|
||||
def embed_query(self, text: str) -> np.ndarray: ...
|
||||
|
||||
|
||||
class Reranker(Protocol):
|
||||
def rerank(
|
||||
self, documents: list[SearchResult], query: str
|
||||
) -> list[SearchResult]: ...
|
||||
|
||||
|
||||
class VectorDB(Protocol):
|
||||
def upsert_documents(self, documents: list[dict]) -> None: ...
|
||||
|
||||
|
||||
61
app/services/rerankers.py
Normal file
61
app/services/rerankers.py
Normal file
@ -0,0 +1,61 @@
|
||||
import logging
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
from app.core.exception import ModelNotFoundError
|
||||
from app.core.interfaces import Reranker, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# pyright: reportCallIssue=false
|
||||
|
||||
|
||||
class MiniLMReranker(Reranker):
|
||||
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||
try:
|
||||
self.model = CrossEncoder(model_name)
|
||||
except Exception as er:
|
||||
err = f"Failed to load model '{model_name}'"
|
||||
logger.exception(err)
|
||||
raise ModelNotFoundError(err) from er
|
||||
|
||||
def rerank(self, documents: list[SearchResult], query: str) -> list[SearchResult]:
|
||||
if not documents:
|
||||
logger.warning("No documents to rerank.")
|
||||
return []
|
||||
|
||||
# Preprocess pairs and keep track of original indexes
|
||||
pairs = []
|
||||
valid_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
content = doc.get("content", "")
|
||||
if not content:
|
||||
err = f"Document at index {i} has no content."
|
||||
logger.warning(err)
|
||||
continue
|
||||
pairs.append((query, content))
|
||||
valid_docs.append(doc)
|
||||
|
||||
if not pairs:
|
||||
logger.warning("No valid document pairs to rerank.")
|
||||
return []
|
||||
|
||||
try:
|
||||
scores = self.model.predict(pairs)
|
||||
except Exception as e:
|
||||
err = f"Model prediction failed: {e}"
|
||||
logger.exception(err)
|
||||
return valid_docs # fallback: return unranked valid docs
|
||||
|
||||
# Sort by score descending
|
||||
if len(scores) != len(valid_docs):
|
||||
logger.warning("Mismatch in number of scores and documents")
|
||||
return valid_docs # or handle the mismatch appropriately
|
||||
|
||||
result = sorted(
|
||||
zip(scores, valid_docs, strict=False),
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
return [doc for _, doc in result]
|
||||
153
scripts/create_tables.py
Normal file
153
scripts/create_tables.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database table creation script with pgvector support.
|
||||
|
||||
This script initializes the database with the required tables and extensions
|
||||
for vector similarity search using pgvector.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
import psycopg2
|
||||
import structlog
|
||||
from dotenv import load_dotenv
|
||||
from psycopg2.extensions import connection as pg_connection
|
||||
from psycopg2.extensions import cursor as pg_cursor
|
||||
|
||||
# Configure structlog
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(structlog.INFO),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=False,
|
||||
)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def get_db_config() -> dict:
|
||||
"""
|
||||
Retrieve database configuration from environment variables.
|
||||
|
||||
Returns:
|
||||
dict: Database connection parameters
|
||||
|
||||
"""
|
||||
return {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": os.getenv("DB_PORT", "5432"),
|
||||
"user": os.getenv("DB_USER", "user"),
|
||||
"password": os.getenv("DB_PASSWORD", "password"),
|
||||
"database": os.getenv("DB_NAME", "mydatabase"),
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection() -> Generator[tuple[pg_connection, pg_cursor], None, None]:
|
||||
"""
|
||||
Context manager for database connection handling.
|
||||
|
||||
Yields:
|
||||
Tuple containing connection and cursor objects
|
||||
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
db_config = get_db_config()
|
||||
conn = psycopg2.connect(**db_config)
|
||||
cursor = conn.cursor()
|
||||
logger.info(
|
||||
"Successfully connected to PostgreSQL database",
|
||||
database=db_config["database"],
|
||||
host=db_config["host"],
|
||||
)
|
||||
yield conn, cursor
|
||||
except Exception:
|
||||
logger.exception("Database connection failed")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def create_vector_extension(conn: pg_connection, cursor: pg_cursor) -> None:
|
||||
"""Create pgvector extension if it doesn't exist."""
|
||||
try:
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
conn.commit()
|
||||
logger.info("pgvector extension is enabled")
|
||||
except Exception:
|
||||
logger.exception("Failed to create pgvector extension")
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def create_documents_table(conn: pg_connection, cursor: pg_cursor) -> None:
|
||||
"""Create documents table with vector support."""
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id SERIAL PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
embedding VECTOR(384), -- Match the dimension of your embedding model
|
||||
source VARCHAR(255),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
logger.info("Table 'documents' created successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to create documents table")
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def create_vector_index(
|
||||
conn: pg_connection, cursor: pg_cursor, dimensions: int = 384
|
||||
) -> None:
|
||||
"""Create HNSW index on the vector column."""
|
||||
try:
|
||||
logger.info("Creating HNSW index on vectors of dimension %s", dimensions)
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS documents_embedding_idx
|
||||
ON documents
|
||||
USING HNSW (embedding vector_cosine_ops)
|
||||
WITH (m = 16, ef_construction = 64);
|
||||
""")
|
||||
conn.commit()
|
||||
logger.info("HNSW index created successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to create vector index")
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main function to set up the database schema."""
|
||||
load_dotenv()
|
||||
logger.info("Starting database setup")
|
||||
|
||||
try:
|
||||
with get_db_connection() as (conn, cursor):
|
||||
create_vector_extension(conn, cursor)
|
||||
create_documents_table(conn, cursor)
|
||||
create_vector_index(conn, cursor)
|
||||
|
||||
logger.info("Database setup completed successfully")
|
||||
except Exception:
|
||||
logger.exception("Database setup failed")
|
||||
raise SystemExit(1) from None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user