diff --git a/app/api/endpoints.py b/app/api/endpoints.py index 76cf97b..02d3f4b 100644 --- a/app/api/endpoints.py +++ b/app/api/endpoints.py @@ -6,7 +6,13 @@ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Up from fastapi.responses import StreamingResponse from structlog import get_logger -from app.schemas.models import IngestResponse, QueryRequest, QueryResponse +from app.schemas.models import ( + ConfigUpdateRequest, + IngestResponse, + QueryRequest, + QueryResponse, +) +from app.services.config_service import ConfigService from app.services.rag_service import RAGService logger = get_logger() @@ -21,6 +27,12 @@ def get_rag_service(): return app_state["rag_service"] +def get_config_service(): + from app.main import app_state + + return app_state["config_service"] + + @router.post("/ingest", response_model=IngestResponse) async def ingest_file( background_tasks: BackgroundTasks, @@ -79,3 +91,23 @@ async def query_index_stream( except Exception: logger.exception("Failed to answer query") raise HTTPException(status_code=500, detail="Internal server error") + + +@router.post("/config") +async def update_configuration( + request: ConfigUpdateRequest, + config_service: Annotated[ConfigService, Depends(get_config_service)], +): + responses = [] + if request.embedding_model: + response = await config_service.set_embedding_model(request.embedding_model) + responses.append(response) + if request.reranker_model: + response = await config_service.set_reranker_model(request.reranker_model) + responses.append(response) + # Add similar logic for LLM models and providers when implemented + + if not responses: + return {"message": "No configuration changes requested."} + + return {"message": " ".join(responses)} diff --git a/app/core/registry.py b/app/core/registry.py new file mode 100644 index 0000000..1c9aa76 --- /dev/null +++ b/app/core/registry.py @@ -0,0 +1,48 @@ +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class Registry(Generic[T]): + """A generic registry to store and retrieve objects by name.""" + + def __init__(self): + self._items: dict[str, T] = {} + + def register(self, name: str, item: T): + """Registers an item with a given name.""" + if not isinstance(name, str) or not name: + raise ValueError("Name must be a non-empty string.") + if name in self._items: + raise ValueError(f"Item with name '{name}' already registered.") + self._items[name] = item + + def get(self, name: str) -> T: + """Retrieves an item by its name.""" + if name not in self._items: + raise KeyError(f"Item with name '{name}' not found in registry.") + return self._items[name] + + def unregister(self, name: str): + """Unregisters an item by its name.""" + if name not in self._items: + raise KeyError(f"Item with name '{name}' not found in registry.") + del self._items[name] + + def list_available(self) -> list[str]: + """Lists all available item names in the registry.""" + return list(self._items.keys()) + + +class EmbeddingModelRegistry(Registry[Callable[..., Any]]): + """Registry specifically for embedding model constructors.""" + + +class RerankerRegistry(Registry[Callable[..., Any]]): + """Registry specifically for reranker constructors.""" + + +# Global instances of the registries +embedding_model_registry = EmbeddingModelRegistry() +reranker_registry = RerankerRegistry() diff --git a/app/main.py b/app/main.py index f56cfe0..f303ca2 100644 --- a/app/main.py +++ b/app/main.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from structlog import get_logger from app.api import endpoints -from app.services.embedding_providers import MiniLMEmbeddingModel +from app.services.config_service import ConfigService from app.services.rag_service import RAGService from app.services.vector_stores import PGVectorStore @@ -20,14 +20,17 @@ app_state = {} @asynccontextmanager async def lifespan(app: FastAPI): - embedding_provider = MiniLMEmbeddingModel() - vector_store_provider = PGVectorStore() + config_service = ConfigService() + await config_service.initialize_models() + app_state["config_service"] = config_service # This code runs on startup logger.info("Application starting up...") # Initialize the RAG Service and store it in the app_state app_state["rag_service"] = RAGService( - embedding_model=embedding_provider, vector_db=vector_store_provider + embedding_model=config_service.get_current_embedding_model(), + vector_db=PGVectorStore(), + reranker=config_service.get_current_reranker_model(), ) yield diff --git a/app/schemas/models.py b/app/schemas/models.py index d492e8a..2a43a63 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing import Optional class QueryRequest(BaseModel): @@ -13,3 +14,16 @@ class QueryResponse(BaseModel): class IngestResponse(BaseModel): message: str filename: str + + +class ConfigUpdateRequest(BaseModel): + embedding_model: Optional[str] = Field( + None, description="Name of the embedding model to use" + ) + reranker_model: Optional[str] = Field( + None, description="Name of the reranker model to use" + ) + llm_model: Optional[str] = Field(None, description="Name of the LLM model to use") + llm_provider: Optional[str] = Field( + None, description="Name of the LLM provider to use" + ) \ No newline at end of file diff --git a/app/services/config_service.py b/app/services/config_service.py new file mode 100644 index 0000000..2b03cc6 --- /dev/null +++ b/app/services/config_service.py @@ -0,0 +1,118 @@ +import logging + +from app.core.config import settings +from app.core.interfaces import EmbeddingModel, Reranker +from app.core.registry import embedding_model_registry, reranker_registry +from app.services.embedding_providers import MiniLMEmbeddingModel +from app.services.rerankers import MiniLMReranker + +logger = logging.getLogger(__name__) + + +class ConfigService: + def __init__(self): + self._current_embedding_model: EmbeddingModel | None = None + self._current_reranker_model: Reranker | None = None + self._loading_status: dict[str, bool] = { + "embedding_model": False, + "reranker_model": False, + } + + # Register available models + self._register_models() + + def _register_models(self): + # Register embedding models + embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel) + # Register reranker models + reranker_registry.register("MiniLMReranker", MiniLMReranker) + + async def initialize_models(self): + logger.info("Initializing default models...") + # Initialize default embedding model + default_embedding_model_name = settings.EMBEDDING_MODEL + await self.set_embedding_model(default_embedding_model_name) + logger.info( + f"Default embedding model initialized: {default_embedding_model_name}" + ) + + # Initialize default reranker model (if any) + # Assuming a default reranker can be set in settings if needed + # For now, let's assume MiniLMReranker is the default if not specified + default_reranker_model_name = "MiniLMReranker" # Or from settings + await self.set_reranker_model(default_reranker_model_name) + logger.info( + f"Default reranker model initialized: {default_reranker_model_name}" + ) + + async def set_embedding_model(self, model_name: str) -> str: + if ( + self._current_embedding_model + and self._current_embedding_model.__class__.__name__ == model_name + ): + return f"Embedding model '{model_name}' is already in use." + + if self._loading_status["embedding_model"]: + return "Another embedding model is currently being loaded. Please wait." + + try: + self._loading_status["embedding_model"] = True + logger.info(f"Attempting to load embedding model: {model_name}") + model_constructor = embedding_model_registry.get(model_name) + self._current_embedding_model = model_constructor() + settings.EMBEDDING_MODEL = model_name # Update settings + logger.info(f"Successfully loaded embedding model: {model_name}") + return f"Embedding model set to '{model_name}' successfully." + except KeyError: + logger.warning(f"Embedding model '{model_name}' not found in registry.") + return ( + f"Embedding model '{model_name}' not available. " + f"Current model remains '{self._current_embedding_model.__class__.__name__ if self._current_embedding_model else 'None'}'." + ) + except Exception as e: + logger.exception(f"Error loading embedding model {model_name}: {e}") + return f"Failed to load embedding model '{model_name}': {e}" + finally: + self._loading_status["embedding_model"] = False + + async def set_reranker_model(self, model_name: str) -> str: + if ( + self._current_reranker_model + and self._current_reranker_model.__class__.__name__ == model_name + ): + return f"Reranker model '{model_name}' is already in use." + + if self._loading_status["reranker_model"]: + return "Another reranker model is currently being loaded. Please wait." + + try: + self._loading_status["reranker_model"] = True + logger.info(f"Attempting to load reranker model: {model_name}") + model_constructor = reranker_registry.get(model_name) + self._current_reranker_model = model_constructor() + # settings.RERANKER_MODEL = model_name # Add this to settings if you want to persist + logger.info(f"Successfully loaded reranker model: {model_name}") + return f"Reranker model set to '{model_name}' successfully." + except KeyError: + logger.warning(f"Reranker model '{model_name}' not found in registry.") + return ( + f"Reranker model '{model_name}' not available. " + f"Current model remains '{self._current_reranker_model.__class__.__name__ if self._current_reranker_model else 'None'}'." + ) + except Exception as e: + logger.exception(f"Error loading reranker model {model_name}: {e}") + return f"Failed to load reranker model '{model_name}': {e}" + finally: + self._loading_status["reranker_model"] = False + + def get_current_embedding_model(self) -> EmbeddingModel | None: + return self._current_embedding_model + + def get_current_reranker_model(self) -> Reranker | None: + return self._current_reranker_model + + def get_available_embedding_models(self) -> list[str]: + return embedding_model_registry.list_available() + + def get_available_reranker_models(self) -> list[str]: + return reranker_registry.list_available() diff --git a/app/services/rag_service.py b/app/services/rag_service.py index 8da7a42..7d69acd 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -6,7 +6,7 @@ from typing import TypedDict import litellm from structlog import get_logger -from app.core.interfaces import EmbeddingModel, VectorDB +from app.core.interfaces import EmbeddingModel, Reranker, VectorDB from app.core.utils import RecursiveCharacterTextSplitter logger = get_logger() @@ -18,10 +18,16 @@ class AnswerResult(TypedDict): class RAGService: - def __init__(self, embedding_model: EmbeddingModel, vector_db: VectorDB): + def __init__( + self, + embedding_model: EmbeddingModel, + vector_db: VectorDB, + reranker: Reranker | None = None, + ): self.embedding_model = embedding_model self.vector_db = vector_db - self.prompt = """Answer the question based on the following context. + self.reranker = reranker + self.prompt = """Answer the question based on the following context. If you don't know the answer, say you don't know. Don't make up an answer. Context: @@ -69,8 +75,12 @@ Answer:""" def answer_query(self, question: str) -> AnswerResult: query_embedding = self.embedding_model.embed_query(question) search_results = self.vector_db.search(query_embedding, top_k=5) - sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) + if self.reranker: + logger.info("Reranking search results...") + search_results = self.reranker.rerank(search_results, question) + + sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) context_str = "\n\n".join([chunk["content"] for chunk in search_results]) try: @@ -92,7 +102,7 @@ Answer:""" max_tokens=500, ) - answer_text = response.choices[0].message.content.strip() # type: ignore + answer_text = response.choices[0].message.content.strip() if not answer_text: answer_text = "No answer generated" @@ -108,6 +118,11 @@ Answer:""" def answer_query_stream(self, question: str) -> Generator[str, None, None]: query_embedding = self.embedding_model.embed_query(question) search_results = self.vector_db.search(query_embedding, top_k=5) + + if self.reranker: + logger.info("Reranking search results...") + search_results = self.reranker.rerank(search_results, question) + sources = list({chunk["source"] for chunk in search_results if chunk["source"]}) context_str = "\n\n".join([chunk["content"] for chunk in search_results]) diff --git a/app/services/rag_service_v1.py b/app/services/rag_service_v1.py index 4874594..48ec0db 100644 --- a/app/services/rag_service_v1.py +++ b/app/services/rag_service_v1.py @@ -175,9 +175,9 @@ Answer:""" with self.db_conn.cursor() as cursor: cursor.execute( """ - SELECT content, source - FROM documents - ORDER BY embedding <-> %s::vector + SELECT content, source + FROM documents + ORDER BY embedding <-> %s::vector LIMIT %s """, (question_embedding.tolist(), top_k),