mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 22:44:03 +01:00
119 lines
5.2 KiB
Python
119 lines
5.2 KiB
Python
import logging
|
|
|
|
from app.core.config import settings
|
|
from app.core.interfaces import EmbeddingModel, Reranker
|
|
from app.core.registry import embedding_model_registry, reranker_registry
|
|
from app.services.embedding_providers import MiniLMEmbeddingModel
|
|
from app.services.rerankers import MiniLMReranker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConfigService:
|
|
def __init__(self):
|
|
self._current_embedding_model: EmbeddingModel | None = None
|
|
self._current_reranker_model: Reranker | None = None
|
|
self._loading_status: dict[str, bool] = {
|
|
"embedding_model": False,
|
|
"reranker_model": False,
|
|
}
|
|
|
|
# Register available models
|
|
self._register_models()
|
|
|
|
def _register_models(self):
|
|
# Register embedding models
|
|
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
|
# Register reranker models
|
|
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
|
|
|
async def initialize_models(self):
|
|
logger.info("Initializing default models...")
|
|
# Initialize default embedding model
|
|
default_embedding_model_name = settings.EMBEDDING_MODEL
|
|
await self.set_embedding_model(default_embedding_model_name)
|
|
logger.info(
|
|
f"Default embedding model initialized: {default_embedding_model_name}"
|
|
)
|
|
|
|
# Initialize default reranker model (if any)
|
|
# Assuming a default reranker can be set in settings if needed
|
|
# For now, let's assume MiniLMReranker is the default if not specified
|
|
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
|
await self.set_reranker_model(default_reranker_model_name)
|
|
logger.info(
|
|
f"Default reranker model initialized: {default_reranker_model_name}"
|
|
)
|
|
|
|
async def set_embedding_model(self, model_name: str) -> str:
|
|
if (
|
|
self._current_embedding_model
|
|
and self._current_embedding_model.__class__.__name__ == model_name
|
|
):
|
|
return f"Embedding model '{model_name}' is already in use."
|
|
|
|
if self._loading_status["embedding_model"]:
|
|
return "Another embedding model is currently being loaded. Please wait."
|
|
|
|
try:
|
|
self._loading_status["embedding_model"] = True
|
|
logger.info(f"Attempting to load embedding model: {model_name}")
|
|
model_constructor = embedding_model_registry.get(model_name)
|
|
self._current_embedding_model = model_constructor()
|
|
settings.EMBEDDING_MODEL = model_name # Update settings
|
|
logger.info(f"Successfully loaded embedding model: {model_name}")
|
|
return f"Embedding model set to '{model_name}' successfully."
|
|
except KeyError:
|
|
logger.warning(f"Embedding model '{model_name}' not found in registry.")
|
|
return (
|
|
f"Embedding model '{model_name}' not available. "
|
|
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Error loading embedding model {model_name}: {e}")
|
|
return f"Failed to load embedding model '{model_name}': {e}"
|
|
finally:
|
|
self._loading_status["embedding_model"] = False
|
|
|
|
async def set_reranker_model(self, model_name: str) -> str:
|
|
if (
|
|
self._current_reranker_model
|
|
and self._current_reranker_model.__class__.__name__ == model_name
|
|
):
|
|
return f"Reranker model '{model_name}' is already in use."
|
|
|
|
if self._loading_status["reranker_model"]:
|
|
return "Another reranker model is currently being loaded. Please wait."
|
|
|
|
try:
|
|
self._loading_status["reranker_model"] = True
|
|
logger.info(f"Attempting to load reranker model: {model_name}")
|
|
model_constructor = reranker_registry.get(model_name)
|
|
self._current_reranker_model = model_constructor()
|
|
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
|
|
logger.info(f"Successfully loaded reranker model: {model_name}")
|
|
return f"Reranker model set to '{model_name}' successfully."
|
|
except KeyError:
|
|
logger.warning(f"Reranker model '{model_name}' not found in registry.")
|
|
return (
|
|
f"Reranker model '{model_name}' not available. "
|
|
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Error loading reranker model {model_name}: {e}")
|
|
return f"Failed to load reranker model '{model_name}': {e}"
|
|
finally:
|
|
self._loading_status["reranker_model"] = False
|
|
|
|
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
|
return self._current_embedding_model
|
|
|
|
def get_current_reranker_model(self) -> Reranker | None:
|
|
return self._current_reranker_model
|
|
|
|
def get_available_embedding_models(self) -> list[str]:
|
|
return embedding_model_registry.list_available()
|
|
|
|
def get_available_reranker_models(self) -> list[str]:
|
|
return reranker_registry.list_available()
|