fix: all adaptor fetch() is async based

This commit is contained in:
Sosokker 2025-05-14 18:16:52 +07:00
parent 2073befd68
commit 6d5d0e3148
19 changed files with 1397 additions and 1024 deletions

11
pipeline/.gitignore vendored
View File

@ -12,3 +12,14 @@ wheels/
.env
/ingestion/data
# Playwright
node_modules/
/test-results/
/playwright-report/
/blob-report/
/playwright/.cache/
data
data/*

View File

@ -17,6 +17,10 @@ class AppSettings(BaseSettings):
Loads values from environment variables or a .env file.
"""
USE_SERVER_API_KEY: bool = True
OPENAI_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
# Application settings
APP_NAME: str = "PipelineRunnerApp"
LOG_LEVEL: str = "DEBUG" # Logging level (e.g., DEBUG, INFO, WARNING)
@ -32,7 +36,7 @@ class AppSettings(BaseSettings):
# Ingestion Defaults
DEFAULT_API_TIMEOUT: int = 30
DEFAULT_SCRAPER_LLM_PROVIDER: str = "openai/gpt-4o-mini"
DEFAULT_SCRAPER_LLM_PROVIDER: str = "gemini/gemini-1.5-pro"
DEFAULT_SCRAPER_CACHE_MODE: str = "ENABLED"
DEFAULT_SCRAPER_PROMPT: str = (
"Extract all data from the page in as much detailed as possible"

View File

@ -62,7 +62,7 @@ class ApiAdapter(DataSourceAdapter):
logger.debug("HTTP session initialized with retry strategy.")
return session
def fetch(self) -> list[AdapterRecord]:
async def fetch(self) -> list[AdapterRecord]:
"""
Perform a GET request and return JSON data as a list of records.

View File

@ -11,7 +11,7 @@ class DataSourceAdapter(Protocol):
Protocol for data source adapters.
"""
def fetch(self) -> list[AdapterRecord]:
async def fetch(self) -> list[AdapterRecord]:
"""
Fetch data from the source.

View File

@ -30,7 +30,7 @@ class FileAdapter(DataSourceAdapter):
f"Initialized FileAdapter for upload: {upload.filename}, format: {upload.content_type}"
)
def fetch(self) -> list[AdapterRecord]:
async def fetch(self) -> list[AdapterRecord]:
"""
Read and parse the file, returning a list of records.
Supports both path-based and uploaded file-like inputs.

View File

@ -2,7 +2,6 @@
Web scraper adapter using crawl4ai to extract structured data.
"""
import asyncio
import json
from config import settings
@ -21,10 +20,12 @@ from crawl4ai.extraction_strategy import (
ExtractionStrategy,
)
from .base import DataSourceAdapter
from loguru import logger
from models.ingestion import AdapterRecord
from models.crawler import HouseSchema
# pyright: reportArgumentType=false
# pyright: reportAssignmentType=false
@ -38,13 +39,13 @@ class WebScraperAdapter(DataSourceAdapter):
def __init__(
self,
urls: list[str],
api_key: str,
api_key: str | None = None,
schema_file: str | None = None,
prompt: str = settings.DEFAULT_SCRAPER_PROMPT,
llm_provider: str = settings.DEFAULT_SCRAPER_LLM_PROVIDER,
output_format: str = "json",
verbose: bool = False,
cache_mode: str = settings.DEFAULT_SCRAPER_CACHE_MODE,
verbose: bool = True,
cache_mode: str = "BYPASS",
):
"""
Initialize the scraper adapter.
@ -71,28 +72,23 @@ class WebScraperAdapter(DataSourceAdapter):
f"Initialized WebScraperAdapter for URLs: {urls} with schema_file={schema_file}, prompt={prompt}, llm_provider={llm_provider}, output_format={output_format}, verbose={verbose}, cache_mode={cache_mode}"
)
if not self.api_key:
logger.error(
"API Key is required for WebScraperAdapter but was not provided."
)
raise ValueError("API Key is required for WebScraperAdapter.")
if settings.USE_SERVER_API_KEY:
if llm_provider == "openai/gpt-4o-mini":
self.api_key = settings.OPENAI_API_KEY
elif llm_provider == "gemini/gemini-1.5-pro":
self.api_key = settings.GEMINI_API_KEY
def fetch(self) -> list[AdapterRecord]:
if not self.api_key:
raise ValueError("API key is required")
async def fetch(self) -> list[AdapterRecord]:
"""
Synchronously fetch data by running the async crawler.
Perform web scraping and return extracted records.
Returns:
List of extracted records.
Raises:
RuntimeError: On failure during crawling or extraction.
List of AdapterRecord objects.
"""
logger.info("Starting synchronous fetch for web scraping.")
try:
return asyncio.run(self._fetch_async())
except Exception as e:
logger.error(f"Web scraping failed: {e}")
raise RuntimeError(f"Web scraping failed: {e}")
return await self._fetch_async()
async def _fetch_async(self) -> list[AdapterRecord]:
"""
@ -125,10 +121,13 @@ class WebScraperAdapter(DataSourceAdapter):
elif self.prompt:
extraction_strategy = LLMExtractionStrategy(
llm_config=llm_cfg,
instruction=self.prompt, # Use the instance's prompt
instruction=self.prompt,
schema=HouseSchema.schema(),
extraction_type="schema",
chunk_token_threshold=1200,
apply_chunking=True,
verbose=self.verbose,
extra_args={"max_tokens": 1500},
)
logger.debug("Using LLM extraction strategy.")
else:
@ -163,34 +162,45 @@ class WebScraperAdapter(DataSourceAdapter):
adapter_records: list[AdapterRecord] = []
for res in results:
logger.debug(res)
if not res.success or not res.extracted_content:
logger.warning(
f"Skipping failed or empty result for URL: {getattr(res, 'url', None)}"
)
continue
try:
content = json.loads(res.extracted_content)
logger.info(f"Parsed extracted content for URL: {res.url}")
except Exception:
logger.error(f"Failed to parse extracted content for URL: {res.url}")
continue
if content is None:
logger.warning(f"Extracted content is None for URL: {res.url}")
continue
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
item["source_url"] = res.url
adapter_records.append(
AdapterRecord(source="scrape", data=item)
)
elif isinstance(content, dict):
content["source_url"] = res.url
adapter_records.append(AdapterRecord(source="scrape", data=content))
else:
logger.warning(
f"Extracted content for URL {res.url} is not a list or dict: {type(content)}"
adapter_records.append(
AdapterRecord(
source="scrape",
data={"content": res.metadata, "source_url": res.url},
)
)
# try:
# content = json.loads(res.extracted_content)
# logger.info(f"Parsed extracted content for URL: {res.url}")
# except Exception:
# logger.error(f"Failed to parse extracted content for URL: {res.url}")
# continue
# if content is None:
# logger.warning(f"Extracted content is None for URL: {res.url}")
# continue
# if isinstance(content, list):
# for item in content:
# if isinstance(item, dict):
# item["source_url"] = res.url
# adapter_records.append(
# AdapterRecord(source="scrape", data=item)
# )
# elif isinstance(content, dict):
# content["source_url"] = res.url
# adapter_records.append(AdapterRecord(source="scrape", data=content))
# else:
# logger.warning(
# f"Extracted content for URL {res.url} is not a list or dict: {type(content)}"
# )
logger.info(
f"Web scraping completed. Extracted {len(adapter_records)} records."

View File

@ -16,7 +16,9 @@ class Ingestor:
"""
@staticmethod
def run(sources: list[IngestSourceConfig], strategy: str = "simple") -> OutputData:
async def run(
sources: list[IngestSourceConfig], strategy: str = "simple"
) -> OutputData:
strategies: dict[str, IngestionMethod] = {
"simple": SimpleIngestionStrategy(),
"ml": MLIngestionStrategy(),
@ -25,4 +27,4 @@ class Ingestor:
if strategy not in strategies:
raise ValueError(f"Unsupported strategy: {strategy}")
return strategies[strategy].run(sources)
return await strategies[strategy].run(sources)

View File

@ -4,5 +4,5 @@ from models.ingestion import IngestSourceConfig, OutputData
class IngestionMethod(ABC):
@abstractmethod
def run(self, sources: list[IngestSourceConfig]) -> OutputData:
async def run(self, sources: list[IngestSourceConfig]) -> OutputData:
pass

View File

@ -3,7 +3,7 @@ from models.ingestion import IngestSourceConfig, OutputData
class MLIngestionStrategy(IngestionMethod):
def run(self, sources: list[IngestSourceConfig]) -> OutputData:
async def run(self, sources: list[IngestSourceConfig]) -> OutputData:
# TODO: Add ML-based logic (e.g., deduplication, entity linking, classification)
return OutputData(
records=[], # Placeholder

View File

@ -17,32 +17,38 @@ from loguru import logger
class SimpleIngestionStrategy(IngestionMethod):
def run(self, sources: list[IngestSourceConfig]) -> OutputData:
async def run(self, sources: list[IngestSourceConfig]) -> OutputData:
results: list[AdapterRecord] = []
# TODO: find better way to check config type and property
for source in sources:
try:
match source.type:
case SourceType.API:
config = source.config
assert isinstance(config, ApiConfig)
config = source.parsed_config
assert isinstance(config, ApiConfig), (
f"Wrong config type for source {source.type}: {config}, get type {type(config)}"
)
adapter = ApiAdapter(
url=config.url,
headers=config.headers,
timeout=config.timeout or settings.DEFAULT_API_TIMEOUT,
token=config.token,
)
records = adapter.fetch()
records = await adapter.fetch()
case SourceType.FILE:
config = source.config
assert isinstance(config, FileConfig)
config = source.parsed_config
assert isinstance(config, FileConfig), (
f"Wrong config type for source {source.type}: {config}, get type {type(config)}"
)
adapter = FileAdapter(upload=config.upload)
records = adapter.fetch()
records = await adapter.fetch()
case SourceType.SCRAPE:
config = source.config
assert isinstance(config, ScrapeConfig)
config = source.parsed_config
assert isinstance(config, ScrapeConfig), (
f"Wrong config type for source {source.type}: {config}, get type {type(config)}"
)
adapter = WebScraperAdapter(
urls=config.urls,
api_key=config.api_key,
@ -55,7 +61,7 @@ class SimpleIngestionStrategy(IngestionMethod):
cache_mode=config.cache_mode
or settings.DEFAULT_SCRAPER_CACHE_MODE,
)
records = adapter.fetch()
records = await adapter.fetch()
results.extend(records)

View File

@ -1,3 +1,4 @@
import sys
import platform
import asyncio
@ -17,9 +18,10 @@ from routers.logs import router as logs_router
sse_queue = asyncio.Queue(maxsize=settings.SSE_LOG_QUEUE_MAX_SIZE)
# ! Window specific asyncio policy
if platform.system() == "Windows":
logger.info("Setting WindowsProactorEventLoopPolicy for asyncio.")
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
if platform.system() == "Windows" or sys.platform == "win32":
logger.info("Setting WindowsSelectorEventLoopPolicy for asyncio.")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# --- Resource Initialization ---
pipeline_store: PipelineStore = InMemoryPipelineStore()
@ -85,6 +87,12 @@ async def read_root():
if __name__ == "__main__":
import uvicorn
if platform.system() == "Windows":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import asyncio
# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Starting Uvicorn server...")
# ! use reload=True only for development
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, loop="asyncio")
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

View File

@ -0,0 +1,190 @@
# models/canonical.py (Create this new file)
from pydantic import BaseModel, Field, HttpUrl, field_validator
from typing import Optional, List, Dict, Any, Literal
from datetime import datetime, timezone
from uuid import uuid4
class Address(BaseModel):
street_address: Optional[str] = Field(None, description="Street name and number")
city: Optional[str] = Field(None, description="City name")
state_province: Optional[str] = Field(
None, description="State or province abbreviation/name"
)
postal_code: Optional[str] = Field(None, description="Zip or postal code")
country: Optional[str] = Field(
"USA", description="Country code or name"
) # Example default
class PropertyFeatures(BaseModel):
bedrooms: Optional[int] = Field(None, description="Number of bedrooms")
bathrooms: Optional[float] = Field(
None, description="Number of bathrooms (float for half baths)"
)
area_sqft: Optional[float] = Field(None, description="Total area in square feet")
lot_size_sqft: Optional[float] = Field(None, description="Lot size in square feet")
year_built: Optional[int] = Field(None, description="Year the property was built")
property_type: Optional[str] = Field(
None,
description="e.g., Single Family House, Condo, Townhouse, Land, Multi-Family",
)
has_pool: Optional[bool] = None
has_garage: Optional[bool] = None
stories: Optional[int] = None
class ListingDetails(BaseModel):
price: Optional[float] = Field(None, description="Listing price")
currency: Optional[str] = Field("USD", description="Currency code")
listing_status: Optional[
Literal["For Sale", "For Rent", "Sold", "Pending", "Off Market", "Unknown"]
] = Field("Unknown", description="Current status of the listing")
listing_type: Optional[Literal["Sale", "Rent"]] = Field(
None, description="Whether the property is for sale or rent"
)
listed_date: Optional[datetime] = Field(
None, description="Date the property was listed (UTC)"
)
last_updated_date: Optional[datetime] = Field(
None, description="Date the listing was last updated (UTC)"
)
listing_url: Optional[HttpUrl] = Field(
None, description="URL of the original listing"
)
mls_id: Optional[str] = Field(
None, description="Multiple Listing Service ID, if available"
)
class AgentContact(BaseModel):
name: Optional[str] = Field(None, description="Listing agent or contact name")
phone: Optional[str] = Field(None, description="Contact phone number")
email: Optional[str] = Field(None, description="Contact email address")
brokerage_name: Optional[str] = Field(
None, description="Real estate brokerage name"
)
class CanonicalRecord(BaseModel):
"""
Represents a unified Real Estate Listing record after mapping.
Target schema for the ML mapping model.
"""
# --- Core Identifier & Provenance ---
canonical_record_id: str = Field(
default_factory=lambda: f"cre-{uuid4()}",
description="Unique identifier for this canonical record.",
examples=[f"cre-{uuid4()}"],
)
original_source_identifier: str = Field(
...,
description="Identifier of the original source (e.g., URL, filename + row index).",
)
original_source_type: str = Field(
...,
description="Type of the original source adapter ('api', 'file', 'scrape').",
)
entity_type: Literal["RealEstateListing", "NewsArticle", "Other"] = Field(
"Other", description="Classification of the source entity."
)
mapping_model_version: Optional[str] = Field(
None, description="Version identifier of the ML model used for mapping."
)
mapping_timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="Timestamp (UTC) when the mapping was performed.",
)
# --- Real Estate Specific Fields ---
address: Optional[Address] = Field(
default=None, description="Structured address details."
)
features: Optional[PropertyFeatures] = Field(
default=None, description="Details about the property itself."
)
listing: Optional[ListingDetails] = Field(
default=None, description="Information about the listing status and price."
)
agent: Optional[AgentContact] = Field(
default=None, description="Listing agent or contact information."
)
description: Optional[str] = Field(
None, description="Textual description from the listing."
)
image_urls: Optional[List[HttpUrl]] = Field(
default=None, description="List of URLs for property images."
)
# --- Common Fields ---
raw_source_data: Optional[Dict[str, Any]] = Field( # Changed name for clarity
default=None, description="Original source data record (JSON representation)."
)
@field_validator("listing", "features", "address", "agent")
def check_fields_for_real_estate(cls, v, info):
if info.data.get("entity_type") == "RealEstateListing" and v is None:
# NOTE: Depending on strictness, might raise ValueError or just allow it
# print(f"Warning: RealEstateListing has None for {info.field_name}")
pass
return v
class Config:
# Example for documentation
schema_extra = {
"example": {
"canonical_record_id": f"cre-{uuid4()}",
"original_source_identifier": "https://some.realestate.site/listing/123",
"original_source_type": "scrape",
"entity_type": "RealEstateListing",
"mapping_model_version": "realestate-mapper-v1.0",
"mapping_timestamp": "2025-04-29T12:00:00Z",
"address": {
"street_address": "123 Main St",
"city": "Anytown",
"state_province": "CA",
"postal_code": "90210",
"country": "USA",
},
"features": {
"bedrooms": 3,
"bathrooms": 2.5,
"area_sqft": 1850.0,
"lot_size_sqft": 5500.0,
"year_built": 1995,
"property_type": "Single Family House",
"has_pool": True,
"has_garage": True,
"stories": 2,
},
"listing": {
"price": 750000.0,
"currency": "USD",
"listing_status": "For Sale",
"listing_type": "Sale",
"listed_date": "2025-04-15T00:00:00Z",
"last_updated_date": "2025-04-28T00:00:00Z",
"listing_url": "https://some.realestate.site/listing/123",
"mls_id": "MLS123456",
},
"agent": {
"name": "Jane Doe",
"phone": "555-123-4567",
"email": "jane.doe@email.com",
"brokerage_name": "Best Realty",
},
"description": "Beautiful 3 bed, 2.5 bath home in a great neighborhood. Recently updated kitchen, spacious backyard with pool.",
"image_urls": [
"https://images.site/123/1.jpg",
"https://images.site/123/2.jpg",
],
"raw_source_data": {
"title": "Charming Home For Sale",
"price_str": "$750,000",
"sqft": "1,850",
"...": "...",
},
}
}

View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, Field
class HouseSchema(BaseModel):
url: str | None = Field(None, description="House URL")
title: str | None = Field(None, description="House title")
price: str | None = Field(None, description="House price")
address: str | None = Field(None, description="House address")
city: str | None = Field(None, description="House city")
state: str | None = Field(None, description="House state")
postal_code: str | None = Field(None, description="House postal code")
description: str | None = Field(None, description="House description")
features: str | None = Field(None, description="House features")
beds: int | None = Field(None, description="House beds")
baths: float | None = Field(None, description="House baths")
sqft: int | None = Field(None, description="House sqft")
lot_size: str | None = Field(None, description="House lot size")
year_built: int | None = Field(None, description="House year built")
type: str | None = Field(None, description="House type")
provider: str | None = Field(None, description="House provider")
image_url: str | None = Field(None, description="House image URL")
details: str | None = Field(None, description="House details")

View File

@ -54,7 +54,7 @@ class FileConfig(BaseModel):
class ScrapeConfig(BaseModel):
urls: list[str]
api_key: str
api_key: str | None = None
schema_file: str | None = None
prompt: str | None = None
llm_provider: str | None = None
@ -75,6 +75,17 @@ class IngestSourceConfig(BaseModel):
..., description="Configuration for the adapter"
)
@property
def parsed_config(self) -> ApiConfig | FileConfig | ScrapeConfig:
if self.type == SourceType.API:
return ApiConfig(**self.config.model_dump())
elif self.type == SourceType.FILE:
return FileConfig(**self.config.model_dump())
elif self.type == SourceType.SCRAPE:
return ScrapeConfig(**self.config.model_dump())
else:
raise ValueError(f"Unsupported type: {self.type}")
class IngestorInput(BaseModel):
"""

View File

@ -7,11 +7,12 @@ requires-python = ">=3.12"
dependencies = [
"apscheduler>=3.11.0",
"crawl4ai>=0.5.0.post8",
"fastapi[standard]>=0.115.12",
"fastapi[all,standard]>=0.115.12",
"freezegun>=1.5.1",
"inquirer>=3.4.0",
"loguru>=0.7.3",
"pandas>=2.2.3",
"playwright>=1.51.0",
"pydantic-settings>=2.9.1",
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
@ -20,6 +21,7 @@ dependencies = [
"responses>=0.25.7",
"rich>=14.0.0",
"sse-starlette>=2.3.4",
"uvicorn[standard]>=0.34.1",
]
[tool.pytest.ini_options]

View File

@ -365,7 +365,7 @@ class PipelineService:
"""
try:
logger.info(f"Executing ingestion with config: {config}")
results: OutputData = Ingestor.run(config.sources)
results: OutputData = await Ingestor.run(config.sources)
logger.info(
f"Ingestion completed successfully. Records count: {len(results.records)}"
)

View File

@ -38,20 +38,20 @@ def test_fetch_dict_response(single_product):
assert adapter_result[0].data == expected_data
def test_fetch_list_response(multiple_product):
async def test_fetch_list_response(multiple_product):
"""Test fetching a list of records from a JSON API endpoint."""
response = httpx.get(multiple_product, timeout=10)
response.raise_for_status()
expected_data = response.json()
adapter = ApiAdapter(url=multiple_product)
adapter_result = adapter.fetch()
adapter_result = await adapter.fetch()
assert adapter_result[0].data == expected_data
@responses.activate
def test_fetch_http_error(single_product):
async def test_fetch_http_error(single_product):
"""Test handling HTTP errors and validate graceful failure."""
for _ in range(4):
responses.add(responses.GET, single_product, status=500)
@ -59,20 +59,20 @@ def test_fetch_http_error(single_product):
adapter = ApiAdapter(url=single_product)
with pytest.raises(RuntimeError) as exc_info:
adapter.fetch()
await adapter.fetch()
assert "API request failed" in str(exc_info.value)
@responses.activate
def test_fetch_json_decode_error(single_product):
async def test_fetch_json_decode_error(single_product):
"""Test handling JSON decode errors."""
responses.add(responses.GET, single_product, body="not-a-json", status=200)
adapter = ApiAdapter(url=single_product)
with pytest.raises(RuntimeError) as exc_info:
adapter.fetch()
await adapter.fetch()
assert "Failed to parse JSON response" in str(exc_info.value)

View File

@ -11,11 +11,11 @@ def make_upload_file(content: str, filename: str) -> UploadFile:
)
def test_file_adapter_csv():
async def test_file_adapter_csv():
csv_content = "id,name,price\n001,Apple,12\n002,Orange,10\n003,Banana,8"
upload = make_upload_file(csv_content, "test.csv")
adapter = FileAdapter(upload)
records = adapter.fetch()
records = await adapter.fetch()
assert len(records) == 3
assert records[0].data["name"] == "Apple"
@ -23,7 +23,7 @@ def test_file_adapter_csv():
assert records[2].data["id"] == 3
def test_file_adapter_json():
async def test_file_adapter_json():
json_content = """
[{"id": "001", "name": "Apple", "price": 12},
{"id": "002", "name": "Orange", "price": 10},
@ -31,7 +31,7 @@ def test_file_adapter_json():
"""
upload = make_upload_file(json_content, "test.json")
adapter = FileAdapter(upload)
records = adapter.fetch()
records = await adapter.fetch()
assert len(records) == 3
assert records[0].data["name"] == "Apple"
@ -39,7 +39,7 @@ def test_file_adapter_json():
assert records[2].data["id"] == 3
def test_file_adapter_missing_filename():
async def test_file_adapter_missing_filename():
upload = UploadFile(
filename="",
file=io.BytesIO("id,name,price\n001,Apple,12".encode("utf-8")),
@ -47,6 +47,6 @@ def test_file_adapter_missing_filename():
adapter = FileAdapter(upload)
with pytest.raises(ValueError) as excinfo:
adapter.fetch()
await adapter.fetch()
assert "File name is required" in str(excinfo.value)

File diff suppressed because it is too large Load Diff