mirror of
https://github.com/borbann-platform/backend-api.git
synced 2025-12-18 20:24:05 +01:00
refactor: fix typehint and add test
This commit is contained in:
parent
186c85bfde
commit
9e8eabec04
@ -4,7 +4,6 @@ Web scraper adapter using crawl4ai to extract structured data.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from crawl4ai import (
|
from crawl4ai import (
|
||||||
AsyncWebCrawler,
|
AsyncWebCrawler,
|
||||||
@ -23,19 +22,26 @@ from crawl4ai.extraction_strategy import (
|
|||||||
from .base import DataSourceAdapter
|
from .base import DataSourceAdapter
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.adapters import AdapterRecord
|
||||||
|
|
||||||
|
# pyright: reportArgumentType=false
|
||||||
|
# pyright: reportAssignmentType=false
|
||||||
|
|
||||||
|
|
||||||
class WebScraperAdapter(DataSourceAdapter):
|
class WebScraperAdapter(DataSourceAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for web scraping using crawl4ai.
|
Adapter for web scraping using crawl4ai.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = "Extract all data from the page in as much detailed as possible"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
urls: List[str],
|
urls: list[str],
|
||||||
schema_file: Optional[str] = None,
|
api_key: str,
|
||||||
prompt: Optional[str] = None,
|
schema_file: str | None = None,
|
||||||
llm_provider: str = "openai/gpt-4",
|
prompt: str = DEFAULT_PROMPT,
|
||||||
api_key: Optional[str] = None,
|
llm_provider: str = "openai/gpt-4o-mini",
|
||||||
output_format: str = "json",
|
output_format: str = "json",
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
cache_mode: str = "ENABLED",
|
cache_mode: str = "ENABLED",
|
||||||
@ -61,9 +67,11 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
self.output_format = output_format
|
self.output_format = output_format
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.cache_mode = cache_mode
|
self.cache_mode = cache_mode
|
||||||
logger.info(f"Initialized WebScraperAdapter for URLs: {urls}")
|
logger.info(
|
||||||
|
f"Initialized WebScraperAdapter for URLs: {urls} with schema_file={schema_file}, prompt={prompt}, llm_provider={llm_provider}, output_format={output_format}, verbose={verbose}, cache_mode={cache_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
def fetch(self) -> List[Dict[str, Any]]:
|
def fetch(self) -> list[AdapterRecord]:
|
||||||
"""
|
"""
|
||||||
Synchronously fetch data by running the async crawler.
|
Synchronously fetch data by running the async crawler.
|
||||||
|
|
||||||
@ -80,7 +88,7 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
logger.error(f"Web scraping failed: {e}")
|
logger.error(f"Web scraping failed: {e}")
|
||||||
raise RuntimeError(f"Web scraping failed: {e}")
|
raise RuntimeError(f"Web scraping failed: {e}")
|
||||||
|
|
||||||
async def _fetch_async(self) -> List[Dict[str, Any]]:
|
async def _fetch_async(self) -> list[AdapterRecord]:
|
||||||
"""
|
"""
|
||||||
Internal async method to perform crawling and extraction.
|
Internal async method to perform crawling and extraction.
|
||||||
"""
|
"""
|
||||||
@ -92,7 +100,7 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
|
|
||||||
# Prepare extraction strategy
|
# Prepare extraction strategy
|
||||||
llm_cfg = LLMConfig(provider=self.llm_provider, api_token=self.api_key)
|
llm_cfg = LLMConfig(provider=self.llm_provider, api_token=self.api_key)
|
||||||
extraction_strategy: Optional[ExtractionStrategy] = None
|
extraction_strategy: ExtractionStrategy | None = None
|
||||||
|
|
||||||
if self.schema_file:
|
if self.schema_file:
|
||||||
try:
|
try:
|
||||||
@ -126,7 +134,9 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
try:
|
try:
|
||||||
cache_enum = getattr(CacheMode, self.cache_mode.upper())
|
cache_enum = getattr(CacheMode, self.cache_mode.upper())
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED.")
|
logger.warning(
|
||||||
|
f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED."
|
||||||
|
)
|
||||||
cache_enum = CacheMode.ENABLED
|
cache_enum = CacheMode.ENABLED
|
||||||
|
|
||||||
run_cfg = CrawlerRunConfig(
|
run_cfg = CrawlerRunConfig(
|
||||||
@ -138,22 +148,23 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
# Execute crawl
|
# Execute crawl
|
||||||
try:
|
try:
|
||||||
logger.info(f"Crawling URLs: {self.urls}")
|
logger.info(f"Crawling URLs: {self.urls}")
|
||||||
results: List[CrawlResult] = await crawler.arun_many(
|
results: list[CrawlResult] = await crawler.arun_many(
|
||||||
urls=self.urls, config=run_cfg
|
urls=self.urls, config=run_cfg
|
||||||
)
|
)
|
||||||
logger.debug(f"Crawling completed. Results: {results}")
|
logger.info("Crawling completed.")
|
||||||
finally:
|
finally:
|
||||||
await crawler.close()
|
await crawler.close()
|
||||||
|
|
||||||
# Process crawl results
|
adapter_records: list[AdapterRecord] = []
|
||||||
records: List[Dict[str, Any]] = []
|
|
||||||
for res in results:
|
for res in results:
|
||||||
if not res.success or not res.extracted_content:
|
if not res.success or not res.extracted_content:
|
||||||
logger.warning(f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}")
|
logger.warning(
|
||||||
|
f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
content = json.loads(res.extracted_content)
|
content = json.loads(res.extracted_content)
|
||||||
logger.debug(f"Parsed extracted content for URL: {res.url}")
|
logger.info(f"Parsed extracted content for URL: {res.url}")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Failed to parse extracted content for URL: {res.url}")
|
logger.error(f"Failed to parse extracted content for URL: {res.url}")
|
||||||
continue
|
continue
|
||||||
@ -164,12 +175,19 @@ class WebScraperAdapter(DataSourceAdapter):
|
|||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
item["source_url"] = res.url
|
item["source_url"] = res.url
|
||||||
records.extend(content)
|
adapter_records.append(
|
||||||
|
AdapterRecord(source="scrape", data=item)
|
||||||
|
)
|
||||||
elif isinstance(content, dict):
|
elif isinstance(content, dict):
|
||||||
content["source_url"] = res.url
|
content["source_url"] = res.url
|
||||||
records.append(content)
|
adapter_records.append(AdapterRecord(source="scrape", data=content))
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Extracted content for URL {res.url} is not a list or dict: {type(content)}")
|
logger.warning(
|
||||||
|
f"Extracted content for URL {res.url} is not a list or dict: {type(content)}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Web scraping completed. Extracted {len(records)} records.")
|
logger.info(
|
||||||
return records
|
f"Web scraping completed. Extracted {len(adapter_records)} records."
|
||||||
|
)
|
||||||
|
logger.debug(adapter_records)
|
||||||
|
return adapter_records
|
||||||
|
|||||||
@ -11,7 +11,12 @@ dependencies = [
|
|||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"pytest>=8.3.5",
|
"pytest>=8.3.5",
|
||||||
|
"pytest-asyncio>=0.26.0",
|
||||||
"python-dotenv>=1.1.0",
|
"python-dotenv>=1.1.0",
|
||||||
"responses>=0.25.7",
|
"responses>=0.25.7",
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
82
pipeline/tests/test_scraper_adapter.py
Normal file
82
pipeline/tests/test_scraper_adapter.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, AsyncMock, MagicMock, mock_open
|
||||||
|
|
||||||
|
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
|
||||||
|
from models.adapters import AdapterRecord
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_with_llm_extraction():
|
||||||
|
"""
|
||||||
|
Test fetching data using LLM extraction.
|
||||||
|
"""
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.url = "http://example.com"
|
||||||
|
mock_result.extracted_content = json.dumps({"title": "Example"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
|
||||||
|
) as mock_crawler_cls:
|
||||||
|
mock_crawler = AsyncMock()
|
||||||
|
mock_crawler_cls.return_value = mock_crawler
|
||||||
|
mock_crawler.arun_many.return_value = [mock_result]
|
||||||
|
|
||||||
|
adapter = WebScraperAdapter(
|
||||||
|
urls=["http://example.com"],
|
||||||
|
api_key="fake-key",
|
||||||
|
schema_file=None,
|
||||||
|
prompt="Extract data",
|
||||||
|
)
|
||||||
|
|
||||||
|
records = await adapter._fetch_async()
|
||||||
|
|
||||||
|
assert isinstance(records, list)
|
||||||
|
assert isinstance(records[0], AdapterRecord)
|
||||||
|
assert records[0].data["title"] == "Example"
|
||||||
|
assert records[0].data["source_url"] == "http://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_with_schema_file():
|
||||||
|
"""
|
||||||
|
Test fetching data using schema file.
|
||||||
|
"""
|
||||||
|
schema = {"title": {"selector": "h1"}}
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.url = "http://example.com"
|
||||||
|
mock_result.extracted_content = json.dumps({"title": "Example"})
|
||||||
|
|
||||||
|
with patch("builtins.open", mock_open(read_data=json.dumps(schema))):
|
||||||
|
with patch(
|
||||||
|
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
|
||||||
|
) as mock_crawler_cls:
|
||||||
|
mock_crawler = AsyncMock()
|
||||||
|
mock_crawler_cls.return_value = mock_crawler
|
||||||
|
mock_crawler.arun_many.return_value = [mock_result]
|
||||||
|
|
||||||
|
adapter = WebScraperAdapter(
|
||||||
|
urls=["http://example.com"],
|
||||||
|
api_key="fake-key",
|
||||||
|
schema_file="schema.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
records = await adapter._fetch_async()
|
||||||
|
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0].data["title"] == "Example"
|
||||||
|
assert records[0].data["source_url"] == "http://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_sync_calls_async():
|
||||||
|
"""
|
||||||
|
Test that the sync fetch method calls the async fetch method.
|
||||||
|
"""
|
||||||
|
adapter = WebScraperAdapter(
|
||||||
|
urls=["http://example.com"], api_key="fake-key", prompt="Extract data"
|
||||||
|
)
|
||||||
|
with patch.object(adapter, "_fetch_async", new=AsyncMock(return_value=[])):
|
||||||
|
result = adapter.fetch()
|
||||||
|
assert result == []
|
||||||
40
pipeline/tests/test_scraper_adapter_integration.py
Normal file
40
pipeline/tests/test_scraper_adapter_integration.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
|
||||||
|
from models.adapters import AdapterRecord
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_web_scraper_adapter_with_llm():
|
||||||
|
"""
|
||||||
|
Integration test for WebScraperAdapter using LLM extraction on books.toscrape.com.
|
||||||
|
Requires OPENAI_API_KEY to be set in the environment.
|
||||||
|
"""
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
assert api_key is not None, "OPENAI_API_KEY environment variable must be set."
|
||||||
|
|
||||||
|
test_url = (
|
||||||
|
"https://books.toscrape.com/catalogue/a-light-in-the-attic_1000/index.html"
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = WebScraperAdapter(
|
||||||
|
urls=[test_url],
|
||||||
|
api_key=api_key,
|
||||||
|
prompt="Extract book title, price, availability, and description.",
|
||||||
|
llm_provider="openai/gpt-4o-mini",
|
||||||
|
schema_file=None,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
records = await adapter._fetch_async()
|
||||||
|
|
||||||
|
assert isinstance(records, list)
|
||||||
|
assert len(records) > 0
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
assert isinstance(record, AdapterRecord)
|
||||||
|
assert "source_url" in record.data
|
||||||
|
assert record.data["source_url"] == test_url
|
||||||
|
|
||||||
|
print("✅ Extracted data:", record.data)
|
||||||
@ -308,6 +308,7 @@ dependencies = [
|
|||||||
{ name = "loguru" },
|
{ name = "loguru" },
|
||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "responses" },
|
{ name = "responses" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
@ -321,6 +322,7 @@ requires-dist = [
|
|||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
{ name = "pandas", specifier = ">=2.2.3" },
|
{ name = "pandas", specifier = ">=2.2.3" },
|
||||||
{ name = "pytest", specifier = ">=8.3.5" },
|
{ name = "pytest", specifier = ">=8.3.5" },
|
||||||
|
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||||
{ name = "responses", specifier = ">=0.25.7" },
|
{ name = "responses", specifier = ">=0.25.7" },
|
||||||
{ name = "rich", specifier = ">=14.0.0" },
|
{ name = "rich", specifier = ">=14.0.0" },
|
||||||
@ -1367,6 +1369,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-asyncio"
|
||||||
|
version = "0.26.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user