mirror of
https://github.com/borbann-platform/backend-api.git
synced 2025-12-18 12:14:05 +01:00
refactor: fix typehint and add test
This commit is contained in:
parent
186c85bfde
commit
9e8eabec04
@ -4,7 +4,6 @@ Web scraper adapter using crawl4ai to extract structured data.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
@ -23,19 +22,26 @@ from crawl4ai.extraction_strategy import (
|
||||
from .base import DataSourceAdapter
|
||||
from loguru import logger
|
||||
|
||||
from models.adapters import AdapterRecord
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
# pyright: reportAssignmentType=false
|
||||
|
||||
|
||||
class WebScraperAdapter(DataSourceAdapter):
|
||||
"""
|
||||
Adapter for web scraping using crawl4ai.
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT = "Extract all data from the page in as much detailed as possible"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: List[str],
|
||||
schema_file: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
llm_provider: str = "openai/gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
urls: list[str],
|
||||
api_key: str,
|
||||
schema_file: str | None = None,
|
||||
prompt: str = DEFAULT_PROMPT,
|
||||
llm_provider: str = "openai/gpt-4o-mini",
|
||||
output_format: str = "json",
|
||||
verbose: bool = False,
|
||||
cache_mode: str = "ENABLED",
|
||||
@ -61,9 +67,11 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
self.output_format = output_format
|
||||
self.verbose = verbose
|
||||
self.cache_mode = cache_mode
|
||||
logger.info(f"Initialized WebScraperAdapter for URLs: {urls}")
|
||||
logger.info(
|
||||
f"Initialized WebScraperAdapter for URLs: {urls} with schema_file={schema_file}, prompt={prompt}, llm_provider={llm_provider}, output_format={output_format}, verbose={verbose}, cache_mode={cache_mode}"
|
||||
)
|
||||
|
||||
def fetch(self) -> List[Dict[str, Any]]:
|
||||
def fetch(self) -> list[AdapterRecord]:
|
||||
"""
|
||||
Synchronously fetch data by running the async crawler.
|
||||
|
||||
@ -80,7 +88,7 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
logger.error(f"Web scraping failed: {e}")
|
||||
raise RuntimeError(f"Web scraping failed: {e}")
|
||||
|
||||
async def _fetch_async(self) -> List[Dict[str, Any]]:
|
||||
async def _fetch_async(self) -> list[AdapterRecord]:
|
||||
"""
|
||||
Internal async method to perform crawling and extraction.
|
||||
"""
|
||||
@ -92,7 +100,7 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
|
||||
# Prepare extraction strategy
|
||||
llm_cfg = LLMConfig(provider=self.llm_provider, api_token=self.api_key)
|
||||
extraction_strategy: Optional[ExtractionStrategy] = None
|
||||
extraction_strategy: ExtractionStrategy | None = None
|
||||
|
||||
if self.schema_file:
|
||||
try:
|
||||
@ -126,7 +134,9 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
try:
|
||||
cache_enum = getattr(CacheMode, self.cache_mode.upper())
|
||||
except AttributeError:
|
||||
logger.warning(f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED.")
|
||||
logger.warning(
|
||||
f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED."
|
||||
)
|
||||
cache_enum = CacheMode.ENABLED
|
||||
|
||||
run_cfg = CrawlerRunConfig(
|
||||
@ -138,22 +148,23 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
# Execute crawl
|
||||
try:
|
||||
logger.info(f"Crawling URLs: {self.urls}")
|
||||
results: List[CrawlResult] = await crawler.arun_many(
|
||||
results: list[CrawlResult] = await crawler.arun_many(
|
||||
urls=self.urls, config=run_cfg
|
||||
)
|
||||
logger.debug(f"Crawling completed. Results: {results}")
|
||||
logger.info("Crawling completed.")
|
||||
finally:
|
||||
await crawler.close()
|
||||
|
||||
# Process crawl results
|
||||
records: List[Dict[str, Any]] = []
|
||||
adapter_records: list[AdapterRecord] = []
|
||||
for res in results:
|
||||
if not res.success or not res.extracted_content:
|
||||
logger.warning(f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}")
|
||||
logger.warning(
|
||||
f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
content = json.loads(res.extracted_content)
|
||||
logger.debug(f"Parsed extracted content for URL: {res.url}")
|
||||
logger.info(f"Parsed extracted content for URL: {res.url}")
|
||||
except Exception:
|
||||
logger.error(f"Failed to parse extracted content for URL: {res.url}")
|
||||
continue
|
||||
@ -164,12 +175,19 @@ class WebScraperAdapter(DataSourceAdapter):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
item["source_url"] = res.url
|
||||
records.extend(content)
|
||||
adapter_records.append(
|
||||
AdapterRecord(source="scrape", data=item)
|
||||
)
|
||||
elif isinstance(content, dict):
|
||||
content["source_url"] = res.url
|
||||
records.append(content)
|
||||
adapter_records.append(AdapterRecord(source="scrape", data=content))
|
||||
else:
|
||||
logger.warning(f"Extracted content for URL {res.url} is not a list or dict: {type(content)}")
|
||||
logger.warning(
|
||||
f"Extracted content for URL {res.url} is not a list or dict: {type(content)}"
|
||||
)
|
||||
|
||||
logger.info(f"Web scraping completed. Extracted {len(records)} records.")
|
||||
return records
|
||||
logger.info(
|
||||
f"Web scraping completed. Extracted {len(adapter_records)} records."
|
||||
)
|
||||
logger.debug(adapter_records)
|
||||
return adapter_records
|
||||
|
||||
@ -11,7 +11,12 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"pandas>=2.2.3",
|
||||
"pytest>=8.3.5",
|
||||
"pytest-asyncio>=0.26.0",
|
||||
"python-dotenv>=1.1.0",
|
||||
"responses>=0.25.7",
|
||||
"rich>=14.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
82
pipeline/tests/test_scraper_adapter.py
Normal file
82
pipeline/tests/test_scraper_adapter.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch, AsyncMock, MagicMock, mock_open
|
||||
|
||||
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
|
||||
from models.adapters import AdapterRecord
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_with_llm_extraction():
|
||||
"""
|
||||
Test fetching data using LLM extraction.
|
||||
"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.url = "http://example.com"
|
||||
mock_result.extracted_content = json.dumps({"title": "Example"})
|
||||
|
||||
with patch(
|
||||
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
|
||||
) as mock_crawler_cls:
|
||||
mock_crawler = AsyncMock()
|
||||
mock_crawler_cls.return_value = mock_crawler
|
||||
mock_crawler.arun_many.return_value = [mock_result]
|
||||
|
||||
adapter = WebScraperAdapter(
|
||||
urls=["http://example.com"],
|
||||
api_key="fake-key",
|
||||
schema_file=None,
|
||||
prompt="Extract data",
|
||||
)
|
||||
|
||||
records = await adapter._fetch_async()
|
||||
|
||||
assert isinstance(records, list)
|
||||
assert isinstance(records[0], AdapterRecord)
|
||||
assert records[0].data["title"] == "Example"
|
||||
assert records[0].data["source_url"] == "http://example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_with_schema_file():
|
||||
"""
|
||||
Test fetching data using schema file.
|
||||
"""
|
||||
schema = {"title": {"selector": "h1"}}
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.url = "http://example.com"
|
||||
mock_result.extracted_content = json.dumps({"title": "Example"})
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=json.dumps(schema))):
|
||||
with patch(
|
||||
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
|
||||
) as mock_crawler_cls:
|
||||
mock_crawler = AsyncMock()
|
||||
mock_crawler_cls.return_value = mock_crawler
|
||||
mock_crawler.arun_many.return_value = [mock_result]
|
||||
|
||||
adapter = WebScraperAdapter(
|
||||
urls=["http://example.com"],
|
||||
api_key="fake-key",
|
||||
schema_file="schema.json",
|
||||
)
|
||||
|
||||
records = await adapter._fetch_async()
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].data["title"] == "Example"
|
||||
assert records[0].data["source_url"] == "http://example.com"
|
||||
|
||||
|
||||
def test_fetch_sync_calls_async():
|
||||
"""
|
||||
Test that the sync fetch method calls the async fetch method.
|
||||
"""
|
||||
adapter = WebScraperAdapter(
|
||||
urls=["http://example.com"], api_key="fake-key", prompt="Extract data"
|
||||
)
|
||||
with patch.object(adapter, "_fetch_async", new=AsyncMock(return_value=[])):
|
||||
result = adapter.fetch()
|
||||
assert result == []
|
||||
40
pipeline/tests/test_scraper_adapter_integration.py
Normal file
40
pipeline/tests/test_scraper_adapter_integration.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import pytest
|
||||
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
|
||||
from models.adapters import AdapterRecord
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_scraper_adapter_with_llm():
|
||||
"""
|
||||
Integration test for WebScraperAdapter using LLM extraction on books.toscrape.com.
|
||||
Requires OPENAI_API_KEY to be set in the environment.
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None, "OPENAI_API_KEY environment variable must be set."
|
||||
|
||||
test_url = (
|
||||
"https://books.toscrape.com/catalogue/a-light-in-the-attic_1000/index.html"
|
||||
)
|
||||
|
||||
adapter = WebScraperAdapter(
|
||||
urls=[test_url],
|
||||
api_key=api_key,
|
||||
prompt="Extract book title, price, availability, and description.",
|
||||
llm_provider="openai/gpt-4o-mini",
|
||||
schema_file=None,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
records = await adapter._fetch_async()
|
||||
|
||||
assert isinstance(records, list)
|
||||
assert len(records) > 0
|
||||
|
||||
for record in records:
|
||||
assert isinstance(record, AdapterRecord)
|
||||
assert "source_url" in record.data
|
||||
assert record.data["source_url"] == test_url
|
||||
|
||||
print("✅ Extracted data:", record.data)
|
||||
@ -308,6 +308,7 @@ dependencies = [
|
||||
{ name = "loguru" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "responses" },
|
||||
{ name = "rich" },
|
||||
@ -321,6 +322,7 @@ requires-dist = [
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||
{ name = "responses", specifier = ">=0.25.7" },
|
||||
{ name = "rich", specifier = ">=14.0.0" },
|
||||
@ -1367,6 +1369,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.26.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user