mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 22:44:03 +01:00
feat: add steraming output when query
This commit is contained in:
parent
425c584101
commit
80af71935f
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
||||||
@ -64,3 +65,19 @@ async def query_index(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to answer query")
|
logger.exception("Failed to answer query")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/query/stream")
|
||||||
|
async def query_index_stream(
|
||||||
|
request: QueryRequest,
|
||||||
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return StreamingResponse(
|
||||||
|
rag_service.answer_query_stream(request.question),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to answer query")
|
||||||
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ from app.core.exception import DocumentExtractionError, DocumentInsertionError
|
|||||||
from app.core.utils import RecursiveCharacterTextSplitter
|
from app.core.utils import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
|
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
|
||||||
|
register_adapter(np.float32, AsIs) # for psycopg2 adapt
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# pyright: reportArgumentType=false
|
# pyright: reportArgumentType=false
|
||||||
@ -244,3 +246,38 @@ Answer:"""
|
|||||||
return AnswerResult(
|
return AnswerResult(
|
||||||
answer="Error generating response", sources=["No sources"]
|
answer="Error generating response", sources=["No sources"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
|
||||||
|
"""Answer a query using streaming."""
|
||||||
|
relevant_context = self._get_relevant_context(question, 5)
|
||||||
|
context_str = "\n\n".join([chunk[0] for chunk in relevant_context])
|
||||||
|
sources = list({chunk[1] for chunk in relevant_context if chunk[1]})
|
||||||
|
|
||||||
|
prompt = self.prompt.format(context=context_str, question=question)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First, yield the sources so the UI can display them immediately
|
||||||
|
import json
|
||||||
|
|
||||||
|
sources_json = json.dumps(sources)
|
||||||
|
yield f'data: {{"sources": {sources_json}}}\n\n'
|
||||||
|
|
||||||
|
# Then, stream the answer tokens
|
||||||
|
for chunk in response:
|
||||||
|
token = chunk.choices[0].delta.content
|
||||||
|
if token: # Ensure there's content to send
|
||||||
|
# SSE format: data: {"token": "..."}\n\n
|
||||||
|
yield f'data: {{"token": "{json.dumps(token)}"}}\n\n'
|
||||||
|
|
||||||
|
# Signal the end of the stream with a special message
|
||||||
|
yield 'data: {"end_of_stream": true}\n\n'
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error generating response")
|
||||||
|
yield 'data: {"error": "Error generating response"}\n\n'
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user