plain-rag/app/services/rag_service.py

148 lines
5.2 KiB
Python

import json
from collections.abc import Generator
from pathlib import Path
from typing import TypedDict
import litellm
from structlog import get_logger
from app.core.interfaces import EmbeddingModel, VectorDB
from app.core.utils import RecursiveCharacterTextSplitter
logger = get_logger()
class AnswerResult(TypedDict):
answer: str
sources: list[str]
class RAGService:
def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB):
self.embedding_model = embedding_model
self.vector_db = vector_db
self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer.
Context:
{context}
Question: {question}
Answer:"""
def _split_text(
self, text: str, chunk_size: int = 500, chunk_overlap: int = 100
) -> list[str]:
"""
Split text into chunks with specified size and overlap.
Args:
text: Input text to split
chunk_size: Maximum size of each chunk in characters
chunk_overlap: Number of characters to overlap between chunks
Returns:
List of text chunks
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter.split_text(text)
def _ingest_document(self, text_chunks: list[str], source_name: str):
embeddings = self.embedding_model.embed_documents(text_chunks)
documents_to_upsert = [
{"content": chunk, "embedding": emb, "source": source_name}
for chunk, emb in zip(text_chunks, embeddings, strict=False)
]
self.vector_db.upsert_documents(documents_to_upsert)
def ingest_document(self, file_path: Path, source_name: str):
with Path(file_path).open("r", encoding="utf-8") as f:
text = f.read()
text_chunks = self._split_text(text)
self._ingest_document(text_chunks, source_name)
def answer_query(self, question: str) -> AnswerResult:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
)
answer_text = response.choices[0].message.content.strip() # type: ignore
if not answer_text:
answer_text = "No answer generated"
sources = ["No sources"]
return AnswerResult(answer=answer_text, sources=sources)
except Exception:
logger.exception("Error generating response")
return AnswerResult(
answer="Error generating response", sources=["No sources"]
)
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": self.prompt.format(
context=context_str, question=question
),
},
],
temperature=0.1,
max_tokens=500,
stream=True,
)
# Yield each chunk of the response as it's generated
for chunk in response:
if chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
yield f'data: {{"token": "{json.dumps(delta.content)}"}}\n\n'
# Yield sources at the end
yield f'data: {{"sources": {json.dumps(sources)}}}\n\n'
yield 'data: {"end_of_stream": true}\n\n'
except Exception:
logger.exception("Error generating streaming response")
yield 'data: {"error": "Error generating response"}\n\n'