refactor: fix typehint and add test

This commit is contained in:
Sosokker 2025-05-12 16:12:28 +07:00
parent 186c85bfde
commit 9e8eabec04
5 changed files with 181 additions and 22 deletions

View File

@ -4,7 +4,6 @@ Web scraper adapter using crawl4ai to extract structured data.
import asyncio
import json
from typing import List, Dict, Any, Optional
from crawl4ai import (
AsyncWebCrawler,
@ -23,19 +22,26 @@ from crawl4ai.extraction_strategy import (
from .base import DataSourceAdapter
from loguru import logger
from models.adapters import AdapterRecord
# pyright: reportArgumentType=false
# pyright: reportAssignmentType=false
class WebScraperAdapter(DataSourceAdapter):
"""
Adapter for web scraping using crawl4ai.
"""
DEFAULT_PROMPT = "Extract all data from the page in as much detailed as possible"
def __init__(
self,
urls: List[str],
schema_file: Optional[str] = None,
prompt: Optional[str] = None,
llm_provider: str = "openai/gpt-4",
api_key: Optional[str] = None,
urls: list[str],
api_key: str,
schema_file: str | None = None,
prompt: str = DEFAULT_PROMPT,
llm_provider: str = "openai/gpt-4o-mini",
output_format: str = "json",
verbose: bool = False,
cache_mode: str = "ENABLED",
@ -61,9 +67,11 @@ class WebScraperAdapter(DataSourceAdapter):
self.output_format = output_format
self.verbose = verbose
self.cache_mode = cache_mode
logger.info(f"Initialized WebScraperAdapter for URLs: {urls}")
logger.info(
f"Initialized WebScraperAdapter for URLs: {urls} with schema_file={schema_file}, prompt={prompt}, llm_provider={llm_provider}, output_format={output_format}, verbose={verbose}, cache_mode={cache_mode}"
)
def fetch(self) -> List[Dict[str, Any]]:
def fetch(self) -> list[AdapterRecord]:
"""
Synchronously fetch data by running the async crawler.
@ -80,7 +88,7 @@ class WebScraperAdapter(DataSourceAdapter):
logger.error(f"Web scraping failed: {e}")
raise RuntimeError(f"Web scraping failed: {e}")
async def _fetch_async(self) -> List[Dict[str, Any]]:
async def _fetch_async(self) -> list[AdapterRecord]:
"""
Internal async method to perform crawling and extraction.
"""
@ -92,7 +100,7 @@ class WebScraperAdapter(DataSourceAdapter):
# Prepare extraction strategy
llm_cfg = LLMConfig(provider=self.llm_provider, api_token=self.api_key)
extraction_strategy: Optional[ExtractionStrategy] = None
extraction_strategy: ExtractionStrategy | None = None
if self.schema_file:
try:
@ -126,7 +134,9 @@ class WebScraperAdapter(DataSourceAdapter):
try:
cache_enum = getattr(CacheMode, self.cache_mode.upper())
except AttributeError:
logger.warning(f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED.")
logger.warning(
f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED."
)
cache_enum = CacheMode.ENABLED
run_cfg = CrawlerRunConfig(
@ -138,22 +148,23 @@ class WebScraperAdapter(DataSourceAdapter):
# Execute crawl
try:
logger.info(f"Crawling URLs: {self.urls}")
results: List[CrawlResult] = await crawler.arun_many(
results: list[CrawlResult] = await crawler.arun_many(
urls=self.urls, config=run_cfg
)
logger.debug(f"Crawling completed. Results: {results}")
logger.info("Crawling completed.")
finally:
await crawler.close()
# Process crawl results
records: List[Dict[str, Any]] = []
adapter_records: list[AdapterRecord] = []
for res in results:
if not res.success or not res.extracted_content:
logger.warning(f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}")
logger.warning(
f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}"
)
continue
try:
content = json.loads(res.extracted_content)
logger.debug(f"Parsed extracted content for URL: {res.url}")
logger.info(f"Parsed extracted content for URL: {res.url}")
except Exception:
logger.error(f"Failed to parse extracted content for URL: {res.url}")
continue
@ -164,12 +175,19 @@ class WebScraperAdapter(DataSourceAdapter):
for item in content:
if isinstance(item, dict):
item["source_url"] = res.url
records.extend(content)
adapter_records.append(
AdapterRecord(source="scrape", data=item)
)
elif isinstance(content, dict):
content["source_url"] = res.url
records.append(content)
adapter_records.append(AdapterRecord(source="scrape", data=content))
else:
logger.warning(f"Extracted content for URL {res.url} is not a list or dict: {type(content)}")
logger.warning(
f"Extracted content for URL {res.url} is not a list or dict: {type(content)}"
)
logger.info(f"Web scraping completed. Extracted {len(records)} records.")
return records
logger.info(
f"Web scraping completed. Extracted {len(adapter_records)} records."
)
logger.debug(adapter_records)
return adapter_records

View File

@ -11,7 +11,12 @@ dependencies = [
"loguru>=0.7.3",
"pandas>=2.2.3",
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"python-dotenv>=1.1.0",
"responses>=0.25.7",
"rich>=14.0.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

View File

@ -0,0 +1,82 @@
import pytest
import json
from unittest.mock import patch, AsyncMock, MagicMock, mock_open
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
from models.adapters import AdapterRecord
@pytest.mark.asyncio
async def test_fetch_with_llm_extraction():
"""
Test fetching data using LLM extraction.
"""
mock_result = MagicMock()
mock_result.success = True
mock_result.url = "http://example.com"
mock_result.extracted_content = json.dumps({"title": "Example"})
with patch(
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
) as mock_crawler_cls:
mock_crawler = AsyncMock()
mock_crawler_cls.return_value = mock_crawler
mock_crawler.arun_many.return_value = [mock_result]
adapter = WebScraperAdapter(
urls=["http://example.com"],
api_key="fake-key",
schema_file=None,
prompt="Extract data",
)
records = await adapter._fetch_async()
assert isinstance(records, list)
assert isinstance(records[0], AdapterRecord)
assert records[0].data["title"] == "Example"
assert records[0].data["source_url"] == "http://example.com"
@pytest.mark.asyncio
async def test_fetch_with_schema_file():
"""
Test fetching data using schema file.
"""
schema = {"title": {"selector": "h1"}}
mock_result = MagicMock()
mock_result.success = True
mock_result.url = "http://example.com"
mock_result.extracted_content = json.dumps({"title": "Example"})
with patch("builtins.open", mock_open(read_data=json.dumps(schema))):
with patch(
"ingestion.adapters.web_scraper_adapter.AsyncWebCrawler"
) as mock_crawler_cls:
mock_crawler = AsyncMock()
mock_crawler_cls.return_value = mock_crawler
mock_crawler.arun_many.return_value = [mock_result]
adapter = WebScraperAdapter(
urls=["http://example.com"],
api_key="fake-key",
schema_file="schema.json",
)
records = await adapter._fetch_async()
assert len(records) == 1
assert records[0].data["title"] == "Example"
assert records[0].data["source_url"] == "http://example.com"
def test_fetch_sync_calls_async():
"""
Test that the sync fetch method calls the async fetch method.
"""
adapter = WebScraperAdapter(
urls=["http://example.com"], api_key="fake-key", prompt="Extract data"
)
with patch.object(adapter, "_fetch_async", new=AsyncMock(return_value=[])):
result = adapter.fetch()
assert result == []

View File

@ -0,0 +1,40 @@
import os
import pytest
from ingestion.adapters.web_scraper_adapter import WebScraperAdapter
from models.adapters import AdapterRecord
@pytest.mark.integration
@pytest.mark.asyncio
async def test_web_scraper_adapter_with_llm():
"""
Integration test for WebScraperAdapter using LLM extraction on books.toscrape.com.
Requires OPENAI_API_KEY to be set in the environment.
"""
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None, "OPENAI_API_KEY environment variable must be set."
test_url = (
"https://books.toscrape.com/catalogue/a-light-in-the-attic_1000/index.html"
)
adapter = WebScraperAdapter(
urls=[test_url],
api_key=api_key,
prompt="Extract book title, price, availability, and description.",
llm_provider="openai/gpt-4o-mini",
schema_file=None,
verbose=True,
)
records = await adapter._fetch_async()
assert isinstance(records, list)
assert len(records) > 0
for record in records:
assert isinstance(record, AdapterRecord)
assert "source_url" in record.data
assert record.data["source_url"] == test_url
print("✅ Extracted data:", record.data)

View File

@ -308,6 +308,7 @@ dependencies = [
{ name = "loguru" },
{ name = "pandas" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "python-dotenv" },
{ name = "responses" },
{ name = "rich" },
@ -321,6 +322,7 @@ requires-dist = [
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "pytest", specifier = ">=8.3.5" },
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
{ name = "python-dotenv", specifier = ">=1.1.0" },
{ name = "responses", specifier = ">=0.25.7" },
{ name = "rich", specifier = ">=14.0.0" },
@ -1367,6 +1369,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
]
[[package]]
name = "pytest-asyncio"
version = "0.26.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"