diff --git a/app/core/config.py b/app/core/config.py index c7425f1..3ba1998 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -95,7 +95,6 @@ def get_settings() -> Settings: return Settings() -# Create settings instance settings = get_settings() # Set environment variables for third-party libraries diff --git a/app/main.py b/app/main.py index f303ca2..3d95da2 100644 --- a/app/main.py +++ b/app/main.py @@ -11,33 +11,32 @@ from app.services.vector_stores import PGVectorStore logger = get_logger() -# Load environment variables from .env file load_dotenv() -# Dictionary to hold our application state, including the RAG service instance app_state = {} @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI): # noqa: ARG001 config_service = ConfigService() await config_service.initialize_models() app_state["config_service"] = config_service - # This code runs on startup logger.info("Application starting up...") - # Initialize the RAG Service and store it in the app_state + embedding_model = config_service.get_current_embedding_model() + reranker = config_service.get_current_reranker_model() + if embedding_model is None: + raise RuntimeError("Embedding model failed to initialize") app_state["rag_service"] = RAGService( - embedding_model=config_service.get_current_embedding_model(), + embedding_model=embedding_model, vector_db=PGVectorStore(), - reranker=config_service.get_current_reranker_model(), + reranker=reranker, ) yield app = FastAPI(lifespan=lifespan) -# Include the API router app.include_router(endpoints.router) diff --git a/app/schemas/enums.py b/app/schemas/enums.py new file mode 100644 index 0000000..acfc3bb --- /dev/null +++ b/app/schemas/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class EmbeddingModelName(str, Enum): + MiniLMEmbeddingModel = "MiniLMEmbeddingModel" + + +class RerankerModelName(str, Enum): + MiniLMReranker = "MiniLMReranker" diff --git a/app/services/config_service.py b/app/services/config_service.py index d7d06a5..f9a8470 100644 --- a/app/services/config_service.py +++ b/app/services/config_service.py @@ -3,6 +3,7 @@ import logging from app.core.config import settings from app.core.interfaces import EmbeddingModel, Reranker from app.core.registry import embedding_model_registry, reranker_registry +from app.schemas.enums import EmbeddingModelName, RerankerModelName from app.services.embedding_providers import MiniLMEmbeddingModel from app.services.rerankers import MiniLMReranker @@ -26,24 +27,41 @@ class ConfigService: reranker_registry.register("MiniLMReranker", MiniLMReranker) async def initialize_models(self): + """ + Initialize embedding and reranker mode, + if not a valid name then fallback to default one. + Will get call on first time starting the app. + """ logger.info("Initializing default models...") - # Initialize default embedding model - default_embedding_model_name = settings.EMBEDDING_MODEL - await self.set_embedding_model(default_embedding_model_name) - logger.info( - "Default embedding model initialized: %s", default_embedding_model_name - ) + embedding_model_name = settings.EMBEDDING_MODEL + if embedding_model_name not in EmbeddingModelName.__members__: + logger.warning( + "Embedding model '%s' is not valid. Falling back to default '%s'", + embedding_model_name, + EmbeddingModelName.MiniLMEmbeddingModel.value, + ) + embedding_model_name = ( + EmbeddingModelName.MiniLMEmbeddingModel.value + ) # use minilm as default + await self.set_embedding_model(embedding_model_name) + logger.info("Default embedding model initialized: %s", embedding_model_name) - # Initialize default reranker model (if any) - # Assuming a default reranker can be set in settings if needed - # For now, let's assume MiniLMReranker is the default if not specified - default_reranker_model_name = "MiniLMReranker" # Or from settings - await self.set_reranker_model(default_reranker_model_name) - logger.info( - "Default reranker model initialized: %s", default_reranker_model_name + reranker_model_name = ( + getattr(settings, "RERANKER_MODEL", None) + or RerankerModelName.MiniLMReranker.value ) + if reranker_model_name not in RerankerModelName.__members__: + logger.warning( + "Reranker model '%s' is not valid. Falling back to default '%s'", + reranker_model_name, + RerankerModelName.MiniLMReranker.value, + ) + reranker_model_name = RerankerModelName.MiniLMReranker.value + await self.set_reranker_model(reranker_model_name) + logger.info("Default reranker model initialized: %s", reranker_model_name) async def set_embedding_model(self, model_name: str) -> str: + """Set system embedding model based on provide model_name""" if ( self._current_embedding_model and self._current_embedding_model.__class__.__name__ == model_name @@ -75,6 +93,7 @@ class ConfigService: self._loading_status["embedding_model"] = False async def set_reranker_model(self, model_name: str) -> str: + """Set system reranker model based on provide model_name""" if ( self._current_reranker_model and self._current_reranker_model.__class__.__name__ == model_name