mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 22:44:03 +01:00
refactor: enhance model initialization and validation in ConfigService, add enums for model names
This commit is contained in:
parent
fc0d1a3a16
commit
2f13d8c3ce
@ -95,7 +95,6 @@ def get_settings() -> Settings:
|
|||||||
return Settings()
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
# Create settings instance
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
# Set environment variables for third-party libraries
|
# Set environment variables for third-party libraries
|
||||||
|
|||||||
15
app/main.py
15
app/main.py
@ -11,33 +11,32 @@ from app.services.vector_stores import PGVectorStore
|
|||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Dictionary to hold our application state, including the RAG service instance
|
|
||||||
app_state = {}
|
app_state = {}
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI): # noqa: ARG001
|
||||||
config_service = ConfigService()
|
config_service = ConfigService()
|
||||||
await config_service.initialize_models()
|
await config_service.initialize_models()
|
||||||
app_state["config_service"] = config_service
|
app_state["config_service"] = config_service
|
||||||
|
|
||||||
# This code runs on startup
|
|
||||||
logger.info("Application starting up...")
|
logger.info("Application starting up...")
|
||||||
# Initialize the RAG Service and store it in the app_state
|
embedding_model = config_service.get_current_embedding_model()
|
||||||
|
reranker = config_service.get_current_reranker_model()
|
||||||
|
if embedding_model is None:
|
||||||
|
raise RuntimeError("Embedding model failed to initialize")
|
||||||
app_state["rag_service"] = RAGService(
|
app_state["rag_service"] = RAGService(
|
||||||
embedding_model=config_service.get_current_embedding_model(),
|
embedding_model=embedding_model,
|
||||||
vector_db=PGVectorStore(),
|
vector_db=PGVectorStore(),
|
||||||
reranker=config_service.get_current_reranker_model(),
|
reranker=reranker,
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# Include the API router
|
|
||||||
app.include_router(endpoints.router)
|
app.include_router(endpoints.router)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
9
app/schemas/enums.py
Normal file
9
app/schemas/enums.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelName(str, Enum):
|
||||||
|
MiniLMEmbeddingModel = "MiniLMEmbeddingModel"
|
||||||
|
|
||||||
|
|
||||||
|
class RerankerModelName(str, Enum):
|
||||||
|
MiniLMReranker = "MiniLMReranker"
|
||||||
@ -3,6 +3,7 @@ import logging
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.interfaces import EmbeddingModel, Reranker
|
from app.core.interfaces import EmbeddingModel, Reranker
|
||||||
from app.core.registry import embedding_model_registry, reranker_registry
|
from app.core.registry import embedding_model_registry, reranker_registry
|
||||||
|
from app.schemas.enums import EmbeddingModelName, RerankerModelName
|
||||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||||
from app.services.rerankers import MiniLMReranker
|
from app.services.rerankers import MiniLMReranker
|
||||||
|
|
||||||
@ -26,24 +27,41 @@ class ConfigService:
|
|||||||
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||||
|
|
||||||
async def initialize_models(self):
|
async def initialize_models(self):
|
||||||
|
"""
|
||||||
|
Initialize embedding and reranker mode,
|
||||||
|
if not a valid name then fallback to default one.
|
||||||
|
Will get call on first time starting the app.
|
||||||
|
"""
|
||||||
logger.info("Initializing default models...")
|
logger.info("Initializing default models...")
|
||||||
# Initialize default embedding model
|
embedding_model_name = settings.EMBEDDING_MODEL
|
||||||
default_embedding_model_name = settings.EMBEDDING_MODEL
|
if embedding_model_name not in EmbeddingModelName.__members__:
|
||||||
await self.set_embedding_model(default_embedding_model_name)
|
logger.warning(
|
||||||
logger.info(
|
"Embedding model '%s' is not valid. Falling back to default '%s'",
|
||||||
"Default embedding model initialized: %s", default_embedding_model_name
|
embedding_model_name,
|
||||||
)
|
EmbeddingModelName.MiniLMEmbeddingModel.value,
|
||||||
|
)
|
||||||
|
embedding_model_name = (
|
||||||
|
EmbeddingModelName.MiniLMEmbeddingModel.value
|
||||||
|
) # use minilm as default
|
||||||
|
await self.set_embedding_model(embedding_model_name)
|
||||||
|
logger.info("Default embedding model initialized: %s", embedding_model_name)
|
||||||
|
|
||||||
# Initialize default reranker model (if any)
|
reranker_model_name = (
|
||||||
# Assuming a default reranker can be set in settings if needed
|
getattr(settings, "RERANKER_MODEL", None)
|
||||||
# For now, let's assume MiniLMReranker is the default if not specified
|
or RerankerModelName.MiniLMReranker.value
|
||||||
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
|
||||||
await self.set_reranker_model(default_reranker_model_name)
|
|
||||||
logger.info(
|
|
||||||
"Default reranker model initialized: %s", default_reranker_model_name
|
|
||||||
)
|
)
|
||||||
|
if reranker_model_name not in RerankerModelName.__members__:
|
||||||
|
logger.warning(
|
||||||
|
"Reranker model '%s' is not valid. Falling back to default '%s'",
|
||||||
|
reranker_model_name,
|
||||||
|
RerankerModelName.MiniLMReranker.value,
|
||||||
|
)
|
||||||
|
reranker_model_name = RerankerModelName.MiniLMReranker.value
|
||||||
|
await self.set_reranker_model(reranker_model_name)
|
||||||
|
logger.info("Default reranker model initialized: %s", reranker_model_name)
|
||||||
|
|
||||||
async def set_embedding_model(self, model_name: str) -> str:
|
async def set_embedding_model(self, model_name: str) -> str:
|
||||||
|
"""Set system embedding model based on provide model_name"""
|
||||||
if (
|
if (
|
||||||
self._current_embedding_model
|
self._current_embedding_model
|
||||||
and self._current_embedding_model.__class__.__name__ == model_name
|
and self._current_embedding_model.__class__.__name__ == model_name
|
||||||
@ -75,6 +93,7 @@ class ConfigService:
|
|||||||
self._loading_status["embedding_model"] = False
|
self._loading_status["embedding_model"] = False
|
||||||
|
|
||||||
async def set_reranker_model(self, model_name: str) -> str:
|
async def set_reranker_model(self, model_name: str) -> str:
|
||||||
|
"""Set system reranker model based on provide model_name"""
|
||||||
if (
|
if (
|
||||||
self._current_reranker_model
|
self._current_reranker_model
|
||||||
and self._current_reranker_model.__class__.__name__ == model_name
|
and self._current_reranker_model.__class__.__name__ == model_name
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user