mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
refactor: enhance model initialization and validation in ConfigService, add enums for model names
This commit is contained in:
parent
fc0d1a3a16
commit
2f13d8c3ce
@ -95,7 +95,6 @@ def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
# Create settings instance
|
||||
settings = get_settings()
|
||||
|
||||
# Set environment variables for third-party libraries
|
||||
|
||||
15
app/main.py
15
app/main.py
@ -11,33 +11,32 @@ from app.services.vector_stores import PGVectorStore
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Dictionary to hold our application state, including the RAG service instance
|
||||
app_state = {}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(app: FastAPI): # noqa: ARG001
|
||||
config_service = ConfigService()
|
||||
await config_service.initialize_models()
|
||||
app_state["config_service"] = config_service
|
||||
|
||||
# This code runs on startup
|
||||
logger.info("Application starting up...")
|
||||
# Initialize the RAG Service and store it in the app_state
|
||||
embedding_model = config_service.get_current_embedding_model()
|
||||
reranker = config_service.get_current_reranker_model()
|
||||
if embedding_model is None:
|
||||
raise RuntimeError("Embedding model failed to initialize")
|
||||
app_state["rag_service"] = RAGService(
|
||||
embedding_model=config_service.get_current_embedding_model(),
|
||||
embedding_model=embedding_model,
|
||||
vector_db=PGVectorStore(),
|
||||
reranker=config_service.get_current_reranker_model(),
|
||||
reranker=reranker,
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Include the API router
|
||||
app.include_router(endpoints.router)
|
||||
|
||||
|
||||
|
||||
9
app/schemas/enums.py
Normal file
9
app/schemas/enums.py
Normal file
@ -0,0 +1,9 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EmbeddingModelName(str, Enum):
|
||||
MiniLMEmbeddingModel = "MiniLMEmbeddingModel"
|
||||
|
||||
|
||||
class RerankerModelName(str, Enum):
|
||||
MiniLMReranker = "MiniLMReranker"
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from app.core.config import settings
|
||||
from app.core.interfaces import EmbeddingModel, Reranker
|
||||
from app.core.registry import embedding_model_registry, reranker_registry
|
||||
from app.schemas.enums import EmbeddingModelName, RerankerModelName
|
||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||
from app.services.rerankers import MiniLMReranker
|
||||
|
||||
@ -26,24 +27,41 @@ class ConfigService:
|
||||
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||
|
||||
async def initialize_models(self):
|
||||
"""
|
||||
Initialize embedding and reranker mode,
|
||||
if not a valid name then fallback to default one.
|
||||
Will get call on first time starting the app.
|
||||
"""
|
||||
logger.info("Initializing default models...")
|
||||
# Initialize default embedding model
|
||||
default_embedding_model_name = settings.EMBEDDING_MODEL
|
||||
await self.set_embedding_model(default_embedding_model_name)
|
||||
logger.info(
|
||||
"Default embedding model initialized: %s", default_embedding_model_name
|
||||
)
|
||||
embedding_model_name = settings.EMBEDDING_MODEL
|
||||
if embedding_model_name not in EmbeddingModelName.__members__:
|
||||
logger.warning(
|
||||
"Embedding model '%s' is not valid. Falling back to default '%s'",
|
||||
embedding_model_name,
|
||||
EmbeddingModelName.MiniLMEmbeddingModel.value,
|
||||
)
|
||||
embedding_model_name = (
|
||||
EmbeddingModelName.MiniLMEmbeddingModel.value
|
||||
) # use minilm as default
|
||||
await self.set_embedding_model(embedding_model_name)
|
||||
logger.info("Default embedding model initialized: %s", embedding_model_name)
|
||||
|
||||
# Initialize default reranker model (if any)
|
||||
# Assuming a default reranker can be set in settings if needed
|
||||
# For now, let's assume MiniLMReranker is the default if not specified
|
||||
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
||||
await self.set_reranker_model(default_reranker_model_name)
|
||||
logger.info(
|
||||
"Default reranker model initialized: %s", default_reranker_model_name
|
||||
reranker_model_name = (
|
||||
getattr(settings, "RERANKER_MODEL", None)
|
||||
or RerankerModelName.MiniLMReranker.value
|
||||
)
|
||||
if reranker_model_name not in RerankerModelName.__members__:
|
||||
logger.warning(
|
||||
"Reranker model '%s' is not valid. Falling back to default '%s'",
|
||||
reranker_model_name,
|
||||
RerankerModelName.MiniLMReranker.value,
|
||||
)
|
||||
reranker_model_name = RerankerModelName.MiniLMReranker.value
|
||||
await self.set_reranker_model(reranker_model_name)
|
||||
logger.info("Default reranker model initialized: %s", reranker_model_name)
|
||||
|
||||
async def set_embedding_model(self, model_name: str) -> str:
|
||||
"""Set system embedding model based on provide model_name"""
|
||||
if (
|
||||
self._current_embedding_model
|
||||
and self._current_embedding_model.__class__.__name__ == model_name
|
||||
@ -75,6 +93,7 @@ class ConfigService:
|
||||
self._loading_status["embedding_model"] = False
|
||||
|
||||
async def set_reranker_model(self, model_name: str) -> str:
|
||||
"""Set system reranker model based on provide model_name"""
|
||||
if (
|
||||
self._current_reranker_model
|
||||
and self._current_reranker_model.__class__.__name__ == model_name
|
||||
|
||||
Loading…
Reference in New Issue
Block a user