mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-19 23:04:05 +01:00
feat: add reranker interface and makefile
This commit is contained in:
parent
3294dafaa6
commit
cf7d1e8218
16
Makefile
16
Makefile
@ -0,0 +1,16 @@
|
|||||||
|
.PHONY: create-tables help
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo "Available targets:"
|
||||||
|
@echo " create-tables - Create database tables using create_tables.py"
|
||||||
|
@echo " install-deps - Install Python dependencies using uv"
|
||||||
|
@echo " help - Show this help message"
|
||||||
|
|
||||||
|
install-deps:
|
||||||
|
uv sync --locked --no-install-project --no-dev
|
||||||
|
|
||||||
|
create-tables:
|
||||||
|
@echo "Creating database tables..."
|
||||||
|
uv run python scripts/create_tables.py
|
||||||
|
|
||||||
|
.DEFAULT_GOAL := help
|
||||||
@ -4,3 +4,7 @@ class DocumentInsertionError(Exception):
|
|||||||
|
|
||||||
class DocumentExtractionError(Exception):
|
class DocumentExtractionError(Exception):
|
||||||
"""Exception raised when document extraction from PDF fails."""
|
"""Exception raised when document extraction from PDF fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFoundError(Exception):
|
||||||
|
"""Exception raised when model is not found."""
|
||||||
|
|||||||
@ -16,6 +16,12 @@ class EmbeddingModel(Protocol):
|
|||||||
def embed_query(self, text: str) -> np.ndarray: ...
|
def embed_query(self, text: str) -> np.ndarray: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Reranker(Protocol):
|
||||||
|
def rerank(
|
||||||
|
self, documents: list[SearchResult], query: str
|
||||||
|
) -> list[SearchResult]: ...
|
||||||
|
|
||||||
|
|
||||||
class VectorDB(Protocol):
|
class VectorDB(Protocol):
|
||||||
def upsert_documents(self, documents: list[dict]) -> None: ...
|
def upsert_documents(self, documents: list[dict]) -> None: ...
|
||||||
|
|
||||||
|
|||||||
61
app/services/rerankers.py
Normal file
61
app/services/rerankers.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
# pyright: reportArgumentType=false
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
from app.core.exception import ModelNotFoundError
|
||||||
|
from app.core.interfaces import Reranker, SearchResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# pyright: reportCallIssue=false
|
||||||
|
|
||||||
|
|
||||||
|
class MiniLMReranker(Reranker):
|
||||||
|
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||||
|
try:
|
||||||
|
self.model = CrossEncoder(model_name)
|
||||||
|
except Exception as er:
|
||||||
|
err = f"Failed to load model '{model_name}'"
|
||||||
|
logger.exception(err)
|
||||||
|
raise ModelNotFoundError(err) from er
|
||||||
|
|
||||||
|
def rerank(self, documents: list[SearchResult], query: str) -> list[SearchResult]:
|
||||||
|
if not documents:
|
||||||
|
logger.warning("No documents to rerank.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Preprocess pairs and keep track of original indexes
|
||||||
|
pairs = []
|
||||||
|
valid_docs = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
content = doc.get("content", "")
|
||||||
|
if not content:
|
||||||
|
err = f"Document at index {i} has no content."
|
||||||
|
logger.warning(err)
|
||||||
|
continue
|
||||||
|
pairs.append((query, content))
|
||||||
|
valid_docs.append(doc)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
logger.warning("No valid document pairs to rerank.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
scores = self.model.predict(pairs)
|
||||||
|
except Exception as e:
|
||||||
|
err = f"Model prediction failed: {e}"
|
||||||
|
logger.exception(err)
|
||||||
|
return valid_docs # fallback: return unranked valid docs
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
if len(scores) != len(valid_docs):
|
||||||
|
logger.warning("Mismatch in number of scores and documents")
|
||||||
|
return valid_docs # or handle the mismatch appropriately
|
||||||
|
|
||||||
|
result = sorted(
|
||||||
|
zip(scores, valid_docs, strict=False),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
return [doc for _, doc in result]
|
||||||
153
scripts/create_tables.py
Normal file
153
scripts/create_tables.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Database table creation script with pgvector support.
|
||||||
|
|
||||||
|
This script initializes the database with the required tables and extensions
|
||||||
|
for vector similarity search using pgvector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import structlog
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from psycopg2.extensions import connection as pg_connection
|
||||||
|
from psycopg2.extensions import cursor as pg_cursor
|
||||||
|
|
||||||
|
# Configure structlog
|
||||||
|
structlog.configure(
|
||||||
|
processors=[
|
||||||
|
structlog.processors.add_log_level,
|
||||||
|
structlog.processors.StackInfoRenderer(),
|
||||||
|
structlog.dev.ConsoleRenderer(),
|
||||||
|
],
|
||||||
|
wrapper_class=structlog.make_filtering_bound_logger(structlog.INFO),
|
||||||
|
context_class=dict,
|
||||||
|
logger_factory=structlog.PrintLoggerFactory(),
|
||||||
|
cache_logger_on_first_use=False,
|
||||||
|
)
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_config() -> dict:
|
||||||
|
"""
|
||||||
|
Retrieve database configuration from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Database connection parameters
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"host": os.getenv("DB_HOST", "localhost"),
|
||||||
|
"port": os.getenv("DB_PORT", "5432"),
|
||||||
|
"user": os.getenv("DB_USER", "user"),
|
||||||
|
"password": os.getenv("DB_PASSWORD", "password"),
|
||||||
|
"database": os.getenv("DB_NAME", "mydatabase"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db_connection() -> Generator[tuple[pg_connection, pg_cursor], None, None]:
|
||||||
|
"""
|
||||||
|
Context manager for database connection handling.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple containing connection and cursor objects
|
||||||
|
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
db_config = get_db_config()
|
||||||
|
conn = psycopg2.connect(**db_config)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
logger.info(
|
||||||
|
"Successfully connected to PostgreSQL database",
|
||||||
|
database=db_config["database"],
|
||||||
|
host=db_config["host"],
|
||||||
|
)
|
||||||
|
yield conn, cursor
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Database connection failed")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_extension(conn: pg_connection, cursor: pg_cursor) -> None:
|
||||||
|
"""Create pgvector extension if it doesn't exist."""
|
||||||
|
try:
|
||||||
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("pgvector extension is enabled")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create pgvector extension")
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def create_documents_table(conn: pg_connection, cursor: pg_cursor) -> None:
|
||||||
|
"""Create documents table with vector support."""
|
||||||
|
try:
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS documents (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding VECTOR(384), -- Match the dimension of your embedding model
|
||||||
|
source VARCHAR(255),
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("Table 'documents' created successfully")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create documents table")
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_index(
|
||||||
|
conn: pg_connection, cursor: pg_cursor, dimensions: int = 384
|
||||||
|
) -> None:
|
||||||
|
"""Create HNSW index on the vector column."""
|
||||||
|
try:
|
||||||
|
logger.info("Creating HNSW index on vectors of dimension %s", dimensions)
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS documents_embedding_idx
|
||||||
|
ON documents
|
||||||
|
USING HNSW (embedding vector_cosine_ops)
|
||||||
|
WITH (m = 16, ef_construction = 64);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("HNSW index created successfully")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create vector index")
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main function to set up the database schema."""
|
||||||
|
load_dotenv()
|
||||||
|
logger.info("Starting database setup")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_db_connection() as (conn, cursor):
|
||||||
|
create_vector_extension(conn, cursor)
|
||||||
|
create_documents_table(conn, cursor)
|
||||||
|
create_vector_index(conn, cursor)
|
||||||
|
|
||||||
|
logger.info("Database setup completed successfully")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Database setup failed")
|
||||||
|
raise SystemExit(1) from None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user