mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-19 14:54:05 +01:00
refactor: integrate ConfigService for model management and add configuration update endpoint
This commit is contained in:
parent
cf7d1e8218
commit
011de22885
@ -6,7 +6,13 @@ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Up
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
from app.schemas.models import (
|
||||||
|
ConfigUpdateRequest,
|
||||||
|
IngestResponse,
|
||||||
|
QueryRequest,
|
||||||
|
QueryResponse,
|
||||||
|
)
|
||||||
|
from app.services.config_service import ConfigService
|
||||||
from app.services.rag_service import RAGService
|
from app.services.rag_service import RAGService
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@ -21,6 +27,12 @@ def get_rag_service():
|
|||||||
return app_state["rag_service"]
|
return app_state["rag_service"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_service():
|
||||||
|
from app.main import app_state
|
||||||
|
|
||||||
|
return app_state["config_service"]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/ingest", response_model=IngestResponse)
|
@router.post("/ingest", response_model=IngestResponse)
|
||||||
async def ingest_file(
|
async def ingest_file(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
@ -79,3 +91,23 @@ async def query_index_stream(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to answer query")
|
logger.exception("Failed to answer query")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config")
|
||||||
|
async def update_configuration(
|
||||||
|
request: ConfigUpdateRequest,
|
||||||
|
config_service: Annotated[ConfigService, Depends(get_config_service)],
|
||||||
|
):
|
||||||
|
responses = []
|
||||||
|
if request.embedding_model:
|
||||||
|
response = await config_service.set_embedding_model(request.embedding_model)
|
||||||
|
responses.append(response)
|
||||||
|
if request.reranker_model:
|
||||||
|
response = await config_service.set_reranker_model(request.reranker_model)
|
||||||
|
responses.append(response)
|
||||||
|
# Add similar logic for LLM models and providers when implemented
|
||||||
|
|
||||||
|
if not responses:
|
||||||
|
return {"message": "No configuration changes requested."}
|
||||||
|
|
||||||
|
return {"message": " ".join(responses)}
|
||||||
|
|||||||
48
app/core/registry.py
Normal file
48
app/core/registry.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Registry(Generic[T]):
|
||||||
|
"""A generic registry to store and retrieve objects by name."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._items: dict[str, T] = {}
|
||||||
|
|
||||||
|
def register(self, name: str, item: T):
|
||||||
|
"""Registers an item with a given name."""
|
||||||
|
if not isinstance(name, str) or not name:
|
||||||
|
raise ValueError("Name must be a non-empty string.")
|
||||||
|
if name in self._items:
|
||||||
|
raise ValueError(f"Item with name '{name}' already registered.")
|
||||||
|
self._items[name] = item
|
||||||
|
|
||||||
|
def get(self, name: str) -> T:
|
||||||
|
"""Retrieves an item by its name."""
|
||||||
|
if name not in self._items:
|
||||||
|
raise KeyError(f"Item with name '{name}' not found in registry.")
|
||||||
|
return self._items[name]
|
||||||
|
|
||||||
|
def unregister(self, name: str):
|
||||||
|
"""Unregisters an item by its name."""
|
||||||
|
if name not in self._items:
|
||||||
|
raise KeyError(f"Item with name '{name}' not found in registry.")
|
||||||
|
del self._items[name]
|
||||||
|
|
||||||
|
def list_available(self) -> list[str]:
|
||||||
|
"""Lists all available item names in the registry."""
|
||||||
|
return list(self._items.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelRegistry(Registry[Callable[..., Any]]):
|
||||||
|
"""Registry specifically for embedding model constructors."""
|
||||||
|
|
||||||
|
|
||||||
|
class RerankerRegistry(Registry[Callable[..., Any]]):
|
||||||
|
"""Registry specifically for reranker constructors."""
|
||||||
|
|
||||||
|
|
||||||
|
# Global instances of the registries
|
||||||
|
embedding_model_registry = EmbeddingModelRegistry()
|
||||||
|
reranker_registry = RerankerRegistry()
|
||||||
11
app/main.py
11
app/main.py
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.api import endpoints
|
from app.api import endpoints
|
||||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
from app.services.config_service import ConfigService
|
||||||
from app.services.rag_service import RAGService
|
from app.services.rag_service import RAGService
|
||||||
from app.services.vector_stores import PGVectorStore
|
from app.services.vector_stores import PGVectorStore
|
||||||
|
|
||||||
@ -20,14 +20,17 @@ app_state = {}
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
embedding_provider = MiniLMEmbeddingModel()
|
config_service = ConfigService()
|
||||||
vector_store_provider = PGVectorStore()
|
await config_service.initialize_models()
|
||||||
|
app_state["config_service"] = config_service
|
||||||
|
|
||||||
# This code runs on startup
|
# This code runs on startup
|
||||||
logger.info("Application starting up...")
|
logger.info("Application starting up...")
|
||||||
# Initialize the RAG Service and store it in the app_state
|
# Initialize the RAG Service and store it in the app_state
|
||||||
app_state["rag_service"] = RAGService(
|
app_state["rag_service"] = RAGService(
|
||||||
embedding_model=embedding_provider, vector_db=vector_store_provider
|
embedding_model=config_service.get_current_embedding_model(),
|
||||||
|
vector_db=PGVectorStore(),
|
||||||
|
reranker=config_service.get_current_reranker_model(),
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
@ -13,3 +14,16 @@ class QueryResponse(BaseModel):
|
|||||||
class IngestResponse(BaseModel):
|
class IngestResponse(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
filename: str
|
filename: str
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigUpdateRequest(BaseModel):
|
||||||
|
embedding_model: Optional[str] = Field(
|
||||||
|
None, description="Name of the embedding model to use"
|
||||||
|
)
|
||||||
|
reranker_model: Optional[str] = Field(
|
||||||
|
None, description="Name of the reranker model to use"
|
||||||
|
)
|
||||||
|
llm_model: Optional[str] = Field(None, description="Name of the LLM model to use")
|
||||||
|
llm_provider: Optional[str] = Field(
|
||||||
|
None, description="Name of the LLM provider to use"
|
||||||
|
)
|
||||||
118
app/services/config_service.py
Normal file
118
app/services/config_service.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.interfaces import EmbeddingModel, Reranker
|
||||||
|
from app.core.registry import embedding_model_registry, reranker_registry
|
||||||
|
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||||
|
from app.services.rerankers import MiniLMReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigService:
|
||||||
|
def __init__(self):
|
||||||
|
self._current_embedding_model: EmbeddingModel | None = None
|
||||||
|
self._current_reranker_model: Reranker | None = None
|
||||||
|
self._loading_status: dict[str, bool] = {
|
||||||
|
"embedding_model": False,
|
||||||
|
"reranker_model": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register available models
|
||||||
|
self._register_models()
|
||||||
|
|
||||||
|
def _register_models(self):
|
||||||
|
# Register embedding models
|
||||||
|
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
||||||
|
# Register reranker models
|
||||||
|
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||||
|
|
||||||
|
async def initialize_models(self):
|
||||||
|
logger.info("Initializing default models...")
|
||||||
|
# Initialize default embedding model
|
||||||
|
default_embedding_model_name = settings.EMBEDDING_MODEL
|
||||||
|
await self.set_embedding_model(default_embedding_model_name)
|
||||||
|
logger.info(
|
||||||
|
f"Default embedding model initialized: {default_embedding_model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize default reranker model (if any)
|
||||||
|
# Assuming a default reranker can be set in settings if needed
|
||||||
|
# For now, let's assume MiniLMReranker is the default if not specified
|
||||||
|
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
||||||
|
await self.set_reranker_model(default_reranker_model_name)
|
||||||
|
logger.info(
|
||||||
|
f"Default reranker model initialized: {default_reranker_model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_embedding_model(self, model_name: str) -> str:
|
||||||
|
if (
|
||||||
|
self._current_embedding_model
|
||||||
|
and self._current_embedding_model.__class__.__name__ == model_name
|
||||||
|
):
|
||||||
|
return f"Embedding model '{model_name}' is already in use."
|
||||||
|
|
||||||
|
if self._loading_status["embedding_model"]:
|
||||||
|
return "Another embedding model is currently being loaded. Please wait."
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._loading_status["embedding_model"] = True
|
||||||
|
logger.info(f"Attempting to load embedding model: {model_name}")
|
||||||
|
model_constructor = embedding_model_registry.get(model_name)
|
||||||
|
self._current_embedding_model = model_constructor()
|
||||||
|
settings.EMBEDDING_MODEL = model_name # Update settings
|
||||||
|
logger.info(f"Successfully loaded embedding model: {model_name}")
|
||||||
|
return f"Embedding model set to '{model_name}' successfully."
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Embedding model '{model_name}' not found in registry.")
|
||||||
|
return (
|
||||||
|
f"Embedding model '{model_name}' not available. "
|
||||||
|
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error loading embedding model {model_name}: {e}")
|
||||||
|
return f"Failed to load embedding model '{model_name}': {e}"
|
||||||
|
finally:
|
||||||
|
self._loading_status["embedding_model"] = False
|
||||||
|
|
||||||
|
async def set_reranker_model(self, model_name: str) -> str:
|
||||||
|
if (
|
||||||
|
self._current_reranker_model
|
||||||
|
and self._current_reranker_model.__class__.__name__ == model_name
|
||||||
|
):
|
||||||
|
return f"Reranker model '{model_name}' is already in use."
|
||||||
|
|
||||||
|
if self._loading_status["reranker_model"]:
|
||||||
|
return "Another reranker model is currently being loaded. Please wait."
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._loading_status["reranker_model"] = True
|
||||||
|
logger.info(f"Attempting to load reranker model: {model_name}")
|
||||||
|
model_constructor = reranker_registry.get(model_name)
|
||||||
|
self._current_reranker_model = model_constructor()
|
||||||
|
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
|
||||||
|
logger.info(f"Successfully loaded reranker model: {model_name}")
|
||||||
|
return f"Reranker model set to '{model_name}' successfully."
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Reranker model '{model_name}' not found in registry.")
|
||||||
|
return (
|
||||||
|
f"Reranker model '{model_name}' not available. "
|
||||||
|
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error loading reranker model {model_name}: {e}")
|
||||||
|
return f"Failed to load reranker model '{model_name}': {e}"
|
||||||
|
finally:
|
||||||
|
self._loading_status["reranker_model"] = False
|
||||||
|
|
||||||
|
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
||||||
|
return self._current_embedding_model
|
||||||
|
|
||||||
|
def get_current_reranker_model(self) -> Reranker | None:
|
||||||
|
return self._current_reranker_model
|
||||||
|
|
||||||
|
def get_available_embedding_models(self) -> list[str]:
|
||||||
|
return embedding_model_registry.list_available()
|
||||||
|
|
||||||
|
def get_available_reranker_models(self) -> list[str]:
|
||||||
|
return reranker_registry.list_available()
|
||||||
@ -6,7 +6,7 @@ from typing import TypedDict
|
|||||||
import litellm
|
import litellm
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.core.interfaces import EmbeddingModel, VectorDB
|
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
|
||||||
from app.core.utils import RecursiveCharacterTextSplitter
|
from app.core.utils import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@ -18,9 +18,15 @@ class AnswerResult(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class RAGService:
|
class RAGService:
|
||||||
def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_model: EmbeddingModel,
|
||||||
|
vector_db: VectorDB,
|
||||||
|
reranker: Reranker | None = None,
|
||||||
|
):
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.vector_db = vector_db
|
self.vector_db = vector_db
|
||||||
|
self.reranker = reranker
|
||||||
self.prompt = """Answer the question based on the following context.
|
self.prompt = """Answer the question based on the following context.
|
||||||
If you don't know the answer, say you don't know. Don't make up an answer.
|
If you don't know the answer, say you don't know. Don't make up an answer.
|
||||||
|
|
||||||
@ -69,8 +75,12 @@ Answer:"""
|
|||||||
def answer_query(self, question: str) -> AnswerResult:
|
def answer_query(self, question: str) -> AnswerResult:
|
||||||
query_embedding = self.embedding_model.embed_query(question)
|
query_embedding = self.embedding_model.embed_query(question)
|
||||||
search_results = self.vector_db.search(query_embedding, top_k=5)
|
search_results = self.vector_db.search(query_embedding, top_k=5)
|
||||||
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
|
||||||
|
|
||||||
|
if self.reranker:
|
||||||
|
logger.info("Reranking search results...")
|
||||||
|
search_results = self.reranker.rerank(search_results, question)
|
||||||
|
|
||||||
|
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
||||||
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -92,7 +102,7 @@ Answer:"""
|
|||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
answer_text = response.choices[0].message.content.strip() # type: ignore
|
answer_text = response.choices[0].message.content.strip()
|
||||||
|
|
||||||
if not answer_text:
|
if not answer_text:
|
||||||
answer_text = "No answer generated"
|
answer_text = "No answer generated"
|
||||||
@ -108,6 +118,11 @@ Answer:"""
|
|||||||
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
|
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
|
||||||
query_embedding = self.embedding_model.embed_query(question)
|
query_embedding = self.embedding_model.embed_query(question)
|
||||||
search_results = self.vector_db.search(query_embedding, top_k=5)
|
search_results = self.vector_db.search(query_embedding, top_k=5)
|
||||||
|
|
||||||
|
if self.reranker:
|
||||||
|
logger.info("Reranking search results...")
|
||||||
|
search_results = self.reranker.rerank(search_results, question)
|
||||||
|
|
||||||
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
||||||
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user