From 9e8eabec049be5c5fff0d2fa2a3085ced37fd64a Mon Sep 17 00:00:00 2001 From: Sosokker Date: Mon, 12 May 2025 16:12:28 +0700 Subject: [PATCH] refactor: fix typehint and add test --- .../ingestion/adapters/web_scraper_adapter.py | 62 +++++++++----- pipeline/pyproject.toml | 5 ++ pipeline/tests/test_scraper_adapter.py | 82 +++++++++++++++++++ .../tests/test_scraper_adapter_integration.py | 40 +++++++++ pipeline/uv.lock | 14 ++++ 5 files changed, 181 insertions(+), 22 deletions(-) create mode 100644 pipeline/tests/test_scraper_adapter.py create mode 100644 pipeline/tests/test_scraper_adapter_integration.py diff --git a/pipeline/ingestion/adapters/web_scraper_adapter.py b/pipeline/ingestion/adapters/web_scraper_adapter.py index 1912413..d0678bb 100644 --- a/pipeline/ingestion/adapters/web_scraper_adapter.py +++ b/pipeline/ingestion/adapters/web_scraper_adapter.py @@ -4,7 +4,6 @@ Web scraper adapter using crawl4ai to extract structured data. import asyncio import json -from typing import List, Dict, Any, Optional from crawl4ai import ( AsyncWebCrawler, @@ -23,19 +22,26 @@ from crawl4ai.extraction_strategy import ( from .base import DataSourceAdapter from loguru import logger +from models.adapters import AdapterRecord + +# pyright: reportArgumentType=false +# pyright: reportAssignmentType=false + class WebScraperAdapter(DataSourceAdapter): """ Adapter for web scraping using crawl4ai. """ + DEFAULT_PROMPT = "Extract all data from the page in as much detailed as possible" + def __init__( self, - urls: List[str], - schema_file: Optional[str] = None, - prompt: Optional[str] = None, - llm_provider: str = "openai/gpt-4", - api_key: Optional[str] = None, + urls: list[str], + api_key: str, + schema_file: str | None = None, + prompt: str = DEFAULT_PROMPT, + llm_provider: str = "openai/gpt-4o-mini", output_format: str = "json", verbose: bool = False, cache_mode: str = "ENABLED", @@ -61,9 +67,11 @@ class WebScraperAdapter(DataSourceAdapter): self.output_format = output_format self.verbose = verbose self.cache_mode = cache_mode - logger.info(f"Initialized WebScraperAdapter for URLs: {urls}") + logger.info( + f"Initialized WebScraperAdapter for URLs: {urls} with schema_file={schema_file}, prompt={prompt}, llm_provider={llm_provider}, output_format={output_format}, verbose={verbose}, cache_mode={cache_mode}" + ) - def fetch(self) -> List[Dict[str, Any]]: + def fetch(self) -> list[AdapterRecord]: """ Synchronously fetch data by running the async crawler. @@ -80,7 +88,7 @@ class WebScraperAdapter(DataSourceAdapter): logger.error(f"Web scraping failed: {e}") raise RuntimeError(f"Web scraping failed: {e}") - async def _fetch_async(self) -> List[Dict[str, Any]]: + async def _fetch_async(self) -> list[AdapterRecord]: """ Internal async method to perform crawling and extraction. """ @@ -92,7 +100,7 @@ class WebScraperAdapter(DataSourceAdapter): # Prepare extraction strategy llm_cfg = LLMConfig(provider=self.llm_provider, api_token=self.api_key) - extraction_strategy: Optional[ExtractionStrategy] = None + extraction_strategy: ExtractionStrategy | None = None if self.schema_file: try: @@ -126,7 +134,9 @@ class WebScraperAdapter(DataSourceAdapter): try: cache_enum = getattr(CacheMode, self.cache_mode.upper()) except AttributeError: - logger.warning(f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED.") + logger.warning( + f"Invalid cache mode '{self.cache_mode}', defaulting to ENABLED." + ) cache_enum = CacheMode.ENABLED run_cfg = CrawlerRunConfig( @@ -138,22 +148,23 @@ class WebScraperAdapter(DataSourceAdapter): # Execute crawl try: logger.info(f"Crawling URLs: {self.urls}") - results: List[CrawlResult] = await crawler.arun_many( + results: list[CrawlResult] = await crawler.arun_many( urls=self.urls, config=run_cfg ) - logger.debug(f"Crawling completed. Results: {results}") + logger.info("Crawling completed.") finally: await crawler.close() - # Process crawl results - records: List[Dict[str, Any]] = [] + adapter_records: list[AdapterRecord] = [] for res in results: if not res.success or not res.extracted_content: - logger.warning(f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}") + logger.warning( + f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}" + ) continue try: content = json.loads(res.extracted_content) - logger.debug(f"Parsed extracted content for URL: {res.url}") + logger.info(f"Parsed extracted content for URL: {res.url}") except Exception: logger.error(f"Failed to parse extracted content for URL: {res.url}") continue @@ -164,12 +175,19 @@ class WebScraperAdapter(DataSourceAdapter): for item in content: if isinstance(item, dict): item["source_url"] = res.url - records.extend(content) + adapter_records.append( + AdapterRecord(source="scrape", data=item) + ) elif isinstance(content, dict): content["source_url"] = res.url - records.append(content) + adapter_records.append(AdapterRecord(source="scrape", data=content)) else: - logger.warning(f"Extracted content for URL {res.url} is not a list or dict: {type(content)}") + logger.warning( + f"Extracted content for URL {res.url} is not a list or dict: {type(content)}" + ) - logger.info(f"Web scraping completed. Extracted {len(records)} records.") - return records \ No newline at end of file + logger.info( + f"Web scraping completed. Extracted {len(adapter_records)} records." + ) + logger.debug(adapter_records) + return adapter_records diff --git a/pipeline/pyproject.toml b/pipeline/pyproject.toml index 7b3e243..da36760 100644 --- a/pipeline/pyproject.toml +++ b/pipeline/pyproject.toml @@ -11,7 +11,12 @@ dependencies = [ "loguru>=0.7.3", "pandas>=2.2.3", "pytest>=8.3.5", + "pytest-asyncio>=0.26.0", "python-dotenv>=1.1.0", "responses>=0.25.7", "rich>=14.0.0", ] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" \ No newline at end of file diff --git a/pipeline/tests/test_scraper_adapter.py b/pipeline/tests/test_scraper_adapter.py new file mode 100644 index 0000000..3c50a8c --- /dev/null +++ b/pipeline/tests/test_scraper_adapter.py @@ -0,0 +1,82 @@ +import pytest +import json +from unittest.mock import patch, AsyncMock, MagicMock, mock_open + +from ingestion.adapters.web_scraper_adapter import WebScraperAdapter +from models.adapters import AdapterRecord + + +@pytest.mark.asyncio +async def test_fetch_with_llm_extraction(): + """ + Test fetching data using LLM extraction. + """ + mock_result = MagicMock() + mock_result.success = True + mock_result.url = "http://example.com" + mock_result.extracted_content = json.dumps({"title": "Example"}) + + with patch( + "ingestion.adapters.web_scraper_adapter.AsyncWebCrawler" + ) as mock_crawler_cls: + mock_crawler = AsyncMock() + mock_crawler_cls.return_value = mock_crawler + mock_crawler.arun_many.return_value = [mock_result] + + adapter = WebScraperAdapter( + urls=["http://example.com"], + api_key="fake-key", + schema_file=None, + prompt="Extract data", + ) + + records = await adapter._fetch_async() + + assert isinstance(records, list) + assert isinstance(records[0], AdapterRecord) + assert records[0].data["title"] == "Example" + assert records[0].data["source_url"] == "http://example.com" + + +@pytest.mark.asyncio +async def test_fetch_with_schema_file(): + """ + Test fetching data using schema file. + """ + schema = {"title": {"selector": "h1"}} + mock_result = MagicMock() + mock_result.success = True + mock_result.url = "http://example.com" + mock_result.extracted_content = json.dumps({"title": "Example"}) + + with patch("builtins.open", mock_open(read_data=json.dumps(schema))): + with patch( + "ingestion.adapters.web_scraper_adapter.AsyncWebCrawler" + ) as mock_crawler_cls: + mock_crawler = AsyncMock() + mock_crawler_cls.return_value = mock_crawler + mock_crawler.arun_many.return_value = [mock_result] + + adapter = WebScraperAdapter( + urls=["http://example.com"], + api_key="fake-key", + schema_file="schema.json", + ) + + records = await adapter._fetch_async() + + assert len(records) == 1 + assert records[0].data["title"] == "Example" + assert records[0].data["source_url"] == "http://example.com" + + +def test_fetch_sync_calls_async(): + """ + Test that the sync fetch method calls the async fetch method. + """ + adapter = WebScraperAdapter( + urls=["http://example.com"], api_key="fake-key", prompt="Extract data" + ) + with patch.object(adapter, "_fetch_async", new=AsyncMock(return_value=[])): + result = adapter.fetch() + assert result == [] diff --git a/pipeline/tests/test_scraper_adapter_integration.py b/pipeline/tests/test_scraper_adapter_integration.py new file mode 100644 index 0000000..069fec8 --- /dev/null +++ b/pipeline/tests/test_scraper_adapter_integration.py @@ -0,0 +1,40 @@ +import os +import pytest +from ingestion.adapters.web_scraper_adapter import WebScraperAdapter +from models.adapters import AdapterRecord + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_web_scraper_adapter_with_llm(): + """ + Integration test for WebScraperAdapter using LLM extraction on books.toscrape.com. + Requires OPENAI_API_KEY to be set in the environment. + """ + api_key = os.getenv("OPENAI_API_KEY") + assert api_key is not None, "OPENAI_API_KEY environment variable must be set." + + test_url = ( + "https://books.toscrape.com/catalogue/a-light-in-the-attic_1000/index.html" + ) + + adapter = WebScraperAdapter( + urls=[test_url], + api_key=api_key, + prompt="Extract book title, price, availability, and description.", + llm_provider="openai/gpt-4o-mini", + schema_file=None, + verbose=True, + ) + + records = await adapter._fetch_async() + + assert isinstance(records, list) + assert len(records) > 0 + + for record in records: + assert isinstance(record, AdapterRecord) + assert "source_url" in record.data + assert record.data["source_url"] == test_url + + print("✅ Extracted data:", record.data) diff --git a/pipeline/uv.lock b/pipeline/uv.lock index e765972..89a02c1 100644 --- a/pipeline/uv.lock +++ b/pipeline/uv.lock @@ -308,6 +308,7 @@ dependencies = [ { name = "loguru" }, { name = "pandas" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "python-dotenv" }, { name = "responses" }, { name = "rich" }, @@ -321,6 +322,7 @@ requires-dist = [ { name = "loguru", specifier = ">=0.7.3" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "responses", specifier = ">=0.25.7" }, { name = "rich", specifier = ">=14.0.0" }, @@ -1367,6 +1369,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"