mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
import shutil
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
|
from structlog import get_logger
|
|
|
|
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
|
from app.services.rag_service import RAGService
|
|
|
|
logger = get_logger()
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# Dependency function to get the RAG service instance
|
|
def get_rag_service():
|
|
from app.main import app_state
|
|
|
|
return app_state["rag_service"]
|
|
|
|
|
|
@router.post("/ingest", response_model=IngestResponse)
|
|
async def ingest_file(
|
|
background_tasks: BackgroundTasks,
|
|
file: Annotated[UploadFile, File(...)],
|
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
|
):
|
|
# Save the uploaded file temporarily
|
|
temp_dir = Path("temp_files")
|
|
Path.mkdir(temp_dir, exist_ok=True)
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="File name is required")
|
|
file_path = temp_dir / Path(file.filename)
|
|
|
|
with file_path.open("wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
# Add the ingestion task to run in the background
|
|
background_tasks.add_task(
|
|
rag_service.ingest_document, file_path.as_posix(), file.filename
|
|
)
|
|
|
|
# Immediately return a response to the user
|
|
return {
|
|
"message": "File upload successful. Ingestion has started in the background.",
|
|
"filename": file.filename,
|
|
}
|
|
|
|
|
|
@router.post("/query", response_model=QueryResponse)
|
|
async def query_index(
|
|
request: QueryRequest,
|
|
rag_service: Annotated[RAGService, Depends(get_rag_service)],
|
|
):
|
|
try:
|
|
result = rag_service.answer_query(request.question)
|
|
|
|
answer = result.get("answer", "No answer generated")
|
|
sources = result.get("sources", ["No sources"])
|
|
|
|
return QueryResponse(answer=answer, sources=sources)
|
|
|
|
except Exception:
|
|
logger.exception("Failed to answer query")
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|