mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 22:44:03 +01:00
119 lines
5.2 KiB
Python
119 lines
5.2 KiB
Python
import logging
|
|
|
|
from app.core.config import settings
|
|
from app.core.interfaces import EmbeddingModel, Reranker
|
|
from app.core.registry import embedding_model_registry, reranker_registry
|
|
from app.services.embedding_providers import MiniLMEmbeddingModel
|
|
from app.services.rerankers import MiniLMReranker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConfigService:
|
|
def __init__(self):
|
|
self._current_embedding_model: EmbeddingModel | None = None
|
|
self._current_reranker_model: Reranker | None = None
|
|
self._loading_status: dict[str, bool] = {
|
|
"embedding_model": False,
|
|
"reranker_model": False,
|
|
}
|
|
|
|
self._register_models()
|
|
|
|
def _register_models(self):
|
|
"""Register all default models"""
|
|
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
|
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
|
|
|
async def initialize_models(self):
|
|
logger.info("Initializing default models...")
|
|
# Initialize default embedding model
|
|
default_embedding_model_name = settings.EMBEDDING_MODEL
|
|
await self.set_embedding_model(default_embedding_model_name)
|
|
logger.info(
|
|
"Default embedding model initialized: %s", default_embedding_model_name
|
|
)
|
|
|
|
# Initialize default reranker model (if any)
|
|
# Assuming a default reranker can be set in settings if needed
|
|
# For now, let's assume MiniLMReranker is the default if not specified
|
|
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
|
await self.set_reranker_model(default_reranker_model_name)
|
|
logger.info(
|
|
"Default reranker model initialized: %s", default_reranker_model_name
|
|
)
|
|
|
|
async def set_embedding_model(self, model_name: str) -> str:
|
|
if (
|
|
self._current_embedding_model
|
|
and self._current_embedding_model.__class__.__name__ == model_name
|
|
):
|
|
return f"Embedding model '{model_name}' is already in use."
|
|
|
|
if self._loading_status["embedding_model"]:
|
|
return "Another embedding model is currently being loaded. Please wait."
|
|
|
|
try:
|
|
self._loading_status["embedding_model"] = True
|
|
logger.info("Attempting to load embedding model: %s", model_name)
|
|
model_constructor = embedding_model_registry.get(model_name)
|
|
self._current_embedding_model = model_constructor()
|
|
settings.EMBEDDING_MODEL = model_name # Update settings
|
|
except KeyError:
|
|
logger.warning("Embedding model '%s' not found in registry.", model_name)
|
|
return (
|
|
f"Embedding model '{model_name}' not available. "
|
|
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Error loading embedding model %s: %s", model_name, e)
|
|
return f"Failed to load embedding model '{model_name}': {e}"
|
|
else:
|
|
logger.info("Successfully loaded embedding model: %s", model_name)
|
|
return f"Embedding model set to '{model_name}' successfully."
|
|
finally:
|
|
self._loading_status["embedding_model"] = False
|
|
|
|
async def set_reranker_model(self, model_name: str) -> str:
|
|
if (
|
|
self._current_reranker_model
|
|
and self._current_reranker_model.__class__.__name__ == model_name
|
|
):
|
|
return f"Reranker model '{model_name}' is already in use."
|
|
|
|
if self._loading_status["reranker_model"]:
|
|
return "Another reranker model is currently being loaded. Please wait."
|
|
|
|
try:
|
|
self._loading_status["reranker_model"] = True
|
|
logger.info("Attempting to load reranker model: %s", model_name)
|
|
model_constructor = reranker_registry.get(model_name)
|
|
self._current_reranker_model = model_constructor()
|
|
# settings.RERANKER_MODEL = model_name
|
|
except KeyError:
|
|
logger.warning("Reranker model '%s' not found in registry.", model_name)
|
|
return (
|
|
f"Reranker model '{model_name}' not available. "
|
|
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Error loading reranker model %s: %s", model_name, e)
|
|
return f"Failed to load reranker model '{model_name}': {e}"
|
|
else:
|
|
logger.info("Successfully loaded reranker model: %s", model_name)
|
|
return f"Reranker model set to '{model_name}' successfully."
|
|
finally:
|
|
self._loading_status["reranker_model"] = False
|
|
|
|
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
|
return self._current_embedding_model
|
|
|
|
def get_current_reranker_model(self) -> Reranker | None:
|
|
return self._current_reranker_model
|
|
|
|
def get_available_embedding_models(self) -> list[str]:
|
|
return embedding_model_registry.list_available()
|
|
|
|
def get_available_reranker_models(self) -> list[str]:
|
|
return reranker_registry.list_available()
|