feat: add FileTypeIngestionError exception and enforce file type validation in RAG service

This commit is contained in:
Sosokker 2025-06-27 23:25:17 +07:00
parent aec7ca824c
commit 43f37f7ad2
4 changed files with 9 additions and 2 deletions

View File

@ -8,3 +8,7 @@ class DocumentExtractionError(Exception):
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
"""Exception raised when model is not found.""" """Exception raised when model is not found."""
class FileTypeIngestionError(Exception):
"""Exception raised when user upload unsupported file type."""

View File

@ -5,7 +5,7 @@ from app.core.interfaces import EmbeddingModel
class MiniLMEmbeddingModel(EmbeddingModel): class MiniLMEmbeddingModel(EmbeddingModel):
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name) self.model = SentenceTransformer(model_name)
def embed_documents(self, texts: list[str]) -> list[np.ndarray]: def embed_documents(self, texts: list[str]) -> list[np.ndarray]:

View File

@ -9,6 +9,7 @@ from PyPDF2.errors import PyPdfError
from structlog import get_logger from structlog import get_logger
from app.core.config import settings from app.core.config import settings
from app.core.exception import FileTypeIngestionError
from app.core.interfaces import EmbeddingModel, Reranker, VectorDB from app.core.interfaces import EmbeddingModel, Reranker, VectorDB
from app.core.utils import RecursiveCharacterTextSplitter from app.core.utils import RecursiveCharacterTextSplitter
from app.schemas.enums import LLMModelName from app.schemas.enums import LLMModelName
@ -74,6 +75,8 @@ Answer:"""
path = Path(file_path) path = Path(file_path)
ext = path.suffix ext = path.suffix
text = "" text = ""
if ext[1:] not in settings.ALLOWED_DOCUMENT_TYPES:
raise FileTypeIngestionError("Only support PDF, MD and TXT files")
if ext == ".pdf": if ext == ".pdf":
try: try:
reader = PdfReader(str(file_path)) reader = PdfReader(str(file_path))

View File

@ -11,7 +11,7 @@ logger = get_logger()
class MiniLMReranker(Reranker): class MiniLMReranker(Reranker):
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): def __init__(self, model_name: str):
try: try:
self.model = CrossEncoder(model_name) self.model = CrossEncoder(model_name)
except Exception as er: except Exception as er: