From 43f37f7ad2dfe73d2362fca1904b17b24f3bb86d Mon Sep 17 00:00:00 2001 From: Sosokker Date: Fri, 27 Jun 2025 23:25:17 +0700 Subject: [PATCH] feat: add FileTypeIngestionError exception and enforce file type validation in RAG service --- app/core/exception.py | 4 ++++ app/services/embedding_providers.py | 2 +- app/services/rag_service.py | 3 +++ app/services/rerankers.py | 2 +- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/app/core/exception.py b/app/core/exception.py index 1383495..273e6b9 100644 --- a/app/core/exception.py +++ b/app/core/exception.py @@ -8,3 +8,7 @@ class DocumentExtractionError(Exception): class ModelNotFoundError(Exception): """Exception raised when model is not found.""" + + +class FileTypeIngestionError(Exception): + """Exception raised when user upload unsupported file type.""" diff --git a/app/services/embedding_providers.py b/app/services/embedding_providers.py index 0cbcc70..3efb8cd 100644 --- a/app/services/embedding_providers.py +++ b/app/services/embedding_providers.py @@ -5,7 +5,7 @@ from app.core.interfaces import EmbeddingModel class MiniLMEmbeddingModel(EmbeddingModel): - def __init__(self, model_name: str = "all-MiniLM-L6-v2"): + def __init__(self, model_name: str): self.model = SentenceTransformer(model_name) def embed_documents(self, texts: list[str]) -> list[np.ndarray]: diff --git a/app/services/rag_service.py b/app/services/rag_service.py index df6c965..860185d 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -9,6 +9,7 @@ from PyPDF2.errors import PyPdfError from structlog import get_logger from app.core.config import settings +from app.core.exception import FileTypeIngestionError from app.core.interfaces import EmbeddingModel, Reranker, VectorDB from app.core.utils import RecursiveCharacterTextSplitter from app.schemas.enums import LLMModelName @@ -74,6 +75,8 @@ Answer:""" path = Path(file_path) ext = path.suffix text = "" + if ext[1:] not in settings.ALLOWED_DOCUMENT_TYPES: + raise FileTypeIngestionError("Only support PDF, MD and TXT files") if ext == ".pdf": try: reader = PdfReader(str(file_path)) diff --git a/app/services/rerankers.py b/app/services/rerankers.py index 6d9986e..27550f2 100644 --- a/app/services/rerankers.py +++ b/app/services/rerankers.py @@ -11,7 +11,7 @@ logger = get_logger() class MiniLMReranker(Reranker): - def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): + def __init__(self, model_name: str): try: self.model = CrossEncoder(model_name) except Exception as er: