refactor(interface): use protocol to create module interface

This commit is contained in:
Sosokker 2025-06-24 16:44:52 +07:00
parent 80af71935f
commit 3294dafaa6
7 changed files with 493 additions and 192 deletions

View File

@ -38,9 +38,7 @@ async def ingest_file(
shutil.copyfileobj(file.file, buffer) shutil.copyfileobj(file.file, buffer)
# Add the ingestion task to run in the background # Add the ingestion task to run in the background
background_tasks.add_task( background_tasks.add_task(rag_service.ingest_document, file_path, file.filename)
rag_service.ingest_document, file_path.as_posix(), file.filename
)
# Immediately return a response to the user # Immediately return a response to the user
return { return {

22
app/core/interfaces.py Normal file
View File

@ -0,0 +1,22 @@
from typing import Protocol, TypedDict
import numpy as np
class SearchResult(TypedDict):
"""Type definition for search results."""
content: str
source: str
class EmbeddingModel(Protocol):
def embed_documents(self, texts: list[str]) -> list[np.ndarray]: ...
def embed_query(self, text: str) -> np.ndarray: ...
class VectorDB(Protocol):
def upsert_documents(self, documents: list[dict]) -> None: ...
def search(self, vector: np.ndarray, top_k: int) -> list[SearchResult]: ...

View File

@ -5,7 +5,9 @@ from fastapi import FastAPI
from structlog import get_logger from structlog import get_logger
from app.api import endpoints from app.api import endpoints
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rag_service import RAGService from app.services.rag_service import RAGService
from app.services.vector_stores import PGVectorStore
logger = get_logger() logger = get_logger()
@ -15,23 +17,27 @@ load_dotenv()
# Dictionary to hold our application state, including the RAG service instance # Dictionary to hold our application state, including the RAG service instance
app_state = {} app_state = {}
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
embedding_provider = MiniLMEmbeddingModel()
vector_store_provider = PGVectorStore()
# This code runs on startup # This code runs on startup
logger.info("Application starting up...") logger.info("Application starting up...")
# Initialize the RAG Service and store it in the app_state # Initialize the RAG Service and store it in the app_state
app_state["rag_service"] = RAGService() app_state["rag_service"] = RAGService(
embedding_model=embedding_provider, vector_db=vector_store_provider
)
yield yield
# This code runs on shutdown
logger.info("Application shutting down...")
app_state["rag_service"].db_conn.close() # Clean up DB connection
app_state.clear()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
# Include the API router # Include the API router
app.include_router(endpoints.router) app.include_router(endpoints.router)
@app.get("/") @app.get("/")
def read_root(): def read_root():
return {"message": "Welcome to the Custom RAG API"} return {"message": "Welcome to the Custom RAG API"}

View File

@ -0,0 +1,15 @@
import numpy as np
from sentence_transformers import SentenceTransformer
from app.core.interfaces import EmbeddingModel
class MiniLMEmbeddingModel(EmbeddingModel):
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts: list[str]) -> list[np.ndarray]:
return self.model.encode(texts).tolist()
def embed_query(self, text: str) -> np.ndarray:
return self.model.encode([text])[0].tolist()

View File

@ -1,37 +1,16 @@
import os import json
from collections.abc import Generator from collections.abc import Generator
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import TypedDict
import litellm import litellm
import numpy as np
import psycopg2
from dotenv import load_dotenv
from psycopg2 import extras
from psycopg2.extensions import AsIs, register_adapter
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
from structlog import get_logger from structlog import get_logger
from app.core.config import settings from app.core.interfaces import EmbeddingModel, VectorDB
from app.core.exception import DocumentExtractionError, DocumentInsertionError
from app.core.utils import RecursiveCharacterTextSplitter from app.core.utils import RecursiveCharacterTextSplitter
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
register_adapter(np.float32, AsIs) # for psycopg2 adapt
logger = get_logger() logger = get_logger()
# pyright: reportArgumentType=false
# Load environment variables
load_dotenv()
# Initialize the embedding model globally to load it only once
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
EMBEDDING_DIM = 384 # Dimension of the all-MiniLM-L6-v2 model
os.environ["GEMINI_API_KEY"] = settings.GEMINI_API_KEY
class AnswerResult(TypedDict): class AnswerResult(TypedDict):
answer: str answer: str
@ -39,20 +18,9 @@ class AnswerResult(TypedDict):
class RAGService: class RAGService:
def __init__(self): def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB):
logger.info("Initializing RAGService...") self.embedding_model = embedding_model
# Load the embedding model ONCE self.vector_db = vector_db
self.embedding_model = SentenceTransformer(
"all-MiniLM-L6-v2", device="cpu"
) # Use 'cuda' if GPU is available
self.db_conn = psycopg2.connect(
host=settings.POSTGRES_SERVER,
port=settings.POSTGRES_PORT,
user=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD,
dbname=settings.POSTGRES_DB,
)
logger.info("RAGService initialized.")
self.prompt = """Answer the question based on the following context. self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer. If you don't know the answer, say you don't know. Don't make up an answer.
@ -84,135 +52,26 @@ Answer:"""
) )
return text_splitter.split_text(text) return text_splitter.split_text(text)
def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray: def _ingest_document(self, text_chunks: list[str], source_name: str):
""" embeddings = self.embedding_model.embed_documents(text_chunks)
Generate embedding for a text chunk. documents_to_upsert = [
{"content": chunk, "embedding": emb, "source": source_name}
Args: for chunk, emb in zip(text_chunks, embeddings, strict=False)
text: Input text to embed
show_progress_bar: Whether to show a progress bar
Returns:
Numpy array containing the embedding vector
"""
return EMBEDDING_MODEL.encode(
text, convert_to_numpy=True, show_progress_bar=show_progress_bar
)
def _store_document(
self, contents: list[str], embeddings: list[np.ndarray], source: str
) -> int:
"""
Store a document chunk in the database.
Args:
contents: List of text content of the chunk
embeddings: List of embedding vectors of the chunk
source: Source file path
Returns:
ID of the inserted document
"""
data_to_insert = [
(chunk, f"[{', '.join(map(str, embedding))}]", source)
for chunk, embedding in zip(contents, embeddings, strict=True)
] ]
self.vector_db.upsert_documents(documents_to_upsert)
query = """ def ingest_document(self, file_path: Path, source_name: str):
INSERT INTO documents (content, embedding, source) with Path(file_path).open("r", encoding="utf-8") as f:
VALUES %s text = f.read()
RETURNING id text_chunks = self._split_text(text)
""" self._ingest_document(text_chunks, source_name)
with self.db_conn.cursor() as cursor:
extras.execute_values(
cursor,
query,
data_to_insert,
template="(%s, %s::vector, %s)",
page_size=100,
)
inserted_ids = [row[0] for row in cursor.fetchall()]
self.db_conn.commit()
if not inserted_ids:
raise DocumentInsertionError("No documents were inserted.")
logger.info("Successfully bulk-ingested %d documents", len(inserted_ids))
logger.info("Inserted document IDs: %s", inserted_ids)
return inserted_ids[0]
def _extract_text_from_pdf(self, pdf_path: str) -> str:
"""
Extract text from a PDF file.
Args:
pdf_path: Path to the PDF file
Returns:
Extracted text as a single string
"""
try:
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
raise DocumentExtractionError(
"Error extracting text from PDF: " + str(e)
) from e
def _get_relevant_context(self, question: str, top_k: int) -> list[tuple[str, str]]:
"""Get the most relevant document chunks for a given question"""
question_embedding = self.embedding_model.encode(
question, convert_to_numpy=True
)
try:
with self.db_conn.cursor() as cursor:
cursor.execute(
"""
SELECT content, source
FROM documents
ORDER BY embedding <-> %s::vector
LIMIT %s
""",
(question_embedding.tolist(), top_k),
)
results = cursor.fetchall()
return results
except Exception as e:
logger.exception("Error retrieving context: %s", e)
return []
def ingest_document(self, file_path: str, filename: str):
logger.info("Ingesting %s...", filename)
if not Path(file_path).exists():
err = f"File not found: {filename}"
raise FileNotFoundError(err)
logger.info("Processing PDF: %s : %s", filename, file_path)
text = self._extract_text_from_pdf(file_path)
if not text.strip():
err = "No text could be extracted from the PDF"
raise ValueError(err)
chunks = self._split_text(text)
logger.info("Split PDF into %d chunks", len(chunks))
embeddings = self._get_embedding(chunks, show_progress_bar=True)
self._store_document(chunks, embeddings, filename)
logger.info("Successfully processed %d chunks from %s", len(chunks), filename)
def answer_query(self, question: str) -> AnswerResult: def answer_query(self, question: str) -> AnswerResult:
relevant_context = self._get_relevant_context(question, 5) query_embedding = self.embedding_model.embed_query(question)
context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) search_results = self.vector_db.search(query_embedding, top_k=5)
sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try: try:
response = litellm.completion( response = litellm.completion(
@ -233,14 +92,13 @@ Answer:"""
max_tokens=500, max_tokens=500,
) )
answer_text = response.choices[0].message.content.strip() answer_text = response.choices[0].message.content.strip() # type: ignore
if not answer_text: if not answer_text:
answer_text = "No answer generated" answer_text = "No answer generated"
sources = ["No sources"] sources = ["No sources"]
return AnswerResult(answer=answer_text, sources=sources) return AnswerResult(answer=answer_text, sources=sources)
except Exception: except Exception:
logger.exception("Error generating response") logger.exception("Error generating response")
return AnswerResult( return AnswerResult(
@ -248,36 +106,42 @@ Answer:"""
) )
def answer_query_stream(self, question: str) -> Generator[str, None, None]: def answer_query_stream(self, question: str) -> Generator[str, None, None]:
"""Answer a query using streaming.""" query_embedding = self.embedding_model.embed_query(question)
relevant_context = self._get_relevant_context(question, 5) search_results = self.vector_db.search(query_embedding, top_k=5)
context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) context_str = "\n\n".join([chunk["content"] for chunk in search_results])
prompt = self.prompt.format(context=context_str, question=question)
try: try:
response = litellm.completion( response = litellm.completion(
model="gemini/gemini-2.0-flash", model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": prompt}], messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
stream=True, stream=True,
) )
# First, yield the sources so the UI can display them immediately # Yield each chunk of the response as it's generated
import json
sources_json = json.dumps(sources)
yield f'data: {{"sources": {sources_json}}}\n\n'
# Then, stream the answer tokens
for chunk in response: for chunk in response:
token = chunk.choices[0].delta.content if chunk.choices:
if token: # Ensure there's content to send delta = chunk.choices[0].delta
# SSE format: data: {"token": "..."}\n\n if hasattr(delta, "content") and delta.content:
yield f'data: {{"token": "{json.dumps(token)}"}}\n\n' yield f'data: {{"token": "{json.dumps(delta.content)}"}}\n\n'
# Signal the end of the stream with a special message # Yield sources at the end
yield f'data: {{"sources": {json.dumps(sources)}}}\n\n'
yield 'data: {"end_of_stream": true}\n\n' yield 'data: {"end_of_stream": true}\n\n'
except Exception: except Exception:
logger.exception("Error generating response") logger.exception("Error generating streaming response")
yield 'data: {"error": "Error generating response"}\n\n' yield 'data: {"error": "Error generating response"}\n\n'

