diff --git a/.env.example b/.env.example index 31132ba..bc9e44c 100644 --- a/.env.example +++ b/.env.example @@ -1,19 +1,20 @@ -# API Configuration -API_PORT=8001 - -# security -SECRET_KEY=your-secret-key-here - -# LLM -OPENAI_API_KEY=your-openai-key - - -# Database Configuration +# Database configuration +DB_SERVER=localhost +DB_USER=postgres +DB_PASSWORD=yourpassword +DB_NAME=chat_hub DB_PORT=5432 -POSTGRES_USER=user -POSTGRES_PASSWORD=password -POSTGRES_DB=mydatabase -# Environment -ENVIRONMENT=production -DEBUG=False \ No newline at end of file +# Vector Store +VECTOR_STORE_TYPE=pgvector +EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 +RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2 + +# LLM Models +GEMINI_API_KEY=secret + +# Prompt +SYSTEM_PROMPT=You are a helpful assistant that answers questions based on the provided context. + +# API +API_PORT=8000 \ No newline at end of file diff --git a/Makefile b/Makefile index 6c65037..d0e1da0 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: create-tables help +.PHONY: install-deps create-tables help start help: @echo "Available targets:" diff --git a/README.md b/README.md index e69de29..11510ba 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,53 @@ +# PlainRAG + +PlainRAG is RAG application without LLM orchestration frameworks like Langchain or Haystack but with `LiteLLM`, `Transformers`, `FastAPI`. + +## Quick Start + +0. You need to install `uv` first +1. Copy `.env.example` to `.env` and fill in your values. +2. Run the following command to build and start the services: + +```bash +make install-deps # or uv sync +make create-tables +make start +``` + +or + +```bash +docker compoes up -d +``` + +This will use Docker Compose to start the API and database services. + +## Environment Variables + +- `DB_SERVER`: Database server hostname (default: localhost) +- `DB_USER`: Database username (default: postgres) +- `DB_PASSWORD`: Database password +- `DB_NAME`: Database name (default: chat_hub) +- `DB_PORT`: Database port (default: 5432) + +Other variables are documented in `.env.example`. + +## Components + +- **API** (`app/api/endpoints.py`): FastAPI endpoints for file ingestion, querying, and configuration. +- **Services** (`app/services/`): + - `rag_service.py`: Core RAG pipeline (ingest, query, stream answers). + - `config_service.py`: Manages model and vector store selection. + - `embedding_providers.py`: Embedding model integration (e.g., MiniLM). + - `rerankers.py`: Reranker model integration (e.g., CrossEncoder). + - `vector_stores.py`: Vector database integration (PostgreSQL/pgvector). +- **Core** (`app/core/`): + - `config.py`: Application settings and environment management. + - `registry.py`: Registry pattern for models and stores. + - `utils.py`: Text splitting utilities. + - `exception.py`: Custom exceptions. + - `interfaces.py`: Abstract interfaces for models and stores. +- **Schemas** (`app/schemas/`): + - `models.py`: Pydantic models for API requests/responses. + - `enums.py`: Enum definitions for model/store selection. + diff --git a/app/core/config.py b/app/core/config.py index 5179151..adf258e 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,83 +1,38 @@ import os -import secrets from functools import lru_cache from pathlib import Path -from pydantic import PostgresDsn, model_validator +from pydantic import Field, ValidationError from pydantic_settings import BaseSettings, SettingsConfigDict +from structlog import get_logger ROOT = Path(__file__).resolve().parent.parent.parent ENV_FILE = ROOT / ".env" +logger = get_logger() + class Settings(BaseSettings): - # Project - PROJECT_NAME: str = "Chat Hub" - VERSION: str = "0.1.0" - API_V1_STR: str = "/api/v1" - SECRET_KEY: str = secrets.token_urlsafe(32) - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days - - # Security - ALGORITHM: str = "HS256" - # Database - POSTGRES_SERVER: str = "localhost" - POSTGRES_USER: str = "postgres" - POSTGRES_PASSWORD: str = "" - POSTGRES_DB: str = "chat_hub" - POSTGRES_PORT: str = "5432" - DATABASE_URI: PostgresDsn = PostgresDsn.build( - scheme="postgresql+asyncpg", - username=POSTGRES_USER, - password=POSTGRES_PASSWORD, - host=POSTGRES_SERVER, - port=int(POSTGRES_PORT), - path=f"/{POSTGRES_DB or ''}", - ) - - @model_validator(mode="after") - def assemble_db_connection(self) -> "Settings": - if self.DATABASE_URI is None: - self.DATABASE_URI = PostgresDsn.build( - scheme="postgresql+asyncpg", - username=self.POSTGRES_USER, - password=self.POSTGRES_PASSWORD, - host=self.POSTGRES_SERVER, - port=int(self.POSTGRES_PORT), - path=f"/{self.POSTGRES_DB or ''}", - ) - return self - - # LLM Configuration - OPENAI_API_KEY: str | None = None - ANTHROPIC_API_KEY: str | None = None + DB_SERVER: str = "localhost" + DB_USER: str = "postgres" + DB_PASSWORD: str = "" + DB_NAME: str = "chat_hub" + DB_PORT: str = "5432" # Vector Store VECTOR_STORE_TYPE: str = "pgvector" # or "chroma", "faiss", etc. - EMBEDDING_MODEL: str = "sentence-transformers/all-mpnet-base-v2" + EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2" + RERANKER_MODEL: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # File uploads - UPLOAD_DIR: str = "uploads" # Relative to project root - MAX_UPLOAD_SIZE: int = 100 * 1024 * 1024 # 100MB - ALLOWED_DOCUMENT_TYPES: list[str] = ["pdf", "txt", "md"] - - # Logging - LOG_LEVEL: str = "INFO" - - # Environment - ENVIRONMENT: str = "development" - DEBUG: bool = False - TESTING: bool = False - - # Rate Limiting - RATE_LIMIT: str = "100/minute" - - # Caching - CACHE_TTL: int = 300 # 5 minutes + ALLOWED_DOCUMENT_TYPES: list[str] = Field(default=["pdf", "txt", "md"]) + # LLM Models GEMINI_API_KEY: str = "secret" + # Prompt + # may be use prompt.pys SYSTEM_PROMPT: str = "You are a helpful assistant that answers questions based on the provided context." model_config = SettingsConfigDict( @@ -94,7 +49,11 @@ def get_settings() -> Settings: Get cached settings instance. This function uses lru_cache to prevent re-reading the environment on each call. """ - return Settings() + try: + return Settings() + except ValidationError: + logger.exception("Error loading settings") + raise settings = get_settings() diff --git a/app/schemas/enums.py b/app/schemas/enums.py index e329a99..201c46d 100644 --- a/app/schemas/enums.py +++ b/app/schemas/enums.py @@ -2,12 +2,16 @@ from enum import Enum class EmbeddingModelName(str, Enum): - MiniLMEmbeddingModel = "MiniLMEmbeddingModel" + MiniLMEmbeddingModel = "sentence-transformers/all-MiniLM-L6-v2" class RerankerModelName(str, Enum): - MiniLMReranker = "MiniLMReranker" + MiniLMReranker = "cross-encoder/ms-marco-MiniLM-L-6-v2" class LLMModelName(str, Enum): GeminiFlash = "gemini/gemini-2.0-flash" + + +class VectorStoreType(str, Enum): + PGVECTOR = "pgvector" diff --git a/app/services/config_service.py b/app/services/config_service.py index 524d32d..cb26769 100644 --- a/app/services/config_service.py +++ b/app/services/config_service.py @@ -7,7 +7,7 @@ from app.core.registry import ( reranker_registry, vector_store_registry, ) -from app.schemas.enums import EmbeddingModelName, RerankerModelName +from app.schemas.enums import EmbeddingModelName, RerankerModelName, VectorStoreType from app.services.embedding_providers import MiniLMEmbeddingModel from app.services.rerankers import MiniLMReranker from app.services.vector_stores import PGVectorStore @@ -30,9 +30,11 @@ class ConfigService: def _register_models(self): """Register all default models""" - embedding_model_registry.register("MiniLMEmbeddingModel", MiniLMEmbeddingModel) - reranker_registry.register("MiniLMReranker", MiniLMReranker) - vector_store_registry.register("PGVectorStore", PGVectorStore) + embedding_model_registry.register( + EmbeddingModelName.MiniLMEmbeddingModel, MiniLMEmbeddingModel + ) + reranker_registry.register(RerankerModelName.MiniLMReranker, MiniLMReranker) + vector_store_registry.register(VectorStoreType.PGVECTOR, PGVectorStore) async def initialize_models(self): """ @@ -69,15 +71,15 @@ class ConfigService: logger.info("Default reranker model initialized: %s", reranker_model_name) vector_store_name = ( - getattr(settings, "VECTOR_STORE_TYPE", None) or "PGVectorStore" + getattr(settings, "VECTOR_STORE_TYPE", None) or VectorStoreType.PGVECTOR ) if vector_store_name not in vector_store_registry.list_available(): logger.warning( "Vector store '%s' is not valid. Falling back to default '%s'", vector_store_name, - "PGVectorStore", + VectorStoreType.PGVECTOR, ) - vector_store_name = "PGVectorStore" + vector_store_name = VectorStoreType.PGVECTOR await self.set_vector_store(vector_store_name) logger.info("Default vector store initialized: %s", vector_store_name) @@ -96,7 +98,7 @@ class ConfigService: self._loading_status["embedding_model"] = True logger.info("Attempting to load embedding model: %s", model_name) model_constructor = embedding_model_registry.get(model_name) - self._current_embedding_model = model_constructor() + self._current_embedding_model = model_constructor(model_name) settings.EMBEDDING_MODEL = model_name # Update settings except KeyError: logger.warning("Embedding model '%s' not found in registry.", model_name) @@ -128,7 +130,7 @@ class ConfigService: self._loading_status["reranker_model"] = True logger.info("Attempting to load reranker model: %s", model_name) model_constructor = reranker_registry.get(model_name) - self._current_reranker_model = model_constructor() + self._current_reranker_model = model_constructor(model_name) # settings.RERANKER_MODEL = model_name except KeyError: logger.warning("Reranker model '%s' not found in registry.", model_name) diff --git a/app/services/vector_stores.py b/app/services/vector_stores.py index 374f489..903085b 100644 --- a/app/services/vector_stores.py +++ b/app/services/vector_stores.py @@ -23,11 +23,11 @@ class PGVectorStore(VectorDB): def _get_connection(self): """Get a new database connection.""" return psycopg2.connect( - host=settings.POSTGRES_SERVER, - port=settings.POSTGRES_PORT, - user=settings.POSTGRES_USER, - password=settings.POSTGRES_PASSWORD, - dbname=settings.POSTGRES_DB, + host=settings.DB_SERVER, + port=settings.DB_PORT, + user=settings.DB_USER, + password=settings.DB_PASSWORD, + dbname=settings.DB_NAME, ) def upsert_documents(self, documents: list[dict]) -> None: diff --git a/docker-compose.yml b/docker-compose.yml index 5f0e2eb..f746747 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,13 +4,13 @@ services: context: . dockerfile: Dockerfile ports: - - "${API_PORT:-8001}:8000" + - "${API_PORT:-8000}:8000" volumes: - ./app:/app env_file: - .env environment: - - DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB} + - DATABASE_URL=postgresql://${DB_USER}:${DB_PASSWORD}@db:5432/${DB_NAME} depends_on: - db @@ -19,9 +19,9 @@ services: env_file: - .env environment: - - POSTGRES_USER=${POSTGRES_USER:-user} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - - POSTGRES_DB=${POSTGRES_DB:-mydatabase} + - DB_USER=${DB_USER:-user} + - DB_PASSWORD=${DB_PASSWORD:-password} + - DB_NAME=${DB_NAME:-mydatabase} volumes: - db_data:/var/lib/postgresql/data diff --git a/scripts/create_tables.py b/scripts/create_tables.py index 81287c2..39a9d4a 100644 --- a/scripts/create_tables.py +++ b/scripts/create_tables.py @@ -39,7 +39,7 @@ def get_db_config() -> dict: """ return { - "host": os.getenv("DB_HOST", "localhost"), + "host": os.getenv("DB_SERVER", "localhost"), "port": os.getenv("DB_PORT", "5432"), "user": os.getenv("DB_USER", "user"), "password": os.getenv("DB_PASSWORD", "password"),