plain-rag/app/services/config_service.py

119 lines
5.2 KiB
Python

import logging
from app.core.config import settings
from app.core.interfaces import EmbeddingModel, Reranker
from app.core.registry import embedding_model_registry, reranker_registry
from app.services.embedding_providers import MiniLMEmbeddingModel
from app.services.rerankers import MiniLMReranker
logger = logging.getLogger(__name__)
class ConfigService:
def __init__(self):
self._current_embedding_model: EmbeddingModel | None = None
self._current_reranker_model: Reranker | None = None
self._loading_status: dict[str, bool] = {
"embedding_model": False,
"reranker_model": False,
}
# Register available models
self._register_models()
def _register_models(self):
# Register embedding models
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
# Register reranker models
reranker_registry.register("MiniLMReranker", MiniLMReranker)
async def initialize_models(self):
logger.info("Initializing default models...")
# Initialize default embedding model
default_embedding_model_name = settings.EMBEDDING_MODEL
await self.set_embedding_model(default_embedding_model_name)
logger.info(
f"Default embedding model initialized: {default_embedding_model_name}"
)
# Initialize default reranker model (if any)
# Assuming a default reranker can be set in settings if needed
# For now, let's assume MiniLMReranker is the default if not specified
default_reranker_model_name = "MiniLMReranker" # Or from settings
await self.set_reranker_model(default_reranker_model_name)
logger.info(
f"Default reranker model initialized: {default_reranker_model_name}"
)
async def set_embedding_model(self, model_name: str) -> str:
if (
self._current_embedding_model
and self._current_embedding_model.__class__.__name__ == model_name
):
return f"Embedding model '{model_name}' is already in use."
if self._loading_status["embedding_model"]:
return "Another embedding model is currently being loaded. Please wait."
try:
self._loading_status["embedding_model"] = True
logger.info(f"Attempting to load embedding model: {model_name}")
model_constructor = embedding_model_registry.get(model_name)
self._current_embedding_model = model_constructor()
settings.EMBEDDING_MODEL = model_name # Update settings
logger.info(f"Successfully loaded embedding model: {model_name}")
return f"Embedding model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Embedding model '{model_name}' not found in registry.")
return (
f"Embedding model '{model_name}' not available. "
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading embedding model {model_name}: {e}")
return f"Failed to load embedding model '{model_name}': {e}"
finally:
self._loading_status["embedding_model"] = False
async def set_reranker_model(self, model_name: str) -> str:
if (
self._current_reranker_model
and self._current_reranker_model.__class__.__name__ == model_name
):
return f"Reranker model '{model_name}' is already in use."
if self._loading_status["reranker_model"]:
return "Another reranker model is currently being loaded. Please wait."
try:
self._loading_status["reranker_model"] = True
logger.info(f"Attempting to load reranker model: {model_name}")
model_constructor = reranker_registry.get(model_name)
self._current_reranker_model = model_constructor()
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
logger.info(f"Successfully loaded reranker model: {model_name}")
return f"Reranker model set to '{model_name}' successfully."
except KeyError:
logger.warning(f"Reranker model '{model_name}' not found in registry.")
return (
f"Reranker model '{model_name}' not available. "
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
)
except Exception as e:
logger.exception(f"Error loading reranker model {model_name}: {e}")
return f"Failed to load reranker model '{model_name}': {e}"
finally:
self._loading_status["reranker_model"] = False
def get_current_embedding_model(self) -> EmbeddingModel | None:
return self._current_embedding_model
def get_current_reranker_model(self) -> Reranker | None:
return self._current_reranker_model
def get_available_embedding_models(self) -> list[str]:
return embedding_model_registry.list_available()
def get_available_reranker_models(self) -> list[str]:
return reranker_registry.list_available()