plain-rag/app/services/rag_service.py

193 lines
6.9 KiB
Python

import json
from collections.abc import Generator
from pathlib import Path
from typing import TypedDict
import litellm
from PyPDF2 import PdfReader
from PyPDF2.errors import PyPdfError
from structlog import get_logger
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
from app.core.utils import RecursiveCharacterTextSplitter
logger = get_logger()
class AnswerResult(TypedDict):
answer: str
sources: list[str]
class RAGService:
def __init__(
self,
embedding_model: EmbeddingModel,
vector_db: VectorDB,
reranker: Reranker | None = None,
):
self.embedding_model = embedding_model
self.vector_db = vector_db
self.reranker = reranker
self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer.
Context:
{context}
Question: {question}
Answer:"""
def _split_text(
self, text: str, chunk_size: int = 500, chunk_overlap: int = 100
) -> list[str]:
"""
Split text into chunks with specified size and overlap.
Args:
text: Input text to split
chunk_size: Maximum size of each chunk in characters
chunk_overlap: Number of characters to overlap between chunks
Returns:
List of text chunks
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter.split_text(text)
def _ingest_document(self, text_chunks: list[str], source_name: str):
embeddings = self.embedding_model.embed_documents(text_chunks)
documents_to_upsert = [
{"content": chunk, "embedding": emb, "source": source_name}
for chunk, emb in zip(text_chunks, embeddings, strict=False)
]
self.vector_db.upsert_documents(documents_to_upsert)
def ingest_document(self, file_path: Path, source_name: str):
path = Path(file_path)
ext = path.suffix
text = ""
if ext == ".pdf":
try:
reader = PdfReader(str(file_path))
text = "\n".join(page.extract_text() or "" for page in reader.pages)
except PyPdfError as e:
logger.exception("PDF processing error for %s", file_path)
raise ValueError(
f"Failed to extract text from PDF due to a PDF processing error: {e}"
) from e
except Exception as e:
logger.exception(
"An unexpected error occurred during PDF processing for %s",
file_path,
)
raise RuntimeError(
f"An unexpected error occurred during PDF processing: {e}"
) from e
else:
with Path(file_path).open("r", encoding="utf-8") as f:
text = f.read()
text_chunks = self._split_text(text)
self._ingest_document(text_chunks, source_name)
def answer_query(self, question: str) -> AnswerResult:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
)
answer_text = None
choices = getattr(response, "choices", None)
if choices and len(choices) > 0:
first_choice = choices[0]
message = getattr(first_choice, "message", None)
content = getattr(message, "content", None)
if content:
answer_text = content.strip()
if not answer_text:
answer_text = "No answer generated"
sources = ["No sources"]
return AnswerResult(answer=answer_text, sources=sources)
except Exception:
logger.exception("Error generating response")
return AnswerResult(
answer="Error generating response", sources=["No sources"]
)
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
stream=True,
)
for chunk in response:
choices = getattr(chunk, "choices", None)
if choices and len(choices) > 0:
delta = getattr(choices[0], "delta", None)
content = getattr(delta, "content", None)
if content:
yield f'data: {{"token": {json.dumps(content)}}}\n\n'
# Yield sources at the end
yield f'data: {{"sources": {json.dumps(sources)}}}\n\n'
yield 'data: {"end_of_stream": true}\n\n'
except Exception:
logger.exception("Error generating streaming response")
yield 'data: {"error": "Error generating response"}\n\n'