mirror of
https://github.com/Sosokker/plain-rag.git
synced 2025-12-19 14:54:05 +01:00
optimize: add bulk insert
This commit is contained in:
parent
cf3fe50f1a
commit
425c584101
@ -76,7 +76,7 @@ class Settings(BaseSettings):
|
|||||||
# Caching
|
# Caching
|
||||||
CACHE_TTL: int = 300 # 5 minutes
|
CACHE_TTL: int = 300 # 5 minutes
|
||||||
|
|
||||||
GEMINI_API_KEY: str
|
GEMINI_API_KEY: str = "secret"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=ENV_FILE,
|
env_file=ENV_FILE,
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import litellm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from psycopg2 import extras
|
||||||
|
from psycopg2.extensions import AsIs, register_adapter
|
||||||
from PyPDF2 import PdfReader
|
from PyPDF2 import PdfReader
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
@ -14,6 +16,7 @@ from app.core.config import settings
|
|||||||
from app.core.exception import DocumentExtractionError, DocumentInsertionError
|
from app.core.exception import DocumentExtractionError, DocumentInsertionError
|
||||||
from app.core.utils import RecursiveCharacterTextSplitter
|
from app.core.utils import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
register_adapter(np.ndarray, AsIs) # for psycopg2 adapt
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# pyright: reportArgumentType=false
|
# pyright: reportArgumentType=false
|
||||||
@ -79,47 +82,64 @@ Answer:"""
|
|||||||
)
|
)
|
||||||
return text_splitter.split_text(text)
|
return text_splitter.split_text(text)
|
||||||
|
|
||||||
def _get_embedding(self, text: str) -> np.ndarray:
|
def _get_embedding(self, text: str, show_progress_bar: bool = False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Generate embedding for a text chunk.
|
Generate embedding for a text chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Input text to embed
|
text: Input text to embed
|
||||||
|
show_progress_bar: Whether to show a progress bar
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy array containing the embedding vector
|
Numpy array containing the embedding vector
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return EMBEDDING_MODEL.encode(text, convert_to_numpy=True)
|
return EMBEDDING_MODEL.encode(
|
||||||
|
text, convert_to_numpy=True, show_progress_bar=show_progress_bar
|
||||||
|
)
|
||||||
|
|
||||||
def _store_document(self, content: str, embedding: np.ndarray, source: str) -> int:
|
def _store_document(
|
||||||
|
self, contents: list[str], embeddings: list[np.ndarray], source: str
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Store a document chunk in the database.
|
Store a document chunk in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Text content of the chunk
|
contents: List of text content of the chunk
|
||||||
embedding: Embedding vector of the chunk
|
embeddings: List of embedding vectors of the chunk
|
||||||
source: Source file path
|
source: Source file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ID of the inserted document
|
ID of the inserted document
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with self.db_conn.cursor() as cursor:
|
data_to_insert = [
|
||||||
cursor.execute(
|
(chunk, f"[{', '.join(map(str, embedding))}]", source)
|
||||||
"""
|
for chunk, embedding in zip(contents, embeddings, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
query = """
|
||||||
INSERT INTO documents (content, embedding, source)
|
INSERT INTO documents (content, embedding, source)
|
||||||
VALUES (%s, %s, %s)
|
VALUES %s
|
||||||
RETURNING id
|
RETURNING id
|
||||||
""",
|
"""
|
||||||
(content, embedding.tolist(), source),
|
with self.db_conn.cursor() as cursor:
|
||||||
|
extras.execute_values(
|
||||||
|
cursor,
|
||||||
|
query,
|
||||||
|
data_to_insert,
|
||||||
|
template="(%s, %s::vector, %s)",
|
||||||
|
page_size=100,
|
||||||
)
|
)
|
||||||
doc_id = cursor.fetchone()
|
inserted_ids = [row[0] for row in cursor.fetchall()]
|
||||||
if doc_id is None:
|
|
||||||
err = "Failed to insert document into database"
|
|
||||||
raise DocumentInsertionError(err)
|
|
||||||
self.db_conn.commit()
|
self.db_conn.commit()
|
||||||
return doc_id[0]
|
|
||||||
|
if not inserted_ids:
|
||||||
|
raise DocumentInsertionError("No documents were inserted.")
|
||||||
|
|
||||||
|
logger.info("Successfully bulk-ingested %d documents", len(inserted_ids))
|
||||||
|
logger.info("Inserted document IDs: %s", inserted_ids)
|
||||||
|
return inserted_ids[0]
|
||||||
|
|
||||||
def _extract_text_from_pdf(self, pdf_path: str) -> str:
|
def _extract_text_from_pdf(self, pdf_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -182,10 +202,8 @@ Answer:"""
|
|||||||
chunks = self._split_text(text)
|
chunks = self._split_text(text)
|
||||||
logger.info("Split PDF into %d chunks", len(chunks))
|
logger.info("Split PDF into %d chunks", len(chunks))
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks, 1):
|
embeddings = self._get_embedding(chunks, show_progress_bar=True)
|
||||||
logger.info("Processing chunk %d/%d", i, len(chunks))
|
self._store_document(chunks, embeddings, filename)
|
||||||
embedding = self._get_embedding(chunk)
|
|
||||||
self._store_document(chunk, embedding, filename)
|
|
||||||
|
|
||||||
logger.info("Successfully processed %d chunks from %s", len(chunks), filename)
|
logger.info("Successfully processed %d chunks from %s", len(chunks), filename)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user