mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-19 14:54:05 +01:00
feat: introduce VectorStoreRegistry and PGVectorStore integration in ConfigService
This commit is contained in:
parent
ccb4db50b1
commit
ddabaed1c4
@ -43,6 +43,11 @@ class RerankerRegistry(Registry[Callable[..., Any]]):
|
|||||||
"""Registry specifically for reranker constructors."""
|
"""Registry specifically for reranker constructors."""
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStoreRegistry(Registry[Callable[..., Any]]):
|
||||||
|
"""Registry specifically for vector store constructors."""
|
||||||
|
|
||||||
|
|
||||||
# Global instances of the registries
|
# Global instances of the registries
|
||||||
embedding_model_registry = EmbeddingModelRegistry()
|
embedding_model_registry = EmbeddingModelRegistry()
|
||||||
reranker_registry = RerankerRegistry()
|
reranker_registry = RerankerRegistry()
|
||||||
|
vector_store_registry = VectorStoreRegistry()
|
||||||
|
|||||||
@ -7,3 +7,7 @@ class EmbeddingModelName(str, Enum):
|
|||||||
|
|
||||||
class RerankerModelName(str, Enum):
|
class RerankerModelName(str, Enum):
|
||||||
MiniLMReranker = "MiniLMReranker"
|
MiniLMReranker = "MiniLMReranker"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelName(str, Enum):
|
||||||
|
GeminiFlash = "gemini/gemini-2.0-flash"
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.interfaces import EmbeddingModel, Reranker
|
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
|
||||||
from app.core.registry import embedding_model_registry, reranker_registry
|
from app.core.registry import (
|
||||||
|
embedding_model_registry,
|
||||||
|
reranker_registry,
|
||||||
|
vector_store_registry,
|
||||||
|
)
|
||||||
from app.schemas.enums import EmbeddingModelName, RerankerModelName
|
from app.schemas.enums import EmbeddingModelName, RerankerModelName
|
||||||
from app.services.embedding_providers import MiniLMEmbeddingModel
|
from app.services.embedding_providers import MiniLMEmbeddingModel
|
||||||
from app.services.rerankers import MiniLMReranker
|
from app.services.rerankers import MiniLMReranker
|
||||||
|
from app.services.vector_stores import PGVectorStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -14,9 +19,11 @@ class ConfigService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._current_embedding_model: EmbeddingModel | None = None
|
self._current_embedding_model: EmbeddingModel | None = None
|
||||||
self._current_reranker_model: Reranker | None = None
|
self._current_reranker_model: Reranker | None = None
|
||||||
|
self._current_vector_store: VectorDB | None = None
|
||||||
self._loading_status: dict[str, bool] = {
|
self._loading_status: dict[str, bool] = {
|
||||||
"embedding_model": False,
|
"embedding_model": False,
|
||||||
"reranker_model": False,
|
"reranker_model": False,
|
||||||
|
"vector_store": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._register_models()
|
self._register_models()
|
||||||
@ -25,6 +32,7 @@ class ConfigService:
|
|||||||
"""Register all default models"""
|
"""Register all default models"""
|
||||||
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel)
|
||||||
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
reranker_registry.register("MiniLMReranker", MiniLMReranker)
|
||||||
|
vector_store_registry.register("PGVectorStore", PGVectorStore)
|
||||||
|
|
||||||
async def initialize_models(self):
|
async def initialize_models(self):
|
||||||
"""
|
"""
|
||||||
@ -60,6 +68,19 @@ class ConfigService:
|
|||||||
await self.set_reranker_model(reranker_model_name)
|
await self.set_reranker_model(reranker_model_name)
|
||||||
logger.info("Default reranker model initialized: %s", reranker_model_name)
|
logger.info("Default reranker model initialized: %s", reranker_model_name)
|
||||||
|
|
||||||
|
vector_store_name = (
|
||||||
|
getattr(settings, "VECTOR_STORE_TYPE", None) or "PGVectorStore"
|
||||||
|
)
|
||||||
|
if vector_store_name not in vector_store_registry.list_available():
|
||||||
|
logger.warning(
|
||||||
|
"Vector store '%s' is not valid. Falling back to default '%s'",
|
||||||
|
vector_store_name,
|
||||||
|
"PGVectorStore",
|
||||||
|
)
|
||||||
|
vector_store_name = "PGVectorStore"
|
||||||
|
await self.set_vector_store(vector_store_name)
|
||||||
|
logger.info("Default vector store initialized: %s", vector_store_name)
|
||||||
|
|
||||||
async def set_embedding_model(self, model_name: str) -> str:
|
async def set_embedding_model(self, model_name: str) -> str:
|
||||||
"""Set system embedding model based on provide model_name"""
|
"""Set system embedding model based on provide model_name"""
|
||||||
if (
|
if (
|
||||||
@ -124,14 +145,52 @@ class ConfigService:
|
|||||||
finally:
|
finally:
|
||||||
self._loading_status["reranker_model"] = False
|
self._loading_status["reranker_model"] = False
|
||||||
|
|
||||||
|
async def set_vector_store(self, store_name: str) -> str:
|
||||||
|
"""Set system vector store based on provided store_name"""
|
||||||
|
if (
|
||||||
|
self._current_vector_store
|
||||||
|
and self._current_vector_store.__class__.__name__ == store_name
|
||||||
|
):
|
||||||
|
return f"Vector store '{store_name}' is already in use."
|
||||||
|
|
||||||
|
if self._loading_status["vector_store"]:
|
||||||
|
return "Another vector store is currently being loaded. Please wait."
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._loading_status["vector_store"] = True
|
||||||
|
logger.info("Attempting to load vector store: %s", store_name)
|
||||||
|
store_constructor = vector_store_registry.get(store_name)
|
||||||
|
self._current_vector_store = store_constructor()
|
||||||
|
settings.VECTOR_STORE_TYPE = store_name # Update settings
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Vector store '%s' not found in registry.", store_name)
|
||||||
|
return (
|
||||||
|
f"Vector store '{store_name}' not available. "
|
||||||
|
f"Current store remains '{self._current_vector_store.__class__.__name__ if self._current_vector_store else 'None'}'."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error loading vector store %s: %s", store_name, e)
|
||||||
|
return f"Failed to load vector store '{store_name}': {e}"
|
||||||
|
else:
|
||||||
|
logger.info("Successfully loaded vector store: %s", store_name)
|
||||||
|
return f"Vector store set to '{store_name}' successfully."
|
||||||
|
finally:
|
||||||
|
self._loading_status["vector_store"] = False
|
||||||
|
|
||||||
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
def get_current_embedding_model(self) -> EmbeddingModel | None:
|
||||||
return self._current_embedding_model
|
return self._current_embedding_model
|
||||||
|
|
||||||
def get_current_reranker_model(self) -> Reranker | None:
|
def get_current_reranker_model(self) -> Reranker | None:
|
||||||
return self._current_reranker_model
|
return self._current_reranker_model
|
||||||
|
|
||||||
|
def get_current_vector_store(self) -> VectorDB | None:
|
||||||
|
return self._current_vector_store
|
||||||
|
|
||||||
def get_available_embedding_models(self) -> list[str]:
|
def get_available_embedding_models(self) -> list[str]:
|
||||||
return embedding_model_registry.list_available()
|
return embedding_model_registry.list_available()
|
||||||
|
|
||||||
def get_available_reranker_models(self) -> list[str]:
|
def get_available_reranker_models(self) -> list[str]:
|
||||||
return reranker_registry.list_available()
|
return reranker_registry.list_available()
|
||||||
|
|
||||||
|
def get_available_vector_stores(self) -> list[str]:
|
||||||
|
return vector_store_registry.list_available()
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
# pyright: reportArgumentType=false
|
# pyright: reportArgumentType=false
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.core.exception import ModelNotFoundError
|
from app.core.exception import ModelNotFoundError
|
||||||
from app.core.interfaces import Reranker, SearchResult
|
from app.core.interfaces import Reranker, SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
# pyright: reportCallIssue=false
|
# pyright: reportCallIssue=false
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extensions import AsIs, register_adapter
|
from psycopg2.extensions import AsIs, register_adapter
|
||||||
from psycopg2.extras import execute_values
|
from psycopg2.extras import execute_values
|
||||||
|
from structlog import get_logger
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.interfaces import SearchResult, VectorDB
|
from app.core.interfaces import SearchResult, VectorDB
|
||||||
@ -10,6 +11,8 @@ from app.core.interfaces import SearchResult, VectorDB
|
|||||||
register_adapter(np.ndarray, AsIs)
|
register_adapter(np.ndarray, AsIs)
|
||||||
register_adapter(np.float32, AsIs)
|
register_adapter(np.float32, AsIs)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class PGVectorStore(VectorDB):
|
class PGVectorStore(VectorDB):
|
||||||
"""PostgreSQL vector store implementation for document storage and retrieval."""
|
"""PostgreSQL vector store implementation for document storage and retrieval."""
|
||||||
@ -40,12 +43,14 @@ class PGVectorStore(VectorDB):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
|
logger.warning("No documents provided for upsert.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate document structure
|
# Validate document structure
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if not all(key in doc for key in ["content", "embedding", "source"]):
|
if not all(key in doc for key in ["content", "embedding", "source"]):
|
||||||
err = "Document must contain 'content', 'embedding', and 'source' keys"
|
err = "Document must contain 'content', 'embedding', and 'source' keys"
|
||||||
|
logger.error(f"Invalid document structure: {doc}")
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
|
|
||||||
data_to_insert = [
|
data_to_insert = [
|
||||||
@ -62,8 +67,8 @@ class PGVectorStore(VectorDB):
|
|||||||
RETURNING id
|
RETURNING id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
try:
|
||||||
try:
|
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||||
execute_values(
|
execute_values(
|
||||||
cursor,
|
cursor,
|
||||||
query,
|
query,
|
||||||
@ -72,11 +77,14 @@ class PGVectorStore(VectorDB):
|
|||||||
page_size=100,
|
page_size=100,
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception:
|
except psycopg2.Error as db_err:
|
||||||
conn.rollback()
|
logger.exception(f"Database error during upsert: {db_err}")
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error during upsert: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def search(self, vector: np.ndarray, top_k: int = 5) -> list[SearchResult]:
|
def search(self, vector: list, top_k: int = 5) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search for similar documents using vector similarity.
|
Search for similar documents using vector similarity.
|
||||||
|
|
||||||
@ -91,7 +99,8 @@ class PGVectorStore(VectorDB):
|
|||||||
psycopg2.Error: For database-related errors.
|
psycopg2.Error: For database-related errors.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not vector:
|
if len(vector) == 0:
|
||||||
|
logger.warning("Empty vector provided for search.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
@ -101,13 +110,17 @@ class PGVectorStore(VectorDB):
|
|||||||
LIMIT %s
|
LIMIT %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self._get_connection() as conn, conn.cursor() as cursor:
|
try:
|
||||||
try:
|
with self._get_connection() as conn, conn.cursor() as cursor:
|
||||||
cursor.execute(query, (np.array(vector).tolist(), top_k))
|
cursor.execute(query, (np.array(vector).tolist(), top_k))
|
||||||
return [
|
results = [
|
||||||
SearchResult(content=row[0], source=row[1])
|
SearchResult(content=row[0], source=row[1])
|
||||||
for row in cursor.fetchall()
|
for row in cursor.fetchall()
|
||||||
]
|
]
|
||||||
except Exception:
|
return results
|
||||||
conn.rollback()
|
except psycopg2.Error as db_err:
|
||||||
raise
|
logger.exception(f"Database error during search: {db_err}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error during search: {e}")
|
||||||
|
raise
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user