refactor: enhance model initialization and validation in ConfigService, add enums for model names

This commit is contained in:
Sosokker 2025-06-27 22:09:27 +07:00
parent fc0d1a3a16
commit 2f13d8c3ce
4 changed files with 48 additions and 22 deletions

View File

@ -95,7 +95,6 @@ def get_settings() -> Settings:
return Settings()
# Create settings instance
settings = get_settings()
# Set environment variables for third-party libraries

View File

@ -11,33 +11,32 @@ from app.services.vector_stores import PGVectorStore
logger = get_logger()
# Load environment variables from .env file
load_dotenv()
# Dictionary to hold our application state, including the RAG service instance
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI): # noqa: ARG001
config_service = ConfigService()
await config_service.initialize_models()
app_state["config_service"] = config_service
# This code runs on startup
logger.info("Application starting up...")
# Initialize the RAG Service and store it in the app_state
embedding_model = config_service.get_current_embedding_model()
reranker = config_service.get_current_reranker_model()
if embedding_model is None:
raise RuntimeError("Embedding model failed to initialize")
app_state["rag_service"] = RAGService(
embedding_model=config_service.get_current_embedding_model(),
embedding_model=embedding_model,
vector_db=PGVectorStore(),
reranker=config_service.get_current_reranker_model(),
reranker=reranker,
)
yield
app = FastAPI(lifespan=lifespan)
# Include the API router
app.include_router(endpoints.router)

9
app/schemas/enums.py Normal file
View File

@ -0,0 +1,9 @@
from enum import Enum
class EmbeddingModelName(str, Enum):
MiniLMEmbeddingModel = "MiniLMEmbeddingModel"
class RerankerModelName(str, Enum):
MiniLMReranker = "MiniLMReranker"

View File

@ -3,6 +3,7 @@ import logging
from app.core.config import settings
from app.core.interfaces import EmbeddingModel, Reranker
from app.core.registry import embedding_model_registry, reranker_registry
from app.schemas.enums import EmbeddingModelName, RerankerModelName
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rerankers import MiniLMReranker
@ -26,24 +27,41 @@ class ConfigService:
reranker_registry.register("MiniLMReranker", MiniLMReranker)
async def initialize_models(self):
"""
Initialize embedding and reranker mode,
if not a valid name then fallback to default one.
Will get call on first time starting the app.
"""
logger.info("Initializing default models...")
# Initialize default embedding model
default_embedding_model_name = settings.EMBEDDING_MODEL
await self.set_embedding_model(default_embedding_model_name)
logger.info(
"Default embedding model initialized: %s", default_embedding_model_name
)
embedding_model_name = settings.EMBEDDING_MODEL
if embedding_model_name not in EmbeddingModelName.__members__:
logger.warning(
"Embedding model '%s' is not valid. Falling back to default '%s'",
embedding_model_name,
EmbeddingModelName.MiniLMEmbeddingModel.value,
)
embedding_model_name = (
EmbeddingModelName.MiniLMEmbeddingModel.value
) # use minilm as default
await self.set_embedding_model(embedding_model_name)
logger.info("Default embedding model initialized: %s", embedding_model_name)
# Initialize default reranker model (if any)
# Assuming a default reranker can be set in settings if needed
# For now, let's assume MiniLMReranker is the default if not specified
default_reranker_model_name = "MiniLMReranker" # Or from settings
await self.set_reranker_model(default_reranker_model_name)
logger.info(
"Default reranker model initialized: %s", default_reranker_model_name
reranker_model_name = (
getattr(settings, "RERANKER_MODEL", None)
or RerankerModelName.MiniLMReranker.value
)
if reranker_model_name not in RerankerModelName.__members__:
logger.warning(
"Reranker model '%s' is not valid. Falling back to default '%s'",
reranker_model_name,
RerankerModelName.MiniLMReranker.value,
)
reranker_model_name = RerankerModelName.MiniLMReranker.value
await self.set_reranker_model(reranker_model_name)
logger.info("Default reranker model initialized: %s", reranker_model_name)
async def set_embedding_model(self, model_name: str) -> str:
"""Set system embedding model based on provide model_name"""
if (
self._current_embedding_model
and self._current_embedding_model.__class__.__name__ == model_name
@ -75,6 +93,7 @@ class ConfigService:
self._loading_status["embedding_model"] = False
async def set_reranker_model(self, model_name: str) -> str:
"""Set system reranker model based on provide model_name"""
if (
self._current_reranker_model
and self._current_reranker_model.__class__.__name__ == model_name