mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
feat: introduce VectorStoreRegistry and PGVectorStore integration in ConfigService
This commit is contained in:
parent
ccb4db50b1
commit
ddabaed1c4
@ -43,6 +43,11 @@ class RerankerRegistry(Registry[Callable[..., Any]]):
|
||||
"""Registry specifically for reranker constructors."""
|
||||
|
||||
|
||||
class VectorStoreRegistry(Registry[Callable[..., Any]]):
|
||||
"""Registry specifically for vector store constructors."""
|
||||
|
||||
|
||||
# Global instances of the registries
|
||||
embedding_model_registry = EmbeddingModelRegistry()
|
||||
reranker_registry = RerankerRegistry()
|
||||
vector_store_registry = VectorStoreRegistry()
|
||||
|
||||
@ -7,3 +7,7 @@ class EmbeddingModelName(str, Enum):
|
||||
|
||||
class RerankerModelName(str, Enum):
|
||||
MiniLMReranker = "MiniLMReranker"
|
||||
|
||||
|
||||
class LLMModelName(str, Enum):
|
||||
GeminiFlash = "gemini/gemini-2.0-flash"
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.interfaces import EmbeddingModel, Reranker
|
||||
from app.core.registry import embedding_model_registry, reranker_registry
|
||||
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
|
||||
from app.core.registry import (
|
||||
embedding_model_registry,
|
||||
reranker_registry,
|
||||
vector_store_registry,
|
||||
)
|
||||
from app.schemas.enums import EmbeddingModelName, RerankerModelName
|
||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||
from app.services.rerankers import MiniLMReranker
|
||||
from app.services.vector_stores import PGVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,9 +19,11 @@ class ConfigService:
|
||||
def __init__(self):
|
||||
self._current_embedding_model: EmbeddingModel | None = None
|
||||
self._current_reranker_model: Reranker | None = None
|
||||
self._current_vector_store: VectorDB | None = None
|
||||
self._loading_status: dict[str, bool] = {
|
||||
"embedding_model": False,
|
||||
"reranker_model": False,
|
||||
"vector_store": False,
|
||||
}
|
||||
|
||||
self._register_models()
|
||||
@ -25,6 +32,7 @@ class ConfigService:
|
||||
"""Register all default models"""
|
||||
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
||||
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||
vector_store_registry.register("PGVectorStore", PGVectorStore)
|
||||
|
||||
async def initialize_models(self):
|
||||
"""
|
||||
@ -60,6 +68,19 @@ class ConfigService:
|
||||
await self.set_reranker_model(reranker_model_name)
|
||||
logger.info("Default reranker model initialized: %s", reranker_model_name)
|
||||
|
||||
vector_store_name = (
|
||||
getattr(settings, "VECTOR_STORE_TYPE", None) or "PGVectorStore"
|
||||
)
|
||||
if vector_store_name not in vector_store_registry.list_available():
|
||||
logger.warning(
|
||||
"Vector store '%s' is not valid. Falling back to default '%s'",
|
||||
vector_store_name,
|
||||
"PGVectorStore",
|
||||
)
|
||||
vector_store_name = "PGVectorStore"
|
||||
await self.set_vector_store(vector_store_name)
|
||||
logger.info("Default vector store initialized: %s", vector_store_name)
|
||||
|
||||
async def set_embedding_model(self, model_name: str) -> str:
|
||||
"""Set system embedding model based on provide model_name"""
|
||||
if (
|
||||
@ -124,14 +145,52 @@ class ConfigService:
|
||||
finally:
|
||||
self._loading_status["reranker_model"] = False
|
||||
|
||||
async def set_vector_store(self, store_name: str) -> str:
|
||||
"""Set system vector store based on provided store_name"""
|
||||
if (
|
||||
self._current_vector_store
|
||||
and self._current_vector_store.__class__.__name__ == store_name
|
||||
):
|
||||
return f"Vector store '{store_name}' is already in use."
|
||||
|
||||
if self._loading_status["vector_store"]:
|
||||
return "Another vector store is currently being loaded. Please wait."
|
||||
|
||||
try:
|
||||
self._loading_status["vector_store"] = True
|
||||
logger.info("Attempting to load vector store: %s", store_name)
|
||||
store_constructor = vector_store_registry.get(store_name)
|
||||
self._current_vector_store = store_constructor()
|
||||
settings.VECTOR_STORE_TYPE = store_name # Update settings
|
||||
except KeyError:
|
||||
logger.warning("Vector store '%s' not found in registry.", store_name)
|
||||
return (
|
||||
f"Vector store '{store_name}' not available. "
|
||||
f"Current store remains '{self._current_vector_store.__class__.__name__ if self._current_vector_store else 'None'}'."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error loading vector store %s: %s", store_name, e)
|
||||
return f"Failed to load vector store '{store_name}': {e}"
|
||||
else:
|
||||
logger.info("Successfully loaded vector store: %s", store_name)
|
||||
return f"Vector store set to '{store_name}' successfully."
|
||||
finally:
|
||||
self._loading_status["vector_store"] = False
|
||||
|
||||
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
||||
return self._current_embedding_model
|
||||
|
||||
def get_current_reranker_model(self) -> Reranker | None:
|
||||
return self._current_reranker_model
|
||||
|
||||
def get_current_vector_store(self) -> VectorDB | None:
|
||||
return self._current_vector_store
|
||||
|
||||
def get_available_embedding_models(self) -> list[str]:
|
||||
return embedding_model_registry.list_available()
|
||||
|
||||
def get_available_reranker_models(self) -> list[str]:
|
||||
return reranker_registry.list_available()
|
||||
|
||||
def get_available_vector_stores(self) -> list[str]:
|
||||
return vector_store_registry.list_available()
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import logging
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
from sentence_transformers import CrossEncoder
|
||||
from structlog import get_logger
|
||||
|
||||
from app.core.exception import ModelNotFoundError
|
||||
from app.core.interfaces import Reranker, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
# pyright: reportCallIssue=false
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import psycopg2
|
||||
from psycopg2.extensions import AsIs, register_adapter
|
||||
from psycopg2.extras import execute_values
|
||||
from structlog import get_logger
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.interfaces import SearchResult, VectorDB
|
||||
@ -10,6 +11,8 @@ from app.core.interfaces import SearchResult, VectorDB
|
||||
register_adapter(np.ndarray, AsIs)
|
||||
register_adapter(np.float32, AsIs)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class PGVectorStore(VectorDB):
|
||||
"""PostgreSQL vector store implementation for document storage and retrieval."""
|
||||
@ -40,12 +43,14 @@ class PGVectorStore(VectorDB):
|
||||
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for upsert.")
|
||||
return
|
||||
|
||||
# Validate document structure
|
||||
for doc in documents:
|
||||
if not all(key in doc for key in ["content", "embedding", "source"]):
|
||||
err = "Document must contain 'content', 'embedding', and 'source' keys"
|
||||
logger.error(f"Invalid document structure: {doc}")
|
||||
raise ValueError(err)
|
||||
|
||||
data_to_insert = [
|
||||
@ -62,8 +67,8 @@ class PGVectorStore(VectorDB):
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||
try:
|
||||
try:
|
||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||
execute_values(
|
||||
cursor,
|
||||
query,
|
||||
@ -72,11 +77,14 @@ class PGVectorStore(VectorDB):
|
||||
page_size=100,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
except psycopg2.Error as db_err:
|
||||
logger.exception(f"Database error during upsert: {db_err}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(self, vector: np.ndarray, top_k: int = 5) -> list[SearchResult]:
|
||||
def search(self, vector: list, top_k: int = 5) -> list[SearchResult]:
|
||||
"""
|
||||
Search for similar documents using vector similarity.
|
||||
|
||||
@ -91,7 +99,8 @@ class PGVectorStore(VectorDB):
|
||||
psycopg2.Error: For database-related errors.
|
||||
|
||||
"""
|
||||
if not vector:
|
||||
if len(vector) == 0:
|
||||
logger.warning("Empty vector provided for search.")
|
||||
return []
|
||||
|
||||
query = """
|
||||
@ -101,13 +110,17 @@ class PGVectorStore(VectorDB):
|
||||
LIMIT %s
|
||||
"""
|
||||
|
||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||
try:
|
||||
try:
|
||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||
cursor.execute(query, (np.array(vector).tolist(), top_k))
|
||||
return [
|
||||
results = [
|
||||
SearchResult(content=row[0], source=row[1])
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
return results
|
||||
except psycopg2.Error as db_err:
|
||||
logger.exception(f"Database error during search: {db_err}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during search: {e}")
|
||||
raise
|
||||
|
||||
Loading…
Reference in New Issue
Block a user