refactor: integrate ConfigService for model management and add configuration update endpoint

This commit is contained in:
Sosokker 2025-06-26 18:02:51 +07:00
parent cf7d1e8218
commit 011de22885
7 changed files with 244 additions and 14 deletions

View File

@ -6,7 +6,13 @@ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Up
from fastapi.responses import StreamingResponse
from structlog import get_logger
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
from app.schemas.models import (
ConfigUpdateRequest,
IngestResponse,
QueryRequest,
QueryResponse,
)
from app.services.config_service import ConfigService
from app.services.rag_service import RAGService
logger = get_logger()
@ -21,6 +27,12 @@ def get_rag_service():
return app_state["rag_service"]
def get_config_service():
from app.main import app_state
return app_state["config_service"]
@router.post("/ingest", response_model=IngestResponse)
async def ingest_file(
background_tasks: BackgroundTasks,
@ -79,3 +91,23 @@ async def query_index_stream(
except Exception:
logger.exception("Failed to answer query")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/config")
async def update_configuration(
request: ConfigUpdateRequest,
config_service: Annotated[ConfigService, Depends(get_config_service)],
):
responses = []
if request.embedding_model:
response = await config_service.set_embedding_model(request.embedding_model)
responses.append(response)
if request.reranker_model:
response = await config_service.set_reranker_model(request.reranker_model)
responses.append(response)
# Add similar logic for LLM models and providers when implemented
if not responses:
return {"message": "No configuration changes requested."}
return {"message": " ".join(responses)}

48
app/core/registry.py Normal file
View File

@ -0,0 +1,48 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar
T = TypeVar("T")
class Registry(Generic[T]):
"""A generic registry to store and retrieve objects by name."""
def __init__(self):
self._items: dict[str, T] = {}
def register(self, name: str, item: T):
"""Registers an item with a given name."""
if not isinstance(name, str) or not name:
raise ValueError("Name must be a non-empty string.")
if name in self._items:
raise ValueError(f"Item with name '{name}' already registered.")
self._items[name] = item
def get(self, name: str) -> T:
"""Retrieves an item by its name."""
if name not in self._items:
raise KeyError(f"Item with name '{name}' not found in registry.")
return self._items[name]
def unregister(self, name: str):
"""Unregisters an item by its name."""
if name not in self._items:
raise KeyError(f"Item with name '{name}' not found in registry.")
del self._items[name]
def list_available(self) -> list[str]:
"""Lists all available item names in the registry."""
return list(self._items.keys())
class EmbeddingModelRegistry(Registry[Callable[..., Any]]):
"""Registry specifically for embedding model constructors."""
class RerankerRegistry(Registry[Callable[..., Any]]):
"""Registry specifically for reranker constructors."""
# Global instances of the registries
embedding_model_registry = EmbeddingModelRegistry()
reranker_registry = RerankerRegistry()

View File

@ -5,7 +5,7 @@ from fastapi import FastAPI
from structlog import get_logger
from app.api import endpoints
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.config_service import ConfigService
from app.services.rag_service import RAGService
from app.services.vector_stores import PGVectorStore
@ -20,14 +20,17 @@ app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
embedding_provider = MiniLMEmbeddingModel()
vector_store_provider = PGVectorStore()
config_service = ConfigService()
await config_service.initialize_models()
app_state["config_service"] = config_service
# This code runs on startup
logger.info("Application starting up...")
# Initialize the RAG Service and store it in the app_state
app_state["rag_service"] = RAGService(
embedding_model=embedding_provider, vector_db=vector_store_provider
embedding_model=config_service.get_current_embedding_model(),
vector_db=PGVectorStore(),
reranker=config_service.get_current_reranker_model(),
)
yield

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
class QueryRequest(BaseModel):
@ -13,3 +14,16 @@ class QueryResponse(BaseModel):
class IngestResponse(BaseModel):
message: str
filename: str
class ConfigUpdateRequest(BaseModel):
embedding_model: Optional[str] = Field(
None, description="Name of the embedding model to use"
)
reranker_model: Optional[str] = Field(
None, description="Name of the reranker model to use"
)
llm_model: Optional[str] = Field(None, description="Name of the LLM model to use")
llm_provider: Optional[str] = Field(
None, description="Name of the LLM provider to use"
)

View File

@ -0,0 +1,118 @@
import logging
from app.core.config import settings
from app.core.interfaces import EmbeddingModel, Reranker
from app.core.registry import embedding_model_registry, reranker_registry
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rerankers import MiniLMReranker
logger = logging.getLogger(__name__)
class ConfigService:
def __init__(self):
self._current_embedding_model: EmbeddingModel | None = None
self._current_reranker_model: Reranker | None = None
self._loading_status: dict[str, bool] = {
"embedding_model": False,
"reranker_model": False,
}
# Register available models
self._register_models()
def _register_models(self):
# Register embedding models
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
# Register reranker models
reranker_registry.register("MiniLMReranker", MiniLMReranker)
async def initialize_models(self):
logger.info("Initializing default models...")
# Initialize default embedding model
default_embedding_model_name = settings.EMBEDDING_MODEL
await self.set_embedding_model(default_embedding_model_name)
logger.info(
f"Default embedding model initialized: {default_embedding_model_name}"
)
# Initialize default reranker model (if any)
# Assuming a default reranker can be set in settings if needed
# For now, let's assume MiniLMReranker is the default if not specified
default_reranker_model_name = "MiniLMReranker" # Or from settings
await self.set_reranker_model(default_reranker_model_name)
logger.info(
f"Default reranker model initialized: {default_reranker_model_name}"
)
async def set_embedding_model(self, model_name: str) -> str:
if (
self._current_embedding_model
and self._current_embedding_model.__class__.__name__ == model_name
):
return f"Embedding model '{model_name}' is already in use."
if self._loading_status["embedding_model"]:
return "Another embedding model is currently being loaded. Please wait."
try:
self._loading_status["embedding_model"] = True
logger.info(f"Attempting to load embedding model: {model_name}")
model_constructor = embedding_model_registry.get(model_name)
self._current_embedding_model = model_constructor()
settings.EMBEDDING_MODEL = model_name # Update settings
logger.info(f"Successfully loaded embedding model: {model_name}")
return f"Embedding model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Embedding model '{model_name}' not found in registry.")
return (
f"Embedding model '{model_name}' not available. "
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading embedding model {model_name}: {e}")
return f"Failed to load embedding model '{model_name}': {e}"
finally:
self._loading_status["embedding_model"] = False
async def set_reranker_model(self, model_name: str) -> str:
if (
self._current_reranker_model
and self._current_reranker_model.__class__.__name__ == model_name
):
return f"Reranker model '{model_name}' is already in use."
if self._loading_status["reranker_model"]:
return "Another reranker model is currently being loaded. Please wait."
try:
self._loading_status["reranker_model"] = True
logger.info(f"Attempting to load reranker model: {model_name}")
model_constructor = reranker_registry.get(model_name)
self._current_reranker_model = model_constructor()
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
logger.info(f"Successfully loaded reranker model: {model_name}")
return f"Reranker model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Reranker model '{model_name}' not found in registry.")
return (
f"Reranker model '{model_name}' not available. "
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading reranker model {model_name}: {e}")
return f"Failed to load reranker model '{model_name}': {e}"
finally:
self._loading_status["reranker_model"] = False
def get_current_embedding_model(self) -> EmbeddingModel | None:
return self._current_embedding_model
def get_current_reranker_model(self) -> Reranker | None:
return self._current_reranker_model
def get_available_embedding_models(self) -> list[str]:
return embedding_model_registry.list_available()
def get_available_reranker_models(self) -> list[str]:
return reranker_registry.list_available()

View File

@ -6,7 +6,7 @@ from typing import TypedDict
import litellm
from structlog import get_logger
from app.core.interfaces import EmbeddingModel, VectorDB
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
from app.core.utils import RecursiveCharacterTextSplitter
logger = get_logger()
@ -18,10 +18,16 @@ class AnswerResult(TypedDict):
class RAGService:
def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB):
def __init__(
self,
embedding_model: EmbeddingModel,
vector_db: VectorDB,
reranker: Reranker | None = None,
):
self.embedding_model = embedding_model
self.vector_db = vector_db
self.prompt = """Answer the question based on the following context.
self.reranker = reranker
self.prompt = """Answer the question based on the following context.
If you don't know the answer, say you don't know. Don't make up an answer.
Context:
@ -69,8 +75,12 @@ Answer:"""
def answer_query(self, question: str) -> AnswerResult:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
try:
@ -92,7 +102,7 @@ Answer:"""
max_tokens=500,
)
answer_text = response.choices[0].message.content.strip() # type: ignore
answer_text = response.choices[0].message.content.strip()
if not answer_text:
answer_text = "No answer generated"
@ -108,6 +118,11 @@ Answer:"""
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
query_embedding = self.embedding_model.embed_query(question)
search_results = self.vector_db.search(query_embedding, top_k=5)
if self.reranker:
logger.info("Reranking search results...")
search_results = self.reranker.rerank(search_results, question)
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
context_str = "\n\n".join([chunk["content"] for chunk in search_results])

View File

@ -175,9 +175,9 @@ Answer:"""
with self.db_conn.cursor() as cursor:
cursor.execute(
"""
SELECT content, source
FROM documents
ORDER BY embedding <-> %s::vector
SELECT content, source
FROM documents
ORDER BY embedding <-> %s::vector
LIMIT %s
""",
(question_embedding.tolist(), top_k),