mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
refactor: integrate ConfigService for model management and add configuration update endpoint
This commit is contained in:
parent
cf7d1e8218
commit
011de22885
@ -6,7 +6,13 @@ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Up
|
||||
from fastapi.responses import StreamingResponse
|
||||
from structlog import get_logger
|
||||
|
||||
from app.schemas.models import IngestResponse, QueryRequest, QueryResponse
|
||||
from app.schemas.models import (
|
||||
ConfigUpdateRequest,
|
||||
IngestResponse,
|
||||
QueryRequest,
|
||||
QueryResponse,
|
||||
)
|
||||
from app.services.config_service import ConfigService
|
||||
from app.services.rag_service import RAGService
|
||||
|
||||
logger = get_logger()
|
||||
@ -21,6 +27,12 @@ def get_rag_service():
|
||||
return app_state["rag_service"]
|
||||
|
||||
|
||||
def get_config_service():
|
||||
from app.main import app_state
|
||||
|
||||
return app_state["config_service"]
|
||||
|
||||
|
||||
@router.post("/ingest", response_model=IngestResponse)
|
||||
async def ingest_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
@ -79,3 +91,23 @@ async def query_index_stream(
|
||||
except Exception:
|
||||
logger.exception("Failed to answer query")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
async def update_configuration(
|
||||
request: ConfigUpdateRequest,
|
||||
config_service: Annotated[ConfigService, Depends(get_config_service)],
|
||||
):
|
||||
responses = []
|
||||
if request.embedding_model:
|
||||
response = await config_service.set_embedding_model(request.embedding_model)
|
||||
responses.append(response)
|
||||
if request.reranker_model:
|
||||
response = await config_service.set_reranker_model(request.reranker_model)
|
||||
responses.append(response)
|
||||
# Add similar logic for LLM models and providers when implemented
|
||||
|
||||
if not responses:
|
||||
return {"message": "No configuration changes requested."}
|
||||
|
||||
return {"message": " ".join(responses)}
|
||||
|
||||
48
app/core/registry.py
Normal file
48
app/core/registry.py
Normal file
@ -0,0 +1,48 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Registry(Generic[T]):
|
||||
"""A generic registry to store and retrieve objects by name."""
|
||||
|
||||
def __init__(self):
|
||||
self._items: dict[str, T] = {}
|
||||
|
||||
def register(self, name: str, item: T):
|
||||
"""Registers an item with a given name."""
|
||||
if not isinstance(name, str) or not name:
|
||||
raise ValueError("Name must be a non-empty string.")
|
||||
if name in self._items:
|
||||
raise ValueError(f"Item with name '{name}' already registered.")
|
||||
self._items[name] = item
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
"""Retrieves an item by its name."""
|
||||
if name not in self._items:
|
||||
raise KeyError(f"Item with name '{name}' not found in registry.")
|
||||
return self._items[name]
|
||||
|
||||
def unregister(self, name: str):
|
||||
"""Unregisters an item by its name."""
|
||||
if name not in self._items:
|
||||
raise KeyError(f"Item with name '{name}' not found in registry.")
|
||||
del self._items[name]
|
||||
|
||||
def list_available(self) -> list[str]:
|
||||
"""Lists all available item names in the registry."""
|
||||
return list(self._items.keys())
|
||||
|
||||
|
||||
class EmbeddingModelRegistry(Registry[Callable[..., Any]]):
|
||||
"""Registry specifically for embedding model constructors."""
|
||||
|
||||
|
||||
class RerankerRegistry(Registry[Callable[..., Any]]):
|
||||
"""Registry specifically for reranker constructors."""
|
||||
|
||||
|
||||
# Global instances of the registries
|
||||
embedding_model_registry = EmbeddingModelRegistry()
|
||||
reranker_registry = RerankerRegistry()
|
||||
11
app/main.py
11
app/main.py
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
||||
from structlog import get_logger
|
||||
|
||||
from app.api import endpoints
|
||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||
from app.services.config_service import ConfigService
|
||||
from app.services.rag_service import RAGService
|
||||
from app.services.vector_stores import PGVectorStore
|
||||
|
||||
@ -20,14 +20,17 @@ app_state = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
embedding_provider = MiniLMEmbeddingModel()
|
||||
vector_store_provider = PGVectorStore()
|
||||
config_service = ConfigService()
|
||||
await config_service.initialize_models()
|
||||
app_state["config_service"] = config_service
|
||||
|
||||
# This code runs on startup
|
||||
logger.info("Application starting up...")
|
||||
# Initialize the RAG Service and store it in the app_state
|
||||
app_state["rag_service"] = RAGService(
|
||||
embedding_model=embedding_provider, vector_db=vector_store_provider
|
||||
embedding_model=config_service.get_current_embedding_model(),
|
||||
vector_db=PGVectorStore(),
|
||||
reranker=config_service.get_current_reranker_model(),
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
@ -13,3 +14,16 @@ class QueryResponse(BaseModel):
|
||||
class IngestResponse(BaseModel):
|
||||
message: str
|
||||
filename: str
|
||||
|
||||
|
||||
class ConfigUpdateRequest(BaseModel):
|
||||
embedding_model: Optional[str] = Field(
|
||||
None, description="Name of the embedding model to use"
|
||||
)
|
||||
reranker_model: Optional[str] = Field(
|
||||
None, description="Name of the reranker model to use"
|
||||
)
|
||||
llm_model: Optional[str] = Field(None, description="Name of the LLM model to use")
|
||||
llm_provider: Optional[str] = Field(
|
||||
None, description="Name of the LLM provider to use"
|
||||
)
|
||||
118
app/services/config_service.py
Normal file
118
app/services/config_service.py
Normal file
@ -0,0 +1,118 @@
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.interfaces import EmbeddingModel, Reranker
|
||||
from app.core.registry import embedding_model_registry, reranker_registry
|
||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||
from app.services.rerankers import MiniLMReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigService:
|
||||
def __init__(self):
|
||||
self._current_embedding_model: EmbeddingModel | None = None
|
||||
self._current_reranker_model: Reranker | None = None
|
||||
self._loading_status: dict[str, bool] = {
|
||||
"embedding_model": False,
|
||||
"reranker_model": False,
|
||||
}
|
||||
|
||||
# Register available models
|
||||
self._register_models()
|
||||
|
||||
def _register_models(self):
|
||||
# Register embedding models
|
||||
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
||||
# Register reranker models
|
||||
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||
|
||||
async def initialize_models(self):
|
||||
logger.info("Initializing default models...")
|
||||
# Initialize default embedding model
|
||||
default_embedding_model_name = settings.EMBEDDING_MODEL
|
||||
await self.set_embedding_model(default_embedding_model_name)
|
||||
logger.info(
|
||||
f"Default embedding model initialized: {default_embedding_model_name}"
|
||||
)
|
||||
|
||||
# Initialize default reranker model (if any)
|
||||
# Assuming a default reranker can be set in settings if needed
|
||||
# For now, let's assume MiniLMReranker is the default if not specified
|
||||
default_reranker_model_name = "MiniLMReranker" # Or from settings
|
||||
await self.set_reranker_model(default_reranker_model_name)
|
||||
logger.info(
|
||||
f"Default reranker model initialized: {default_reranker_model_name}"
|
||||
)
|
||||
|
||||
async def set_embedding_model(self, model_name: str) -> str:
|
||||
if (
|
||||
self._current_embedding_model
|
||||
and self._current_embedding_model.__class__.__name__ == model_name
|
||||
):
|
||||
return f"Embedding model '{model_name}' is already in use."
|
||||
|
||||
if self._loading_status["embedding_model"]:
|
||||
return "Another embedding model is currently being loaded. Please wait."
|
||||
|
||||
try:
|
||||
self._loading_status["embedding_model"] = True
|
||||
logger.info(f"Attempting to load embedding model: {model_name}")
|
||||
model_constructor = embedding_model_registry.get(model_name)
|
||||
self._current_embedding_model = model_constructor()
|
||||
settings.EMBEDDING_MODEL = model_name # Update settings
|
||||
logger.info(f"Successfully loaded embedding model: {model_name}")
|
||||
return f"Embedding model set to '{model_name}' successfully."
|
||||
except KeyError:
|
||||
logger.warning(f"Embedding model '{model_name}' not found in registry.")
|
||||
return (
|
||||
f"Embedding model '{model_name}' not available. "
|
||||
f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error loading embedding model {model_name}: {e}")
|
||||
return f"Failed to load embedding model '{model_name}': {e}"
|
||||
finally:
|
||||
self._loading_status["embedding_model"] = False
|
||||
|
||||
async def set_reranker_model(self, model_name: str) -> str:
|
||||
if (
|
||||
self._current_reranker_model
|
||||
and self._current_reranker_model.__class__.__name__ == model_name
|
||||
):
|
||||
return f"Reranker model '{model_name}' is already in use."
|
||||
|
||||
if self._loading_status["reranker_model"]:
|
||||
return "Another reranker model is currently being loaded. Please wait."
|
||||
|
||||
try:
|
||||
self._loading_status["reranker_model"] = True
|
||||
logger.info(f"Attempting to load reranker model: {model_name}")
|
||||
model_constructor = reranker_registry.get(model_name)
|
||||
self._current_reranker_model = model_constructor()
|
||||
# settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist
|
||||
logger.info(f"Successfully loaded reranker model: {model_name}")
|
||||
return f"Reranker model set to '{model_name}' successfully."
|
||||
except KeyError:
|
||||
logger.warning(f"Reranker model '{model_name}' not found in registry.")
|
||||
return (
|
||||
f"Reranker model '{model_name}' not available. "
|
||||
f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error loading reranker model {model_name}: {e}")
|
||||
return f"Failed to load reranker model '{model_name}': {e}"
|
||||
finally:
|
||||
self._loading_status["reranker_model"] = False
|
||||
|
||||
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
||||
return self._current_embedding_model
|
||||
|
||||
def get_current_reranker_model(self) -> Reranker | None:
|
||||
return self._current_reranker_model
|
||||
|
||||
def get_available_embedding_models(self) -> list[str]:
|
||||
return embedding_model_registry.list_available()
|
||||
|
||||
def get_available_reranker_models(self) -> list[str]:
|
||||
return reranker_registry.list_available()
|
||||
@ -6,7 +6,7 @@ from typing import TypedDict
|
||||
import litellm
|
||||
from structlog import get_logger
|
||||
|
||||
from app.core.interfaces import EmbeddingModel, VectorDB
|
||||
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
|
||||
from app.core.utils import RecursiveCharacterTextSplitter
|
||||
|
||||
logger = get_logger()
|
||||
@ -18,10 +18,16 @@ class AnswerResult(TypedDict):
|
||||
|
||||
|
||||
class RAGService:
|
||||
def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: EmbeddingModel,
|
||||
vector_db: VectorDB,
|
||||
reranker: Reranker | None = None,
|
||||
):
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_db = vector_db
|
||||
self.prompt = """Answer the question based on the following context.
|
||||
self.reranker = reranker
|
||||
self.prompt = """Answer the question based on the following context.
|
||||
If you don't know the answer, say you don't know. Don't make up an answer.
|
||||
|
||||
Context:
|
||||
@ -69,8 +75,12 @@ Answer:"""
|
||||
def answer_query(self, question: str) -> AnswerResult:
|
||||
query_embedding = self.embedding_model.embed_query(question)
|
||||
search_results = self.vector_db.search(query_embedding, top_k=5)
|
||||
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
||||
|
||||
if self.reranker:
|
||||
logger.info("Reranking search results...")
|
||||
search_results = self.reranker.rerank(search_results, question)
|
||||
|
||||
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
||||
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
||||
|
||||
try:
|
||||
@ -92,7 +102,7 @@ Answer:"""
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
answer_text = response.choices[0].message.content.strip() # type: ignore
|
||||
answer_text = response.choices[0].message.content.strip()
|
||||
|
||||
if not answer_text:
|
||||
answer_text = "No answer generated"
|
||||
@ -108,6 +118,11 @@ Answer:"""
|
||||
def answer_query_stream(self, question: str) -> Generator[str, None, None]:
|
||||
query_embedding = self.embedding_model.embed_query(question)
|
||||
search_results = self.vector_db.search(query_embedding, top_k=5)
|
||||
|
||||
if self.reranker:
|
||||
logger.info("Reranking search results...")
|
||||
search_results = self.reranker.rerank(search_results, question)
|
||||
|
||||
sources = list({chunk["source"] for chunk in search_results if chunk["source"]})
|
||||
context_str = "\n\n".join([chunk["content"] for chunk in search_results])
|
||||
|
||||
|
||||
@ -175,9 +175,9 @@ Answer:"""
|
||||
with self.db_conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT content, source
|
||||
FROM documents
|
||||
ORDER BY embedding <-> %s::vector
|
||||
SELECT content, source
|
||||
FROM documents
|
||||
ORDER BY embedding <-> %s::vector
|
||||
LIMIT %s
|
||||
""",
|
||||
(question_embedding.tolist(), top_k),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user