diff --git a/app/schemas/models.py b/app/schemas/models.py index 2a43a63..f968e7b 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -1,5 +1,4 @@ from pydantic import BaseModel, Field -from typing import Optional class QueryRequest(BaseModel): @@ -17,13 +16,13 @@ class IngestResponse(BaseModel): class ConfigUpdateRequest(BaseModel): - embedding_model: Optional[str] = Field( + embedding_model: str | None = Field( None, description="Name of the embedding model to use" ) - reranker_model: Optional[str] = Field( + reranker_model: str | None = Field( None, description="Name of the reranker model to use" ) - llm_model: Optional[str] = Field(None, description="Name of the LLM model to use") - llm_provider: Optional[str] = Field( + llm_model: str | None = Field(None, description="Name of the LLM model to use") + llm_provider: str | None = Field( None, description="Name of the LLM provider to use" - ) \ No newline at end of file + ) diff --git a/app/services/config_service.py b/app/services/config_service.py index 2b03cc6..d7d06a5 100644 --- a/app/services/config_service.py +++ b/app/services/config_service.py @@ -18,13 +18,11 @@ class ConfigService: "reranker_model": False, } - # Register available models self._register_models() def _register_models(self): - # Register embedding models + """Register all default models""" embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel) - # Register reranker models reranker_registry.register("MiniLMReranker", MiniLMReranker) async def initialize_models(self): @@ -33,7 +31,7 @@ class ConfigService: default_embedding_model_name = settings.EMBEDDING_MODEL await self.set_embedding_model(default_embedding_model_name) logger.info( - f"Default embedding model initialized: {default_embedding_model_name}" + "Default embedding model initialized: %s", default_embedding_model_name ) # Initialize default reranker model (if any) @@ -42,7 +40,7 @@ class ConfigService: default_reranker_model_name = "MiniLMReranker" # Or from settings await self.set_reranker_model(default_reranker_model_name) logger.info( - f"Default reranker model initialized: {default_reranker_model_name}" + "Default reranker model initialized: %s", default_reranker_model_name ) async def set_embedding_model(self, model_name: str) -> str: @@ -57,21 +55,22 @@ class ConfigService: try: self._loading_status["embedding_model"] = True - logger.info(f"Attempting to load embedding model: {model_name}") + logger.info("Attempting to load embedding model: %s", model_name) model_constructor = embedding_model_registry.get(model_name) self._current_embedding_model = model_constructor() settings.EMBEDDING_MODEL = model_name # Update settings - logger.info(f"Successfully loaded embedding model: {model_name}") - return f"Embedding model set to '{model_name}' successfully." except KeyError: - logger.warning(f"Embedding model '{model_name}' not found in registry.") + logger.warning("Embedding model '%s' not found in registry.", model_name) return ( f"Embedding model '{model_name}' not available. " f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'." ) except Exception as e: - logger.exception(f"Error loading embedding model {model_name}: {e}") + logger.exception("Error loading embedding model %s: %s", model_name, e) return f"Failed to load embedding model '{model_name}': {e}" + else: + logger.info("Successfully loaded embedding model: %s", model_name) + return f"Embedding model set to '{model_name}' successfully." finally: self._loading_status["embedding_model"] = False @@ -87,21 +86,22 @@ class ConfigService: try: self._loading_status["reranker_model"] = True - logger.info(f"Attempting to load reranker model: {model_name}") + logger.info("Attempting to load reranker model: %s", model_name) model_constructor = reranker_registry.get(model_name) self._current_reranker_model = model_constructor() - # settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist - logger.info(f"Successfully loaded reranker model: {model_name}") - return f"Reranker model set to '{model_name}' successfully." + # settings.RERANKER_MODEL = model_name except KeyError: - logger.warning(f"Reranker model '{model_name}' not found in registry.") + logger.warning("Reranker model '%s' not found in registry.", model_name) return ( f"Reranker model '{model_name}' not available. " f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'." ) except Exception as e: - logger.exception(f"Error loading reranker model {model_name}: {e}") + logger.exception("Error loading reranker model %s: %s", model_name, e) return f"Failed to load reranker model '{model_name}': {e}" + else: + logger.info("Successfully loaded reranker model: %s", model_name) + return f"Reranker model set to '{model_name}' successfully." finally: self._loading_status["reranker_model"] = False diff --git a/app/services/rag_service.py b/app/services/rag_service.py index 7d69acd..05a5cd0 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import TypedDict import litellm +from PyPDF2 import PdfReader +from PyPDF2.errors import PyPdfError from structlog import get_logger from app.core.interfaces import EmbeddingModel, Reranker, VectorDB @@ -67,8 +69,29 @@ Answer:""" self.vector_db.upsert_documents(documents_to_upsert) def ingest_document(self, file_path: Path, source_name: str): - with Path(file_path).open("r", encoding="utf-8") as f: - text = f.read() + path = Path(file_path) + ext = path.suffix + text = "" + if ext == ".pdf": + try: + reader = PdfReader(str(file_path)) + text = "\n".join(page.extract_text() or "" for page in reader.pages) + except PyPdfError as e: + logger.exception("PDF processing error for %s", file_path) + raise ValueError( + f"Failed to extract text from PDF due to a PDF processing error: {e}" + ) from e + except Exception as e: + logger.exception( + "An unexpected error occurred during PDF processing for %s", + file_path, + ) + raise RuntimeError( + f"An unexpected error occurred during PDF processing: {e}" + ) from e + else: + with Path(file_path).open("r", encoding="utf-8") as f: + text = f.read() text_chunks = self._split_text(text) self._ingest_document(text_chunks, source_name) @@ -102,8 +125,14 @@ Answer:""" max_tokens=500, ) - answer_text = response.choices[0].message.content.strip() - + answer_text = None + choices = getattr(response, "choices", None) + if choices and len(choices) > 0: + first_choice = choices[0] + message = getattr(first_choice, "message", None) + content = getattr(message, "content", None) + if content: + answer_text = content.strip() if not answer_text: answer_text = "No answer generated" sources = ["No sources"] @@ -146,12 +175,13 @@ Answer:""" stream=True, ) - # Yield each chunk of the response as it's generated for chunk in response: - if chunk.choices: - delta = chunk.choices[0].delta - if hasattr(delta, "content") and delta.content: - yield f'data: {{"token": "{json.dumps(delta.content)}"}}\n\n' + choices = getattr(chunk, "choices", None) + if choices and len(choices) > 0: + delta = getattr(choices[0], "delta", None) + content = getattr(delta, "content", None) + if content: + yield f'data: {{"token": {json.dumps(content)}}}\n\n' # Yield sources at the end yield f'data: {{"sources": {json.dumps(sources)}}}\n\n' diff --git a/app/services/rag_service_v1.py b/app/services/rag_service_v1.py deleted file mode 100644 index 48ec0db..0000000 --- a/app/services/rag_service_v1.py +++ /dev/null @@ -1,283 +0,0 @@ -import os -from collections.abc import Generator -from pathlib import Path -from typing import TypedDict - -import litellm -import numpy as np -import psycopg2 -from dotenv import load_dotenv -from psycopg2 import extras -from psycopg2.extensions import AsIs, register_adapter -from PyPDF2 import PdfReader -from sentence_transformers import SentenceTransformer -from structlog import get_logger - -from app.core.config import settings -from app.core.exception import DocumentExtractionError, DocumentInsertionError -from app.core.utils import RecursiveCharacterTextSplitter - -register_adapter(np.ndarray, AsIs) # for psycopg2 adapt -register_adapter(np.float32, AsIs) # for psycopg2 adapt -logger = get_logger() - -# pyright: reportArgumentType=false - -# Load environment variables -load_dotenv() - -# Initialize the embedding model globally to load it only once -EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") -EMBEDDING_DIM = 384 # Dimension of the all-MiniLM-L6-v2 model - -os.environ["GEMINI_API_KEY"] = settings.GEMINI_API_KEY - - -class AnswerResult(TypedDict): - answer: str - sources: list[str] - - -class RAGService: - def __init__(self): - logger.info("Initializing RAGService...") - # Load the embedding model ONCE - self.embedding_model = SentenceTransformer( - "all-MiniLM-L6-v2", device="cpu" - ) # Use 'cuda' if GPU is available - self.db_conn = psycopg2.connect( - host=settings.POSTGRES_SERVER, - port=settings.POSTGRES_PORT, - user=settings.POSTGRES_USER, - password=settings.POSTGRES_PASSWORD, - dbname=settings.POSTGRES_DB, - ) - logger.info("RAGService initialized.") - self.prompt = """Answer the question based on the following context. -If you don't know the answer, say you don't know. Don't make up an answer. - -Context: -{context} - -Question: {question} - -Answer:""" - - def _split_text( - self, text: str, chunk_size: int = 500, chunk_overlap: int = 100 - ) -> list[str]: - """ - Split text into chunks with specified size and overlap. - - Args: - text: Input text to split - chunk_size: Maximum size of each chunk in characters - chunk_overlap: Number of characters to overlap between chunks - - Returns: - List of text chunks - - """ - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - return text_splitter.split_text(text) - - def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray: - """ - Generate embedding for a text chunk. - - Args: - text: Input text to embed - show_progress_bar: Whether to show a progress bar - - Returns: - Numpy array containing the embedding vector - - """ - return EMBEDDING_MODEL.encode( - text, convert_to_numpy=True, show_progress_bar=show_progress_bar - ) - - def _store_document( - self, contents: list[str], embeddings: list[np.ndarray], source: str - ) -> int: - """ - Store a document chunk in the database. - - Args: - contents: List of text content of the chunk - embeddings: List of embedding vectors of the chunk - source: Source file path - - Returns: - ID of the inserted document - - """ - data_to_insert = [ - (chunk, f"[{', '.join(map(str, embedding))}]", source) - for chunk, embedding in zip(contents, embeddings, strict=True) - ] - - query = """ - INSERT INTO documents (content, embedding, source) - VALUES %s - RETURNING id - """ - with self.db_conn.cursor() as cursor: - extras.execute_values( - cursor, - query, - data_to_insert, - template="(%s, %s::vector, %s)", - page_size=100, - ) - inserted_ids = [row[0] for row in cursor.fetchall()] - self.db_conn.commit() - - if not inserted_ids: - raise DocumentInsertionError("No documents were inserted.") - - logger.info("Successfully bulk-ingested %d documents", len(inserted_ids)) - logger.info("Inserted document IDs: %s", inserted_ids) - return inserted_ids[0] - - def _extract_text_from_pdf(self, pdf_path: str) -> str: - """ - Extract text from a PDF file. - - Args: - pdf_path: Path to the PDF file - - Returns: - Extracted text as a single string - - """ - try: - reader = PdfReader(pdf_path) - text = "" - for page in reader.pages: - text += page.extract_text() + "\n" - return text.strip() - except Exception as e: - raise DocumentExtractionError( - "Error extracting text from PDF: " + str(e) - ) from e - - def _get_relevant_context(self, question: str, top_k: int) -> list[tuple[str, str]]: - """Get the most relevant document chunks for a given question""" - question_embedding = self.embedding_model.encode( - question, convert_to_numpy=True - ) - - try: - with self.db_conn.cursor() as cursor: - cursor.execute( - """ - SELECT content, source - FROM documents - ORDER BY embedding <-> %s::vector - LIMIT %s - """, - (question_embedding.tolist(), top_k), - ) - results = cursor.fetchall() - return results - except Exception as e: - logger.exception("Error retrieving context: %s", e) - return [] - - def ingest_document(self, file_path: str, filename: str): - logger.info("Ingesting %s...", filename) - if not Path(file_path).exists(): - err = f"File not found: {filename}" - raise FileNotFoundError(err) - - logger.info("Processing PDF: %s : %s", filename, file_path) - - text = self._extract_text_from_pdf(file_path) - if not text.strip(): - err = "No text could be extracted from the PDF" - raise ValueError(err) - - chunks = self._split_text(text) - logger.info("Split PDF into %d chunks", len(chunks)) - - embeddings = self._get_embedding(chunks, show_progress_bar=True) - self._store_document(chunks, embeddings, filename) - - logger.info("Successfully processed %d chunks from %s", len(chunks), filename) - - def answer_query(self, question: str) -> AnswerResult: - relevant_context = self._get_relevant_context(question, 5) - context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) - sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) - - try: - response = litellm.completion( - model="gemini/gemini-2.0-flash", - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that answers questions based on the provided context.", - }, - { - "role": "user", - "content": self.prompt.format( - context=context_str, question=question - ), - }, - ], - temperature=0.1, - max_tokens=500, - ) - - answer_text = response.choices[0].message.content.strip() - - if not answer_text: - answer_text = "No answer generated" - sources = ["No sources"] - - return AnswerResult(answer=answer_text, sources=sources) - - except Exception: - logger.exception("Error generating response") - return AnswerResult( - answer="Error generating response", sources=["No sources"] - ) - - def answer_query_stream(self, question: str) -> Generator[str, None, None]: - """Answer a query using streaming.""" - relevant_context = self._get_relevant_context(question, 5) - context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) - sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) - - prompt = self.prompt.format(context=context_str, question=question) - - try: - response = litellm.completion( - model="gemini/gemini-2.0-flash", - messages=[{"role": "user", "content": prompt}], - stream=True, - ) - - # First, yield the sources so the UI can display them immediately - import json - - sources_json = json.dumps(sources) - yield f'data: {{"sources": {sources_json}}}\n\n' - - # Then, stream the answer tokens - for chunk in response: - token = chunk.choices[0].delta.content - if token: # Ensure there's content to send - # SSE format: data: {"token": "..."}\n\n - yield f'data: {{"token": "{json.dumps(token)}"}}\n\n' - - # Signal the end of the stream with a special message - yield 'data: {"end_of_stream": true}\n\n' - - except Exception: - logger.exception("Error generating response") - yield 'data: {"error": "Error generating response"}\n\n'