refactor: integrate ConfigService for model management and add configuration update endpoint

This commit is contained in:
Sosokker 2025-06-26 18:02:51 +07:00
parent cf7d1e8218
commit 011de22885
7 changed files with 244 additions and 14 deletions

View File

@ -6,7 +6,13 @@ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Up
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from structlog import get_logger from structlog import get_logger
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse from app.schemas.models import (
ConfigUpdateRequest,
IngestResponse,
QueryRequest,
QueryResponse,
)
from app.services.config_service import ConfigService
from app.services.rag_service import RAGService from app.services.rag_service import RAGService
logger = get_logger() logger = get_logger()
@ -21,6 +27,12 @@ def get_rag_service():
return app_state["rag_service"] return app_state["rag_service"]
def get_config_service():
from app.main import app_state
return app_state["config_service"]
@router.post("/ingest", response_model=IngestResponse) @router.post("/ingest", response_model=IngestResponse)
async def ingest_file( async def ingest_file(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@ -79,3 +91,23 @@ async def query_index_stream(
except Exception: except Exception:
logger.exception("Failed to answer query") logger.exception("Failed to answer query")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/config")
async def update_configuration(
request: ConfigUpdateRequest,
config_service: Annotated[ConfigService, Depends(get_config_service)],
):
responses = []
if request.embedding_model:
response = await config_service.set_embedding_model(request.embedding_model)
responses.append(response)
if request.reranker_model:
response = await config_service.set_reranker_model(request.reranker_model)
responses.append(response)
# Add similar logic for LLM models and providers when implemented
if not responses:
return {"message": "No configuration changes requested."}
return {"message": " ".join(responses)}

48
app/core/registry.py Normal file
View File

@ -0,0 +1,48 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar
T = TypeVar("T")
class Registry(Generic[T]):
"""A generic registry to store and retrieve objects by name."""
def __init__(self):
self._items: dict[str, T] = {}
def register(self, name: str, item: T):
"""Registers an item with a given name."""
if not isinstance(name, str) or not name:
raise ValueError("Name must be a non-empty string.")
if name in self._items:
raise ValueError(f"Item with name '{name}' already registered.")
self._items[name] = item
def get(self, name: str) -> T:
"""Retrieves an item by its name."""
if name not in self._items:
raise KeyError(f"Item with name '{name}' not found in registry.")
return self._items[name]
def unregister(self, name: str):
"""Unregisters an item by its name."""
if name not in self._items:
raise KeyError(f"Item with name '{name}' not found in registry.")
del self._items[name]
def list_available(self) -> list[str]:
"""Lists all available item names in the registry."""
return list(self._items.keys())
class EmbeddingModelRegistry(Registry[Callable[..., Any]]):
"""Registry specifically for embedding model constructors."""
class RerankerRegistry(Registry[Callable[..., Any]]):
"""Registry specifically for reranker constructors."""
# Global instances of the registries
embedding_model_registry = EmbeddingModelRegistry()
reranker_registry = RerankerRegistry()

View File

@ -5,7 +5,7 @@ from fastapi import FastAPI
from structlog import get_logger from structlog import get_logger
from app.api import endpoints from app.api import endpoints
from app.services.embedding_providers import MiniLMEmbeddingModel from app.services.config_service import ConfigService
from app.services.rag_service import RAGService from app.services.rag_service import RAGService
from app.services.vector_stores import PGVectorStore from app.services.vector_stores import PGVectorStore
@ -20,14 +20,17 @@ app_state = {}
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
embedding_provider = MiniLMEmbeddingModel() config_service = ConfigService()
vector_store_provider = PGVectorStore() await config_service.initialize_models()
app_state["config_service"] = config_service
# This code runs on startup # This code runs on startup
logger.info("Application starting up...") logger.info("Application starting up...")
# Initialize the RAG Service and store it in the app_state # Initialize the RAG Service and store it in the app_state
app_state["rag_service"] = RAGService( app_state["rag_service"] = RAGService(
embedding_model=embedding_provider, vector_db=vector_store_provider embedding_model=config_service.get_current_embedding_model(),
vector_db=PGVectorStore(),
reranker=config_service.get_current_reranker_model(),
) )
yield yield

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel, Field
from typing import Optional
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
@ -13,3 +14,16 @@ class QueryResponse(BaseModel):
class IngestResponse(BaseModel): class IngestResponse(BaseModel):
message: str message: str
filename: str filename: str
class ConfigUpdateRequest(BaseModel):
embedding_model: Optional[str] = Field(
None, description="Name of the embedding model to use"
)
reranker_model: Optional[str] = Field(
None, description="Name of the reranker model to use"
)
llm_model: Optional[str] = Field(None, description="Name of the LLM model to use")
llm_provider: Optional[str] = Field(
None, description="Name of the LLM provider to use"
)

View File

@ -0,0 +1,118 @@
import logging
from app.core.config import settings
from app.core.interfaces import EmbeddingModel, Reranker
from app.core.registry import embedding_model_registry, reranker_registry
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rerankers import MiniLMReranker
logger = logging.getLogger(__name__)
class ConfigService:
def __init__(self):
self._current_embedding_model: EmbeddingModel | None = None
self._current_reranker_model: Reranker | None = None
self._loading_status: dict[str, bool] = {
"embedding_model": False,
"reranker_model": False,
}
# Register available models
self._register_models()
def _register_models(self):
# Register embedding models
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
# Register reranker models
reranker_registry.register("MiniLMReranker", MiniLMReranker)
async def initialize_models(self):
logger.info("Initializing default models...")
# Initialize default embedding model
default_embedding_model_name = settings.EMBEDDING_MODEL
await self.set_embedding_model(default_embedding_model_name)
logger.info(
f"Default embedding model initialized: {default_embedding_model_name}"
)
# Initialize default reranker model (if any)
# Assuming a default reranker can be set in settings if needed
# For now, let's assume MiniLMReranker is the default if not specified
default_reranker_model_name = "MiniLMReranker" # Or from settings
await self.set_reranker_model(default_reranker_model_name)
logger.info(
f"Default reranker model initialized: {default_reranker_model_name}"
)
async def set_embedding_model(self, model_name: str) -> str:
if (
self._current_embedding_model
and self._current_embedding_model.__class__.__name__ == model_name
):
return f"Embedding model '{model_name}' is already in use."
if self._loading_status["embedding_model"]:
return "Another embedding model is currently being loaded. Please wait."
try:
self._loading_status["embedding_model"] = True
logger.info(f"Attempting to load embedding model: {model_name}")
model_constructor = embedding_model_registry.get(model_name)
self._current_embedding_model = model_constructor()
settings.EMBEDDING_MODEL = model_name # Update settings
logger.info(f"Successfully loaded embedding model: {model_name}")
return f"Embedding model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Embedding model '{model_name}' not found in registry.")
return (
f"Embedding model '{model_name}' not available. "
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading embedding model {model_name}: {e}")
return f"Failed to load embedding model '{model_name}': {e}"
finally:
self._loading_status["embedding_model"] = False
async def set_reranker_model(self, model_name: str) -> str:
if (
self._current_reranker_model
and self._current_reranker_model.__class__.__name__ == model_name
):
return f"Reranker model '{model_name}' is already in use."
if self._loading_status["reranker_model"]:
return "Another reranker model is currently being loaded. Please wait."
try:
self._loading_status["reranker_model"] = True
logger.info(f"Attempting to load reranker model: {model_name}")
model_constructor = reranker_registry.get(model_name)
self._current_reranker_model = model_constructor()
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
logger.info(f"Successfully loaded reranker model: {model_name}")
return f"Reranker model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Reranker model '{model_name}' not found in registry.")
return (
f"Reranker model '{model_name}' not available. "
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading reranker model {model_name}: {e}")
return f"Failed to load reranker model '{model_name}': {e}"
finally:
self._loading_status["reranker_model"] = False
def get_current_embedding_model(self) -> EmbeddingModel | None:
return self._current_embedding_model
def get_current_reranker_model(self) -> Reranker | None:
return self._current_reranker_model
def get_available_embedding_models(self) -> list[str]:
return embedding_model_registry.list_available()
def get_available_reranker_models(self) -> list[str]:
return reranker_registry.list_available()

View File

@ -6,7 +6,7 @@ from typing import TypedDict
import litellm import litellm
from structlog import get_logger from structlog import get_logger
from app.core.interfaces import EmbeddingModel, VectorDB from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
from app.core.utils import RecursiveCharacterTextSplitter from app.core.utils import RecursiveCharacterTextSplitter
logger = get_logger() logger = get_logger()
@ -18,9 +18,15 @@ class AnswerResult(TypedDict):
class RAGService: class RAGService:
def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB): def __init__(
self,
embedding_model: EmbeddingModel,
vector_db: VectorDB,
reranker: Reranker | None = None,
):
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.vector_db = vector_db self.vector_db = vector_db
self.reranker = reranker
self.prompt = """Answer the question based on the following context. self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer. If you don't know the answer, say you don't know. Don't make up an answer.
@ -69,8 +75,12 @@ Answer:"""
def answer_query(self, question: str) -> AnswerResult: def answer_query(self, question: str) -> AnswerResult:
query_embedding = self.embedding_model.embed_query(question) query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5) search_results = self.vector_db.search(query_embedding, top_k=5)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results]) context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try: try:
@ -92,7 +102,7 @@ Answer:"""
max_tokens=500, max_tokens=500,
) )
answer_text = response.choices[0].message.content.strip() # type: ignore answer_text = response.choices[0].message.content.strip()
if not answer_text: if not answer_text:
answer_text = "No answer generated" answer_text = "No answer generated"
@ -108,6 +118,11 @@ Answer:"""
def answer_query_stream(self, question: str) -> Generator[str, None, None]: def answer_query_stream(self, question: str) -> Generator[str, None, None]:
query_embedding = self.embedding_model.embed_query(question) query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5) search_results = self.vector_db.search(query_embedding, top_k=5)
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results]) context_str = "\n\n".join([chunk["content"] for chunk in search_results])