plain-rag/app/services/config_service.py

119 lines
5.2 KiB
Python

import logging
from app.core.config import settings
from app.core.interfaces import EmbeddingModel, Reranker
from app.core.registry import embedding_model_registry, reranker_registry
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rerankers import MiniLMReranker
logger = logging.getLogger(__name__)
class ConfigService:
def __init__(self):
self._current_embedding_model: EmbeddingModel | None = None
self._current_reranker_model: Reranker | None = None
self._loading_status: dict[str, bool] = {
"embedding_model": False,
"reranker_model": False,
}
self._register_models()
def _register_models(self):
"""Register all default models"""
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
reranker_registry.register("MiniLMReranker", MiniLMReranker)
async def initialize_models(self):
logger.info("Initializing default models...")
# Initialize default embedding model
default_embedding_model_name = settings.EMBEDDING_MODEL
await self.set_embedding_model(default_embedding_model_name)
logger.info(
"Default embedding model initialized: %s", default_embedding_model_name
)
# Initialize default reranker model (if any)
# Assuming a default reranker can be set in settings if needed
# For now, let's assume MiniLMReranker is the default if not specified
default_reranker_model_name = "MiniLMReranker" # Or from settings
await self.set_reranker_model(default_reranker_model_name)
logger.info(
"Default reranker model initialized: %s", default_reranker_model_name
)
async def set_embedding_model(self, model_name: str) -> str:
if (
self._current_embedding_model
and self._current_embedding_model.__class__.__name__ == model_name
):
return f"Embedding model '{model_name}' is already in use."
if self._loading_status["embedding_model"]:
return "Another embedding model is currently being loaded. Please wait."
try:
self._loading_status["embedding_model"] = True
logger.info("Attempting to load embedding model: %s", model_name)
model_constructor = embedding_model_registry.get(model_name)
self._current_embedding_model = model_constructor()
settings.EMBEDDING_MODEL = model_name # Update settings
except KeyError:
logger.warning("Embedding model '%s' not found in registry.", model_name)
return (
f"Embedding model '{model_name}' not available. "
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
)
except Exception as e:
logger.exception("Error loading embedding model %s: %s", model_name, e)
return f"Failed to load embedding model '{model_name}': {e}"
else:
logger.info("Successfully loaded embedding model: %s", model_name)
return f"Embedding model set to '{model_name}' successfully."
finally:
self._loading_status["embedding_model"] = False
async def set_reranker_model(self, model_name: str) -> str:
if (
self._current_reranker_model
and self._current_reranker_model.__class__.__name__ == model_name
):
return f"Reranker model '{model_name}' is already in use."
if self._loading_status["reranker_model"]:
return "Another reranker model is currently being loaded. Please wait."
try:
self._loading_status["reranker_model"] = True
logger.info("Attempting to load reranker model: %s", model_name)
model_constructor = reranker_registry.get(model_name)
self._current_reranker_model = model_constructor()
# settings.RERANKER_MODEL = model_name
except KeyError:
logger.warning("Reranker model '%s' not found in registry.", model_name)
return (
f"Reranker model '{model_name}' not available. "
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
)
except Exception as e:
logger.exception("Error loading reranker model %s: %s", model_name, e)
return f"Failed to load reranker model '{model_name}': {e}"
else:
logger.info("Successfully loaded reranker model: %s", model_name)
return f"Reranker model set to '{model_name}' successfully."
finally:
self._loading_status["reranker_model"] = False
def get_current_embedding_model(self) -> EmbeddingModel | None:
return self._current_embedding_model
def get_current_reranker_model(self) -> Reranker | None:
return self._current_reranker_model
def get_available_embedding_models(self) -> list[str]:
return embedding_model_registry.list_available()
def get_available_reranker_models(self) -> list[str]:
return reranker_registry.list_available()