diff --git a/Makefile b/Makefile index e69de29..4aef8b4 100644 --- a/Makefile +++ b/Makefile @@ -0,0 +1,16 @@ +.PHONY: create-tables help + +help: + @echo "Available targets:" + @echo " create-tables - Create database tables using create_tables.py" + @echo " install-deps - Install Python dependencies using uv" + @echo " help - Show this help message" + +install-deps: + uv sync --locked --no-install-project --no-dev + +create-tables: + @echo "Creating database tables..." + uv run python scripts/create_tables.py + +.DEFAULT_GOAL := help \ No newline at end of file diff --git a/app/core/exception.py b/app/core/exception.py index 88b2da3..1383495 100644 --- a/app/core/exception.py +++ b/app/core/exception.py @@ -4,3 +4,7 @@ class DocumentInsertionError(Exception): class DocumentExtractionError(Exception): """Exception raised when document extraction from PDF fails.""" + + +class ModelNotFoundError(Exception): + """Exception raised when model is not found.""" diff --git a/app/core/interfaces.py b/app/core/interfaces.py index bd4081e..cf157af 100644 --- a/app/core/interfaces.py +++ b/app/core/interfaces.py @@ -16,6 +16,12 @@ class EmbeddingModel(Protocol): def embed_query(self, text: str) -> np.ndarray: ... +class Reranker(Protocol): + def rerank( + self, documents: list[SearchResult], query: str + ) -> list[SearchResult]: ... + + class VectorDB(Protocol): def upsert_documents(self, documents: list[dict]) -> None: ... diff --git a/app/services/rerankers.py b/app/services/rerankers.py new file mode 100644 index 0000000..683cbe5 --- /dev/null +++ b/app/services/rerankers.py @@ -0,0 +1,61 @@ +import logging + +# pyright: reportArgumentType=false +from sentence_transformers import CrossEncoder + +from app.core.exception import ModelNotFoundError +from app.core.interfaces import Reranker, SearchResult + +logger = logging.getLogger(__name__) + +# pyright: reportCallIssue=false + + +class MiniLMReranker(Reranker): + def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): + try: + self.model = CrossEncoder(model_name) + except Exception as er: + err = f"Failed to load model '{model_name}'" + logger.exception(err) + raise ModelNotFoundError(err) from er + + def rerank(self, documents: list[SearchResult], query: str) -> list[SearchResult]: + if not documents: + logger.warning("No documents to rerank.") + return [] + + # Preprocess pairs and keep track of original indexes + pairs = [] + valid_docs = [] + for i, doc in enumerate(documents): + content = doc.get("content", "") + if not content: + err = f"Document at index {i} has no content." + logger.warning(err) + continue + pairs.append((query, content)) + valid_docs.append(doc) + + if not pairs: + logger.warning("No valid document pairs to rerank.") + return [] + + try: + scores = self.model.predict(pairs) + except Exception as e: + err = f"Model prediction failed: {e}" + logger.exception(err) + return valid_docs # fallback: return unranked valid docs + + # Sort by score descending + if len(scores) != len(valid_docs): + logger.warning("Mismatch in number of scores and documents") + return valid_docs # or handle the mismatch appropriately + + result = sorted( + zip(scores, valid_docs, strict=False), + key=lambda x: x[0], + reverse=True, + ) + return [doc for _, doc in result] diff --git a/scripts/create_tables.py b/scripts/create_tables.py new file mode 100644 index 0000000..a96ce71 --- /dev/null +++ b/scripts/create_tables.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Database table creation script with pgvector support. + +This script initializes the database with the required tables and extensions +for vector similarity search using pgvector. +""" + +import os +from collections.abc import Generator +from contextlib import contextmanager + +import psycopg2 +import structlog +from dotenv import load_dotenv +from psycopg2.extensions import connection as pg_connection +from psycopg2.extensions import cursor as pg_cursor + +# Configure structlog +structlog.configure( + processors=[ + structlog.processors.add_log_level, + structlog.processors.StackInfoRenderer(), + structlog.dev.ConsoleRenderer(), + ], + wrapper_class=structlog.make_filtering_bound_logger(structlog.INFO), + context_class=dict, + logger_factory=structlog.PrintLoggerFactory(), + cache_logger_on_first_use=False, +) +logger = structlog.get_logger() + + +def get_db_config() -> dict: + """ + Retrieve database configuration from environment variables. + + Returns: + dict: Database connection parameters + + """ + return { + "host": os.getenv("DB_HOST", "localhost"), + "port": os.getenv("DB_PORT", "5432"), + "user": os.getenv("DB_USER", "user"), + "password": os.getenv("DB_PASSWORD", "password"), + "database": os.getenv("DB_NAME", "mydatabase"), + } + + +@contextmanager +def get_db_connection() -> Generator[tuple[pg_connection, pg_cursor], None, None]: + """ + Context manager for database connection handling. + + Yields: + Tuple containing connection and cursor objects + + """ + conn = None + try: + db_config = get_db_config() + conn = psycopg2.connect(**db_config) + cursor = conn.cursor() + logger.info( + "Successfully connected to PostgreSQL database", + database=db_config["database"], + host=db_config["host"], + ) + yield conn, cursor + except Exception: + logger.exception("Database connection failed") + raise + finally: + if conn: + cursor.close() + conn.close() + + +def create_vector_extension(conn: pg_connection, cursor: pg_cursor) -> None: + """Create pgvector extension if it doesn't exist.""" + try: + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;") + conn.commit() + logger.info("pgvector extension is enabled") + except Exception: + logger.exception("Failed to create pgvector extension") + conn.rollback() + raise + + +def create_documents_table(conn: pg_connection, cursor: pg_cursor) -> None: + """Create documents table with vector support.""" + try: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS documents ( + id SERIAL PRIMARY KEY, + content TEXT NOT NULL, + embedding VECTOR(384), -- Match the dimension of your embedding model + source VARCHAR(255), + created_at TIMESTAMPTZ DEFAULT NOW() + ); + """) + conn.commit() + logger.info("Table 'documents' created successfully") + except Exception: + logger.exception("Failed to create documents table") + conn.rollback() + raise + + +def create_vector_index( + conn: pg_connection, cursor: pg_cursor, dimensions: int = 384 +) -> None: + """Create HNSW index on the vector column.""" + try: + logger.info("Creating HNSW index on vectors of dimension %s", dimensions) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS documents_embedding_idx + ON documents + USING HNSW (embedding vector_cosine_ops) + WITH (m = 16, ef_construction = 64); + """) + conn.commit() + logger.info("HNSW index created successfully") + except Exception: + logger.exception("Failed to create vector index") + conn.rollback() + raise + + +def main() -> None: + """Main function to set up the database schema.""" + load_dotenv() + logger.info("Starting database setup") + + try: + with get_db_connection() as (conn, cursor): + create_vector_extension(conn, cursor) + create_documents_table(conn, cursor) + create_vector_index(conn, cursor) + + logger.info("Database setup completed successfully") + except Exception: + logger.exception("Database setup failed") + raise SystemExit(1) from None + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.INFO) + main()