feat: add steraming output when query

This commit is contained in:
Sosokker 2025-06-24 16:03:45 +07:00
parent 425c584101
commit 80af71935f
2 changed files with 54 additions and 0 deletions

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from structlog import get_logger from structlog import get_logger
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
@ -64,3 +65,19 @@ async def query_index(
except Exception: except Exception:
logger.exception("Failed to answer query") logger.exception("Failed to answer query")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/query/stream")
async def query_index_stream(
request: QueryRequest,
rag_service: Annotated[RAGService, Depends(get_rag_service)],
):
try:
return StreamingResponse(
rag_service.answer_query_stream(request.question),
media_type="text/event-stream",
)
except Exception:
logger.exception("Failed to answer query")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@ -1,4 +1,5 @@
import os import os
from collections.abc import Generator
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import TypedDict
@ -17,6 +18,7 @@ from app.core.exception import DocumentExtractionError, DocumentInsertionError
from app.core.utils import RecursiveCharacterTextSplitter from app.core.utils import RecursiveCharacterTextSplitter
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
register_adapter(np.float32, AsIs) # for psycopg2 adapt
logger = get_logger() logger = get_logger()
# pyright: reportArgumentType=false # pyright: reportArgumentType=false
@ -244,3 +246,38 @@ Answer:"""
return AnswerResult( return AnswerResult(
answer="Error generating response", sources=["No sources"] answer="Error generating response", sources=["No sources"]
) )
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
"""Answer a query using streaming."""
relevant_context = self._get_relevant_context(question, 5)
context_str = "\n\n".join([chunk[0] for chunk in relevant_context])
sources = list({chunk[1] for chunk in relevant_context if chunk[1]})
prompt = self.prompt.format(context=context_str, question=question)
try:
response = litellm.completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# First, yield the sources so the UI can display them immediately
import json
sources_json = json.dumps(sources)
yield f'data: {{"sources": {sources_json}}}\n\n'
# Then, stream the answer tokens
for chunk in response:
token = chunk.choices[0].delta.content
if token: # Ensure there's content to send
# SSE format: data: {"token": "..."}\n\n
yield f'data: {{"token": "{json.dumps(token)}"}}\n\n'
# Signal the end of the stream with a special message
yield 'data: {"end_of_stream": true}\n\n'
except Exception:
logger.exception("Error generating response")
yield 'data: {"error": "Error generating response"}\n\n'