View File

@ -0,0 +1,283 @@
import os
from collections.abc import Generator
from pathlib import Path
from typing import TypedDict
import litellm
import numpy as np
import psycopg2
from dotenv import load_dotenv
from psycopg2 import extras
from psycopg2.extensions import AsIs, register_adapter
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
from structlog import get_logger
from app.core.config import settings
from app.core.exception import DocumentExtractionError, DocumentInsertionError
from app.core.utils import RecursiveCharacterTextSplitter
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
register_adapter(np.float32, AsIs) # for psycopg2 adapt
logger = get_logger()
# pyright: reportArgumentType=false
# Load environment variables
load_dotenv()
# Initialize the embedding model globally to load it only once
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
EMBEDDING_DIM = 384 # Dimension of the all-MiniLM-L6-v2 model
os.environ["GEMINI_API_KEY"] = settings.GEMINI_API_KEY
class AnswerResult(TypedDict):
answer: str
sources: list[str]
class RAGService:
def __init__(self):
logger.info("Initializing RAGService...")
# Load the embedding model ONCE
self.embedding_model = SentenceTransformer(
"all-MiniLM-L6-v2", device="cpu"
) # Use 'cuda' if GPU is available
self.db_conn = psycopg2.connect(
host=settings.POSTGRES_SERVER,
port=settings.POSTGRES_PORT,
user=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD,
dbname=settings.POSTGRES_DB,
)
logger.info("RAGService initialized.")
self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer.
Context:
{context}
Question: {question}
Answer:"""
def _split_text(
self, text: str, chunk_size: int = 500, chunk_overlap: int = 100
) -> list[str]:
"""
Split text into chunks with specified size and overlap.
Args:
text: Input text to split
chunk_size: Maximum size of each chunk in characters
chunk_overlap: Number of characters to overlap between chunks
Returns:
List of text chunks
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter.split_text(text)
def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray:
"""
Generate embedding for a text chunk.
Args:
text: Input text to embed
show_progress_bar: Whether to show a progress bar
Returns:
Numpy array containing the embedding vector
"""
return EMBEDDING_MODEL.encode(
text, convert_to_numpy=True, show_progress_bar=show_progress_bar
)
def _store_document(
self, contents: list[str], embeddings: list[np.ndarray], source: str
) -> int:
"""
Store a document chunk in the database.
Args:
contents: List of text content of the chunk
embeddings: List of embedding vectors of the chunk
source: Source file path
Returns:
ID of the inserted document
"""
data_to_insert = [
(chunk, f"[{', '.join(map(str, embedding))}]", source)
for chunk, embedding in zip(contents, embeddings, strict=True)
]
query = """
INSERT INTO documents (content, embedding, source)
VALUES %s
RETURNING id
"""
with self.db_conn.cursor() as cursor:
extras.execute_values(
cursor,
query,
data_to_insert,
template="(%s, %s::vector, %s)",
page_size=100,
)
inserted_ids = [row[0] for row in cursor.fetchall()]
self.db_conn.commit()
if not inserted_ids:
raise DocumentInsertionError("No documents were inserted.")
logger.info("Successfully bulk-ingested %d documents", len(inserted_ids))
logger.info("Inserted document IDs: %s", inserted_ids)
return inserted_ids[0]
def _extract_text_from_pdf(self, pdf_path: str) -> str:
"""
Extract text from a PDF file.
Args:
pdf_path: Path to the PDF file
Returns:
Extracted text as a single string
"""
try:
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
raise DocumentExtractionError(
"Error extracting text from PDF: " + str(e)
) from e
def _get_relevant_context(self, question: str, top_k: int) -> list[tuple[str, str]]:
"""Get the most relevant document chunks for a given question"""
question_embedding = self.embedding_model.encode(
question, convert_to_numpy=True
)
try:
with self.db_conn.cursor() as cursor:
cursor.execute(
"""
SELECT content, source
FROM documents
ORDER BY embedding <-> %s::vector
LIMIT %s
""",
(question_embedding.tolist(), top_k),
)
results = cursor.fetchall()
return results
except Exception as e:
logger.exception("Error retrieving context: %s", e)
return []
def ingest_document(self, file_path: str, filename: str):
logger.info("Ingesting %s...", filename)
if not Path(file_path).exists():
err = f"File not found: {filename}"
raise FileNotFoundError(err)
logger.info("Processing PDF: %s : %s", filename, file_path)
text = self._extract_text_from_pdf(file_path)
if not text.strip():
err = "No text could be extracted from the PDF"
raise ValueError(err)
chunks = self._split_text(text)
logger.info("Split PDF into %d chunks", len(chunks))
embeddings = self._get_embedding(chunks, show_progress_bar=True)
self._store_document(chunks, embeddings, filename)
logger.info("Successfully processed %d chunks from %s", len(chunks), filename)
def answer_query(self, question: str) -> AnswerResult:
relevant_context = self._get_relevant_context(question, 5)
context_str = "\n\n".join([chunk[0] for chunk in relevant_context])
sources = list({chunk[1] for chunk in relevant_context if chunk[1]})
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
)
answer_text = response.choices[0].message.content.strip()
if not answer_text:
answer_text = "No answer generated"
sources = ["No sources"]
return AnswerResult(answer=answer_text, sources=sources)
except Exception:
logger.exception("Error generating response")
return AnswerResult(
answer="Error generating response", sources=["No sources"]
)
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
"""Answer a query using streaming."""
relevant_context = self._get_relevant_context(question, 5)
context_str = "\n\n".join([chunk[0] for chunk in relevant_context])
sources = list({chunk[1] for chunk in relevant_context if chunk[1]})
prompt = self.prompt.format(context=context_str, question=question)
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# First, yield the sources so the UI can display them immediately
import json
sources_json = json.dumps(sources)
yield f'data: {{"sources": {sources_json}}}\n\n'
# Then, stream the answer tokens
for chunk in response:
token = chunk.choices[0].delta.content
if token: # Ensure there's content to send
# SSE format: data: {"token": "..."}\n\n
yield f'data: {{"token": "{json.dumps(token)}"}}\n\n'
# Signal the end of the stream with a special message
yield 'data: {"end_of_stream": true}\n\n'
except Exception:
logger.exception("Error generating response")
yield 'data: {"error": "Error generating response"}\n\n'

