mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-18 14:34:05 +01:00
optimize: add bulk insert
This commit is contained in:
parent
cf3fe50f1a
commit
425c584101
@ -76,7 +76,7 @@ class Settings(BaseSettings):
|
||||
# Caching
|
||||
CACHE_TTL: int = 300 # 5 minutes
|
||||
|
||||
GEMINI_API_KEY: str
|
||||
GEMINI_API_KEY: str = "secret"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=ENV_FILE,
|
||||
|
||||
@ -6,6 +6,8 @@ import litellm
|
||||
import numpy as np
|
||||
import psycopg2
|
||||
from dotenv import load_dotenv
|
||||
from psycopg2 import extras
|
||||
from psycopg2.extensions import AsIs, register_adapter
|
||||
from PyPDF2 import PdfReader
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from structlog import get_logger
|
||||
@ -14,6 +16,7 @@ from app.core.config import settings
|
||||
from app.core.exception import DocumentExtractionError, DocumentInsertionError
|
||||
from app.core.utils import RecursiveCharacterTextSplitter
|
||||
|
||||
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
|
||||
logger = get_logger()
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
@ -79,47 +82,64 @@ Answer:"""
|
||||
)
|
||||
return text_splitter.split_text(text)
|
||||
|
||||
def _get_embedding(self, text: str) -> np.ndarray:
|
||||
def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a text chunk.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
show_progress_bar: Whether to show a progress bar
|
||||
|
||||
Returns:
|
||||
Numpy array containing the embedding vector
|
||||
|
||||
"""
|
||||
return EMBEDDING_MODEL.encode(text, convert_to_numpy=True)
|
||||
return EMBEDDING_MODEL.encode(
|
||||
text, convert_to_numpy=True, show_progress_bar=show_progress_bar
|
||||
)
|
||||
|
||||
def _store_document(self, content: str, embedding: np.ndarray, source: str) -> int:
|
||||
def _store_document(
|
||||
self, contents: list[str], embeddings: list[np.ndarray], source: str
|
||||
) -> int:
|
||||
"""
|
||||
Store a document chunk in the database.
|
||||
|
||||
Args:
|
||||
content: Text content of the chunk
|
||||
embedding: Embedding vector of the chunk
|
||||
contents: List of text content of the chunk
|
||||
embeddings: List of embedding vectors of the chunk
|
||||
source: Source file path
|
||||
|
||||
Returns:
|
||||
ID of the inserted document
|
||||
|
||||
"""
|
||||
data_to_insert = [
|
||||
(chunk, f"[{', '.join(map(str, embedding))}]", source)
|
||||
for chunk, embedding in zip(contents, embeddings, strict=True)
|
||||
]
|
||||
|
||||
query = """
|
||||
INSERT INTO documents (content, embedding, source)
|
||||
VALUES %s
|
||||
RETURNING id
|
||||
"""
|
||||
with self.db_conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO documents (content, embedding, source)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(content, embedding.tolist(), source),
|
||||
extras.execute_values(
|
||||
cursor,
|
||||
query,
|
||||
data_to_insert,
|
||||
template="(%s, %s::vector, %s)",
|
||||
page_size=100,
|
||||
)
|
||||
doc_id = cursor.fetchone()
|
||||
if doc_id is None:
|
||||
err = "Failed to insert document into database"
|
||||
raise DocumentInsertionError(err)
|
||||
inserted_ids = [row[0] for row in cursor.fetchall()]
|
||||
self.db_conn.commit()
|
||||
return doc_id[0]
|
||||
|
||||
if not inserted_ids:
|
||||
raise DocumentInsertionError("No documents were inserted.")
|
||||
|
||||
logger.info("Successfully bulk-ingested %d documents", len(inserted_ids))
|
||||
logger.info("Inserted document IDs: %s", inserted_ids)
|
||||
return inserted_ids[0]
|
||||
|
||||
def _extract_text_from_pdf(self, pdf_path: str) -> str:
|
||||
"""
|
||||
@ -182,10 +202,8 @@ Answer:"""
|
||||
chunks = self._split_text(text)
|
||||
logger.info("Split PDF into %d chunks", len(chunks))
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
logger.info("Processing chunk %d/%d", i, len(chunks))
|
||||
embedding = self._get_embedding(chunk)
|
||||
self._store_document(chunk, embedding, filename)
|
||||
embeddings = self._get_embedding(chunks, show_progress_bar=True)
|
||||
self._store_document(chunks, embeddings, filename)
|
||||
|
||||
logger.info("Successfully processed %d chunks from %s", len(chunks), filename)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user