mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
feat: add steraming output when query
This commit is contained in:
parent
425c584101
commit
80af71935f
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from structlog import get_logger
|
||||
|
||||
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
||||
@ -64,3 +65,19 @@ async def query_index(
|
||||
except Exception:
|
||||
logger.exception("Failed to answer query")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/query/stream")
|
||||
async def query_index_stream(
|
||||
request: QueryRequest,
|
||||
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
||||
):
|
||||
try:
|
||||
return StreamingResponse(
|
||||
rag_service.answer_query_stream(request.question),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to answer query")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
@ -17,6 +18,7 @@ from app.core.exception import DocumentExtractionError, DocumentInsertionError
|
||||
from app.core.utils import RecursiveCharacterTextSplitter
|
||||
|
||||
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
|
||||
register_adapter(np.float32, AsIs) # for psycopg2 adapt
|
||||
logger = get_logger()
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
@ -244,3 +246,38 @@ Answer:"""
|
||||
return AnswerResult(
|
||||
answer="Error generating response", sources=["No sources"]
|
||||
)
|
||||
|
||||
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
|
||||
"""Answer a query using streaming."""
|
||||
relevant_context = self._get_relevant_context(question, 5)
|
||||
context_str = "\n\n".join([chunk[0] for chunk in relevant_context])
|
||||
sources = list({chunk[1] for chunk in relevant_context if chunk[1]})
|
||||
|
||||
prompt = self.prompt.format(context=context_str, question=question)
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# First, yield the sources so the UI can display them immediately
|
||||
import json
|
||||
|
||||
sources_json = json.dumps(sources)
|
||||
yield f'data: {{"sources": {sources_json}}}\n\n'
|
||||
|
||||
# Then, stream the answer tokens
|
||||
for chunk in response:
|
||||
token = chunk.choices[0].delta.content
|
||||
if token: # Ensure there's content to send
|
||||
# SSE format: data: {"token": "..."}\n\n
|
||||
yield f'data: {{"token": "{json.dumps(token)}"}}\n\n'
|
||||
|
||||
# Signal the end of the stream with a special message
|
||||
yield 'data: {"end_of_stream": true}\n\n'
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error generating response")
|
||||
yield 'data: {"error": "Error generating response"}\n\n'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user