View File

@ -0,0 +1,113 @@
import numpy as np
import psycopg2
from psycopg2.extensions import AsIs, register_adapter
from psycopg2.extras import execute_values
from app.core.config import settings
from app.core.interfaces import SearchResult, VectorDB
# Register NumPy array and float32 adapters for psycopg2
register_adapter(np.ndarray, AsIs)
register_adapter(np.float32, AsIs)
class PGVectorStore(VectorDB):
"""PostgreSQL vector store implementation for document storage and retrieval."""
def __init__(self):
pass
def _get_connection(self):
"""Get a new database connection."""
return psycopg2.connect(
host=settings.POSTGRES_SERVER,
port=settings.POSTGRES_PORT,
user=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD,
dbname=settings.POSTGRES_DB,
)
def upsert_documents(self, documents: list[dict]) -> None:
"""
Upsert documents into the vector store.
Args:
documents: List of document dictionaries containing 'content', 'embedding', and 'source'.
Raises:
ValueError: If required fields are missing from documents.
psycopg2.Error: For database-related errors.
"""
if not documents:
return
# Validate document structure
for doc in documents:
if not all(key in doc for key in ["content", "embedding", "source"]):
err = "Document must contain 'content', 'embedding', and 'source' keys"
raise ValueError(err)
data_to_insert = [
(doc["content"], np.array(doc["embedding"]), doc["source"])
for doc in documents
]
query = """
INSERT INTO documents (content, embedding, source)
VALUES %s
ON CONFLICT (content, source) DO UPDATE SET
embedding = EXCLUDED.embedding,
updated_at = NOW()
RETURNING id
"""
with self._get_connection() as conn, conn.cursor() as cursor:
try:
execute_values(
cursor,
query,
data_to_insert,
template="(%s, %s::vector, %s)",
page_size=100,
)
conn.commit()
except Exception:
conn.rollback()
raise
def search(self, vector: np.ndarray, top_k: int = 5) -> list[SearchResult]:
"""
Search for similar documents using vector similarity.
Args:
vector: The query vector to search with.
top_k: Maximum number of results to return.
Returns:
List of search results with content and source.
Raises:
psycopg2.Error: For database-related errors.
"""
if not vector:
return []
query = """
SELECT content, source
FROM documents
ORDER BY embedding <-> %s::vector
LIMIT %s
"""
with self._get_connection() as conn, conn.cursor() as cursor:
try:
cursor.execute(query, (np.array(vector).tolist(), top_k))
return [
SearchResult(content=row[0], source=row[1])
for row in cursor.fetchall()
]
except Exception:
conn.rollback()
raise