From 3294dafaa64e2ad8eb8f8001dd2365a76ef81bce Mon Sep 17 00:00:00 2001 From: Sosokker Date: Tue, 24 Jun 2025 16:44:52 +0700 Subject: [PATCH] refactor(interface): use protocol to create module interface --- app/api/endpoints.py | 4 +- app/core/interfaces.py | 22 +++ app/main.py | 16 +- app/services/embedding_providers.py | 15 ++ app/services/rag_service.py | 232 +++++------------------ app/services/rag_service_v1.py | 283 ++++++++++++++++++++++++++++ app/services/vector_stores.py | 113 +++++++++++ 7 files changed, 493 insertions(+), 192 deletions(-) create mode 100644 app/core/interfaces.py create mode 100644 app/services/embedding_providers.py create mode 100644 app/services/rag_service_v1.py create mode 100644 app/services/vector_stores.py diff --git a/app/api/endpoints.py b/app/api/endpoints.py index 2ccaca3..76cf97b 100644 --- a/app/api/endpoints.py +++ b/app/api/endpoints.py @@ -38,9 +38,7 @@ async def ingest_file( shutil.copyfileobj(file.file, buffer) # Add the ingestion task to run in the background - background_tasks.add_task( - rag_service.ingest_document, file_path.as_posix(), file.filename - ) + background_tasks.add_task(rag_service.ingest_document, file_path, file.filename) # Immediately return a response to the user return { diff --git a/app/core/interfaces.py b/app/core/interfaces.py new file mode 100644 index 0000000..bd4081e --- /dev/null +++ b/app/core/interfaces.py @@ -0,0 +1,22 @@ +from typing import Protocol, TypedDict + +import numpy as np + + +class SearchResult(TypedDict): + """Type definition for search results.""" + + content: str + source: str + + +class EmbeddingModel(Protocol): + def embed_documents(self, texts: list[str]) -> list[np.ndarray]: ... + + def embed_query(self, text: str) -> np.ndarray: ... + + +class VectorDB(Protocol): + def upsert_documents(self, documents: list[dict]) -> None: ... + + def search(self, vector: np.ndarray, top_k: int) -> list[SearchResult]: ... diff --git a/app/main.py b/app/main.py index f3767ab..f56cfe0 100644 --- a/app/main.py +++ b/app/main.py @@ -5,7 +5,9 @@ from fastapi import FastAPI from structlog import get_logger from app.api import endpoints +from app.services.embedding_providers import MiniLMEmbeddingModel from app.services.rag_service import RAGService +from app.services.vector_stores import PGVectorStore logger = get_logger() @@ -15,23 +17,27 @@ load_dotenv() # Dictionary to hold our application state, including the RAG service instance app_state = {} + @asynccontextmanager async def lifespan(app: FastAPI): + embedding_provider = MiniLMEmbeddingModel() + vector_store_provider = PGVectorStore() + # This code runs on startup logger.info("Application starting up...") # Initialize the RAG Service and store it in the app_state - app_state["rag_service"] = RAGService() + app_state["rag_service"] = RAGService( + embedding_model=embedding_provider, vector_db=vector_store_provider + ) yield - # This code runs on shutdown - logger.info("Application shutting down...") - app_state["rag_service"].db_conn.close() # Clean up DB connection - app_state.clear() + app = FastAPI(lifespan=lifespan) # Include the API router app.include_router(endpoints.router) + @app.get("/") def read_root(): return {"message": "Welcome to the Custom RAG API"} diff --git a/app/services/embedding_providers.py b/app/services/embedding_providers.py new file mode 100644 index 0000000..0cbcc70 --- /dev/null +++ b/app/services/embedding_providers.py @@ -0,0 +1,15 @@ +import numpy as np +from sentence_transformers import SentenceTransformer + +from app.core.interfaces import EmbeddingModel + + +class MiniLMEmbeddingModel(EmbeddingModel): + def __init__(self, model_name: str = "all-MiniLM-L6-v2"): + self.model = SentenceTransformer(model_name) + + def embed_documents(self, texts: list[str]) -> list[np.ndarray]: + return self.model.encode(texts).tolist() + + def embed_query(self, text: str) -> np.ndarray: + return self.model.encode([text])[0].tolist() diff --git a/app/services/rag_service.py b/app/services/rag_service.py index 4874594..8da7a42 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -1,37 +1,16 @@ -import os +import json from collections.abc import Generator from pathlib import Path from typing import TypedDict import litellm -import numpy as np -import psycopg2 -from dotenv import load_dotenv -from psycopg2 import extras -from psycopg2.extensions import AsIs, register_adapter -from PyPDF2 import PdfReader -from sentence_transformers import SentenceTransformer from structlog import get_logger -from app.core.config import settings -from app.core.exception import DocumentExtractionError, DocumentInsertionError +from app.core.interfaces import EmbeddingModel, VectorDB from app.core.utils import RecursiveCharacterTextSplitter -register_adapter(np.ndarray, AsIs) # for psycopg2 adapt -register_adapter(np.float32, AsIs) # for psycopg2 adapt logger = get_logger() -# pyright: reportArgumentType=false - -# Load environment variables -load_dotenv() - -# Initialize the embedding model globally to load it only once -EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") -EMBEDDING_DIM = 384 # Dimension of the all-MiniLM-L6-v2 model - -os.environ["GEMINI_API_KEY"] = settings.GEMINI_API_KEY - class AnswerResult(TypedDict): answer: str @@ -39,20 +18,9 @@ class AnswerResult(TypedDict): class RAGService: - def __init__(self): - logger.info("Initializing RAGService...") - # Load the embedding model ONCE - self.embedding_model = SentenceTransformer( - "all-MiniLM-L6-v2", device="cpu" - ) # Use 'cuda' if GPU is available - self.db_conn = psycopg2.connect( - host=settings.POSTGRES_SERVER, - port=settings.POSTGRES_PORT, - user=settings.POSTGRES_USER, - password=settings.POSTGRES_PASSWORD, - dbname=settings.POSTGRES_DB, - ) - logger.info("RAGService initialized.") + def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB): + self.embedding_model = embedding_model + self.vector_db = vector_db self.prompt = """Answer the question based on the following context. If you don't know the answer, say you don't know. Don't make up an answer. @@ -84,135 +52,26 @@ Answer:""" ) return text_splitter.split_text(text) - def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray: - """ - Generate embedding for a text chunk. - - Args: - text: Input text to embed - show_progress_bar: Whether to show a progress bar - - Returns: - Numpy array containing the embedding vector - - """ - return EMBEDDING_MODEL.encode( - text, convert_to_numpy=True, show_progress_bar=show_progress_bar - ) - - def _store_document( - self, contents: list[str], embeddings: list[np.ndarray], source: str - ) -> int: - """ - Store a document chunk in the database. - - Args: - contents: List of text content of the chunk - embeddings: List of embedding vectors of the chunk - source: Source file path - - Returns: - ID of the inserted document - - """ - data_to_insert = [ - (chunk, f"[{', '.join(map(str, embedding))}]", source) - for chunk, embedding in zip(contents, embeddings, strict=True) + def _ingest_document(self, text_chunks: list[str], source_name: str): + embeddings = self.embedding_model.embed_documents(text_chunks) + documents_to_upsert = [ + {"content": chunk, "embedding": emb, "source": source_name} + for chunk, emb in zip(text_chunks, embeddings, strict=False) ] + self.vector_db.upsert_documents(documents_to_upsert) - query = """ - INSERT INTO documents (content, embedding, source) - VALUES %s - RETURNING id - """ - with self.db_conn.cursor() as cursor: - extras.execute_values( - cursor, - query, - data_to_insert, - template="(%s, %s::vector, %s)", - page_size=100, - ) - inserted_ids = [row[0] for row in cursor.fetchall()] - self.db_conn.commit() - - if not inserted_ids: - raise DocumentInsertionError("No documents were inserted.") - - logger.info("Successfully bulk-ingested %d documents", len(inserted_ids)) - logger.info("Inserted document IDs: %s", inserted_ids) - return inserted_ids[0] - - def _extract_text_from_pdf(self, pdf_path: str) -> str: - """ - Extract text from a PDF file. - - Args: - pdf_path: Path to the PDF file - - Returns: - Extracted text as a single string - - """ - try: - reader = PdfReader(pdf_path) - text = "" - for page in reader.pages: - text += page.extract_text() + "\n" - return text.strip() - except Exception as e: - raise DocumentExtractionError( - "Error extracting text from PDF: " + str(e) - ) from e - - def _get_relevant_context(self, question: str, top_k: int) -> list[tuple[str, str]]: - """Get the most relevant document chunks for a given question""" - question_embedding = self.embedding_model.encode( - question, convert_to_numpy=True - ) - - try: - with self.db_conn.cursor() as cursor: - cursor.execute( - """ - SELECT content, source - FROM documents - ORDER BY embedding <-> %s::vector - LIMIT %s - """, - (question_embedding.tolist(), top_k), - ) - results = cursor.fetchall() - return results - except Exception as e: - logger.exception("Error retrieving context: %s", e) - return [] - - def ingest_document(self, file_path: str, filename: str): - logger.info("Ingesting %s...", filename) - if not Path(file_path).exists(): - err = f"File not found: {filename}" - raise FileNotFoundError(err) - - logger.info("Processing PDF: %s : %s", filename, file_path) - - text = self._extract_text_from_pdf(file_path) - if not text.strip(): - err = "No text could be extracted from the PDF" - raise ValueError(err) - - chunks = self._split_text(text) - logger.info("Split PDF into %d chunks", len(chunks)) - - embeddings = self._get_embedding(chunks, show_progress_bar=True) - self._store_document(chunks, embeddings, filename) - - logger.info("Successfully processed %d chunks from %s", len(chunks), filename) + def ingest_document(self, file_path: Path, source_name: str): + with Path(file_path).open("r", encoding="utf-8") as f: + text = f.read() + text_chunks = self._split_text(text) + self._ingest_document(text_chunks, source_name) def answer_query(self, question: str) -> AnswerResult: - relevant_context = self._get_relevant_context(question, 5) - context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) - sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) + query_embedding = self.embedding_model.embed_query(question) + search_results = self.vector_db.search(query_embedding, top_k=5) + sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) + + context_str = "\n\n".join([chunk["content"] for chunk in search_results]) try: response = litellm.completion( @@ -233,14 +92,13 @@ Answer:""" max_tokens=500, ) - answer_text = response.choices[0].message.content.strip() + answer_text = response.choices[0].message.content.strip() # type: ignore if not answer_text: answer_text = "No answer generated" sources = ["No sources"] return AnswerResult(answer=answer_text, sources=sources) - except Exception: logger.exception("Error generating response") return AnswerResult( @@ -248,36 +106,42 @@ Answer:""" ) def answer_query_stream(self, question: str) -> Generator[str, None, None]: - """Answer a query using streaming.""" - relevant_context = self._get_relevant_context(question, 5) - context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) - sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) - - prompt = self.prompt.format(context=context_str, question=question) + query_embedding = self.embedding_model.embed_query(question) + search_results = self.vector_db.search(query_embedding, top_k=5) + sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) + context_str = "\n\n".join([chunk["content"] for chunk in search_results]) try: response = litellm.completion( model="gemini/gemini-2.0-flash", - messages=[{"role": "user", "content": prompt}], + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that answers questions based on the provided context.", + }, + { + "role": "user", + "content": self.prompt.format( + context=context_str, question=question + ), + }, + ], + temperature=0.1, + max_tokens=500, stream=True, ) - # First, yield the sources so the UI can display them immediately - import json - - sources_json = json.dumps(sources) - yield f'data: {{"sources": {sources_json}}}\n\n' - - # Then, stream the answer tokens + # Yield each chunk of the response as it's generated for chunk in response: - token = chunk.choices[0].delta.content - if token: # Ensure there's content to send - # SSE format: data: {"token": "..."}\n\n - yield f'data: {{"token": "{json.dumps(token)}"}}\n\n' + if chunk.choices: + delta = chunk.choices[0].delta + if hasattr(delta, "content") and delta.content: + yield f'data: {{"token": "{json.dumps(delta.content)}"}}\n\n' - # Signal the end of the stream with a special message + # Yield sources at the end + yield f'data: {{"sources": {json.dumps(sources)}}}\n\n' yield 'data: {"end_of_stream": true}\n\n' except Exception: - logger.exception("Error generating response") + logger.exception("Error generating streaming response") yield 'data: {"error": "Error generating response"}\n\n' diff --git a/app/services/rag_service_v1.py b/app/services/rag_service_v1.py new file mode 100644 index 0000000..4874594 --- /dev/null +++ b/app/services/rag_service_v1.py @@ -0,0 +1,283 @@ +import os +from collections.abc import Generator +from pathlib import Path +from typing import TypedDict + +import litellm +import numpy as np +import psycopg2 +from dotenv import load_dotenv +from psycopg2 import extras +from psycopg2.extensions import AsIs, register_adapter +from PyPDF2 import PdfReader +from sentence_transformers import SentenceTransformer +from structlog import get_logger + +from app.core.config import settings +from app.core.exception import DocumentExtractionError, DocumentInsertionError +from app.core.utils import RecursiveCharacterTextSplitter + +register_adapter(np.ndarray, AsIs) # for psycopg2 adapt +register_adapter(np.float32, AsIs) # for psycopg2 adapt +logger = get_logger() + +# pyright: reportArgumentType=false + +# Load environment variables +load_dotenv() + +# Initialize the embedding model globally to load it only once +EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") +EMBEDDING_DIM = 384 # Dimension of the all-MiniLM-L6-v2 model + +os.environ["GEMINI_API_KEY"] = settings.GEMINI_API_KEY + + +class AnswerResult(TypedDict): + answer: str + sources: list[str] + + +class RAGService: + def __init__(self): + logger.info("Initializing RAGService...") + # Load the embedding model ONCE + self.embedding_model = SentenceTransformer( + "all-MiniLM-L6-v2", device="cpu" + ) # Use 'cuda' if GPU is available + self.db_conn = psycopg2.connect( + host=settings.POSTGRES_SERVER, + port=settings.POSTGRES_PORT, + user=settings.POSTGRES_USER, + password=settings.POSTGRES_PASSWORD, + dbname=settings.POSTGRES_DB, + ) + logger.info("RAGService initialized.") + self.prompt = """Answer the question based on the following context. +If you don't know the answer, say you don't know. Don't make up an answer. + +Context: +{context} + +Question: {question} + +Answer:""" + + def _split_text( + self, text: str, chunk_size: int = 500, chunk_overlap: int = 100 + ) -> list[str]: + """ + Split text into chunks with specified size and overlap. + + Args: + text: Input text to split + chunk_size: Maximum size of each chunk in characters + chunk_overlap: Number of characters to overlap between chunks + + Returns: + List of text chunks + + """ + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter.split_text(text) + + def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray: + """ + Generate embedding for a text chunk. + + Args: + text: Input text to embed + show_progress_bar: Whether to show a progress bar + + Returns: + Numpy array containing the embedding vector + + """ + return EMBEDDING_MODEL.encode( + text, convert_to_numpy=True, show_progress_bar=show_progress_bar + ) + + def _store_document( + self, contents: list[str], embeddings: list[np.ndarray], source: str + ) -> int: + """ + Store a document chunk in the database. + + Args: + contents: List of text content of the chunk + embeddings: List of embedding vectors of the chunk + source: Source file path + + Returns: + ID of the inserted document + + """ + data_to_insert = [ + (chunk, f"[{', '.join(map(str, embedding))}]", source) + for chunk, embedding in zip(contents, embeddings, strict=True) + ] + + query = """ + INSERT INTO documents (content, embedding, source) + VALUES %s + RETURNING id + """ + with self.db_conn.cursor() as cursor: + extras.execute_values( + cursor, + query, + data_to_insert, + template="(%s, %s::vector, %s)", + page_size=100, + ) + inserted_ids = [row[0] for row in cursor.fetchall()] + self.db_conn.commit() + + if not inserted_ids: + raise DocumentInsertionError("No documents were inserted.") + + logger.info("Successfully bulk-ingested %d documents", len(inserted_ids)) + logger.info("Inserted document IDs: %s", inserted_ids) + return inserted_ids[0] + + def _extract_text_from_pdf(self, pdf_path: str) -> str: + """ + Extract text from a PDF file. + + Args: + pdf_path: Path to the PDF file + + Returns: + Extracted text as a single string + + """ + try: + reader = PdfReader(pdf_path) + text = "" + for page in reader.pages: + text += page.extract_text() + "\n" + return text.strip() + except Exception as e: + raise DocumentExtractionError( + "Error extracting text from PDF: " + str(e) + ) from e + + def _get_relevant_context(self, question: str, top_k: int) -> list[tuple[str, str]]: + """Get the most relevant document chunks for a given question""" + question_embedding = self.embedding_model.encode( + question, convert_to_numpy=True + ) + + try: + with self.db_conn.cursor() as cursor: + cursor.execute( + """ + SELECT content, source + FROM documents + ORDER BY embedding <-> %s::vector + LIMIT %s + """, + (question_embedding.tolist(), top_k), + ) + results = cursor.fetchall() + return results + except Exception as e: + logger.exception("Error retrieving context: %s", e) + return [] + + def ingest_document(self, file_path: str, filename: str): + logger.info("Ingesting %s...", filename) + if not Path(file_path).exists(): + err = f"File not found: {filename}" + raise FileNotFoundError(err) + + logger.info("Processing PDF: %s : %s", filename, file_path) + + text = self._extract_text_from_pdf(file_path) + if not text.strip(): + err = "No text could be extracted from the PDF" + raise ValueError(err) + + chunks = self._split_text(text) + logger.info("Split PDF into %d chunks", len(chunks)) + + embeddings = self._get_embedding(chunks, show_progress_bar=True) + self._store_document(chunks, embeddings, filename) + + logger.info("Successfully processed %d chunks from %s", len(chunks), filename) + + def answer_query(self, question: str) -> AnswerResult: + relevant_context = self._get_relevant_context(question, 5) + context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) + sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) + + try: + response = litellm.completion( + model="gemini/gemini-2.0-flash", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that answers questions based on the provided context.", + }, + { + "role": "user", + "content": self.prompt.format( + context=context_str, question=question + ), + }, + ], + temperature=0.1, + max_tokens=500, + ) + + answer_text = response.choices[0].message.content.strip() + + if not answer_text: + answer_text = "No answer generated" + sources = ["No sources"] + + return AnswerResult(answer=answer_text, sources=sources) + + except Exception: + logger.exception("Error generating response") + return AnswerResult( + answer="Error generating response", sources=["No sources"] + ) + + def answer_query_stream(self, question: str) -> Generator[str, None, None]: + """Answer a query using streaming.""" + relevant_context = self._get_relevant_context(question, 5) + context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) + sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) + + prompt = self.prompt.format(context=context_str, question=question) + + try: + response = litellm.completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + + # First, yield the sources so the UI can display them immediately + import json + + sources_json = json.dumps(sources) + yield f'data: {{"sources": {sources_json}}}\n\n' + + # Then, stream the answer tokens + for chunk in response: + token = chunk.choices[0].delta.content + if token: # Ensure there's content to send + # SSE format: data: {"token": "..."}\n\n + yield f'data: {{"token": "{json.dumps(token)}"}}\n\n' + + # Signal the end of the stream with a special message + yield 'data: {"end_of_stream": true}\n\n' + + except Exception: + logger.exception("Error generating response") + yield 'data: {"error": "Error generating response"}\n\n' diff --git a/app/services/vector_stores.py b/app/services/vector_stores.py new file mode 100644 index 0000000..111e053 --- /dev/null +++ b/app/services/vector_stores.py @@ -0,0 +1,113 @@ +import numpy as np +import psycopg2 +from psycopg2.extensions import AsIs, register_adapter +from psycopg2.extras import execute_values + +from app.core.config import settings +from app.core.interfaces import SearchResult, VectorDB + +# Register NumPy array and float32 adapters for psycopg2 +register_adapter(np.ndarray, AsIs) +register_adapter(np.float32, AsIs) + + +class PGVectorStore(VectorDB): + """PostgreSQL vector store implementation for document storage and retrieval.""" + + def __init__(self): + pass + + def _get_connection(self): + """Get a new database connection.""" + return psycopg2.connect( + host=settings.POSTGRES_SERVER, + port=settings.POSTGRES_PORT, + user=settings.POSTGRES_USER, + password=settings.POSTGRES_PASSWORD, + dbname=settings.POSTGRES_DB, + ) + + def upsert_documents(self, documents: list[dict]) -> None: + """ + Upsert documents into the vector store. + + Args: + documents: List of document dictionaries containing 'content', 'embedding', and 'source'. + + Raises: + ValueError: If required fields are missing from documents. + psycopg2.Error: For database-related errors. + + """ + if not documents: + return + + # Validate document structure + for doc in documents: + if not all(key in doc for key in ["content", "embedding", "source"]): + err = "Document must contain 'content', 'embedding', and 'source' keys" + raise ValueError(err) + + data_to_insert = [ + (doc["content"], np.array(doc["embedding"]), doc["source"]) + for doc in documents + ] + + query = """ + INSERT INTO documents (content, embedding, source) + VALUES %s + ON CONFLICT (content, source) DO UPDATE SET + embedding = EXCLUDED.embedding, + updated_at = NOW() + RETURNING id + """ + + with self._get_connection() as conn, conn.cursor() as cursor: + try: + execute_values( + cursor, + query, + data_to_insert, + template="(%s, %s::vector, %s)", + page_size=100, + ) + conn.commit() + except Exception: + conn.rollback() + raise + + def search(self, vector: np.ndarray, top_k: int = 5) -> list[SearchResult]: + """ + Search for similar documents using vector similarity. + + Args: + vector: The query vector to search with. + top_k: Maximum number of results to return. + + Returns: + List of search results with content and source. + + Raises: + psycopg2.Error: For database-related errors. + + """ + if not vector: + return [] + + query = """ + SELECT content, source + FROM documents + ORDER BY embedding <-> %s::vector + LIMIT %s + """ + + with self._get_connection() as conn, conn.cursor() as cursor: + try: + cursor.execute(query, (np.array(vector).tolist(), top_k)) + return [ + SearchResult(content=row[0], source=row[1]) + for row in cursor.fetchall() + ] + except Exception: + conn.rollback() + raise