From 80af71935f714a82634aefd57a62d477c21570db Mon Sep 17 00:00:00 2001 From: Sosokker Date: Tue, 24 Jun 2025 16:03:45 +0700 Subject: [PATCH] feat: add steraming output when query --- app/api/endpoints.py | 17 +++++++++++++++++ app/services/rag_service.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/app/api/endpoints.py b/app/api/endpoints.py index f05d384..2ccaca3 100644 --- a/app/api/endpoints.py +++ b/app/api/endpoints.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Annotated from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile +from fastapi.responses import StreamingResponse from structlog import get_logger from app.schemas.models import IngestResponse, QueryRequest, QueryResponse @@ -64,3 +65,19 @@ async def query_index( except Exception: logger.exception("Failed to answer query") raise HTTPException(status_code=500, detail="Internal server error") + + +@router.post("/query/stream") +async def query_index_stream( + request: QueryRequest, + rag_service: Annotated[RAGService, Depends(get_rag_service)], +): + try: + return StreamingResponse( + rag_service.answer_query_stream(request.question), + media_type="text/event-stream", + ) + + except Exception: + logger.exception("Failed to answer query") + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/app/services/rag_service.py b/app/services/rag_service.py index c8c23c2..4874594 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator from pathlib import Path from typing import TypedDict @@ -17,6 +18,7 @@ from app.core.exception import DocumentExtractionError, DocumentInsertionError from app.core.utils import RecursiveCharacterTextSplitter register_adapter(np.ndarray, AsIs) # for psycopg2 adapt +register_adapter(np.float32, AsIs) # for psycopg2 adapt logger = get_logger() # pyright: reportArgumentType=false @@ -244,3 +246,38 @@ Answer:""" return AnswerResult( answer="Error generating response", sources=["No sources"] ) + + def answer_query_stream(self, question: str) -> Generator[str, None, None]: + """Answer a query using streaming.""" + relevant_context = self._get_relevant_context(question, 5) + context_str = "\n\n".join([chunk[0] for chunk in relevant_context]) + sources = list({chunk[1] for chunk in relevant_context if chunk[1]}) + + prompt = self.prompt.format(context=context_str, question=question) + + try: + response = litellm.completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + + # First, yield the sources so the UI can display them immediately + import json + + sources_json = json.dumps(sources) + yield f'data: {{"sources": {sources_json}}}\n\n' + + # Then, stream the answer tokens + for chunk in response: + token = chunk.choices[0].delta.content + if token: # Ensure there's content to send + # SSE format: data: {"token": "..."}\n\n + yield f'data: {{"token": "{json.dumps(token)}"}}\n\n' + + # Signal the end of the stream with a special message + yield 'data: {"end_of_stream": true}\n\n' + + except Exception: + logger.exception("Error generating response") + yield 'data: {"error": "Error generating response"}\n\n'