mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
114 lines
3.2 KiB
Python
114 lines
3.2 KiB
Python
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from structlog import get_logger
|
|
|
|
from app.schemas.models import (
|
|
ConfigUpdateRequest,
|
|
IngestResponse,
|
|
QueryRequest,
|
|
QueryResponse,
|
|
)
|
|
from app.services.config_service import ConfigService
|
|
from app.services.rag_service import RAGService
|
|
|
|
logger = get_logger()
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# Dependency function to get the RAG service instance
|
|
def get_rag_service():
|
|
from app.main import app_state
|
|
|
|
return app_state["rag_service"]
|
|
|
|
|
|
def get_config_service():
|
|
from app.main import app_state
|
|
|
|
return app_state["config_service"]
|
|
|
|
|
|
@router.post("/ingest", response_model=IngestResponse)
|
|
async def ingest_file(
|
|
background_tasks: BackgroundTasks,
|
|
file: Annotated[UploadFile, File(...)],
|
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
|
):
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="File name is required")
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
delete=False, suffix=file.filename, delete_on_close=True
|
|
) as tmp:
|
|
file.file.seek(0)
|
|
shutil.copyfileobj(file.file, tmp)
|
|
tmp.flush()
|
|
background_tasks.add_task(
|
|
rag_service.ingest_document, Path(tmp.name), file.filename
|
|
)
|
|
|
|
return {
|
|
"message": "File upload successful. Ingestion has started in the background.",
|
|
"filename": file.filename,
|
|
}
|
|
|
|
|
|
@router.post("/query", response_model=QueryResponse)
|
|
async def query_index(
|
|
request: QueryRequest,
|
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
|
):
|
|
try:
|
|
result = rag_service.answer_query(request.question)
|
|
|
|
answer = result.get("answer", "No answer generated")
|
|
sources = result.get("sources", ["No sources"])
|
|
|
|
return QueryResponse(answer=answer, sources=sources)
|
|
|
|
except Exception:
|
|
logger.exception("Failed to answer query")
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
|
@router.post("/query/stream")
|
|
async def query_index_stream(
|
|
request: QueryRequest,
|
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
|
):
|
|
try:
|
|
return StreamingResponse(
|
|
rag_service.answer_query_stream(request.question),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
except Exception:
|
|
logger.exception("Failed to answer query")
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
|
@router.post("/config")
|
|
async def update_configuration(
|
|
request: ConfigUpdateRequest,
|
|
config_service: Annotated[ConfigService, Depends(get_config_service)],
|
|
):
|
|
responses = []
|
|
if request.embedding_model:
|
|
response = await config_service.set_embedding_model(request.embedding_model)
|
|
responses.append(response)
|
|
if request.reranker_model:
|
|
response = await config_service.set_reranker_model(request.reranker_model)
|
|
responses.append(response)
|
|
# Add similar logic for LLM models and providers when implemented
|
|
|
|
if not responses:
|
|
return {"message": "No configuration changes requested."}
|
|
|
|
return {"message": " ".join(responses)}
|