agnes-the-ai-analyst/tests/test_generate_sample_data.py
Petr 302494b632 Add --format parquet using project's ParquetManager
Generator now supports --format {csv,parquet,both}. Parquet mode
uses src.parquet_manager.ParquetManager for snappy compression,
proper column types (DATE, TIMESTAMP, DOUBLE), and metadata.
No more ad-hoc pandas conversion needed on the server.
2026-03-10 21:46:20 +01:00

234 lines
9.1 KiB
Python

"""Tests for the sample data generator."""
import csv
import json
import pytest
from pathlib import Path
from scripts.generate_sample_data import SampleDataGenerator, SIZE_CONFIGS
@pytest.fixture
def output_dir(tmp_path: Path) -> Path:
"""Temporary output directory for generated CSV files."""
return tmp_path / "sample_data"
class TestSizeConfigs:
"""Verify size configuration integrity."""
def test_all_sizes_have_required_keys(self):
required = {
"customers", "products", "campaigns", "web_sessions",
"web_leads", "orders", "support_tickets", "months",
}
for size, cfg in SIZE_CONFIGS.items():
missing = required - set(cfg.keys())
assert not missing, f"Size '{size}' missing keys: {missing}"
def test_sizes_scale_monotonically(self):
"""Each size should be strictly larger than the previous one."""
sizes = list(SIZE_CONFIGS.keys())
for key in ["customers", "products", "orders", "web_sessions"]:
values = [SIZE_CONFIGS[s][key] for s in sizes]
assert values == sorted(values), (
f"{key} does not scale monotonically across sizes"
)
class TestXSGeneration:
"""Full generation test with xs size (fast)."""
@pytest.fixture(autouse=True)
def generate(self, output_dir: Path):
self.output_dir = output_dir
gen = SampleDataGenerator(size="xs", seed=42, output_dir=output_dir)
self.manifest = gen.run()
def test_all_csv_files_created(self):
expected = {
"customers", "products", "campaigns", "web_sessions",
"web_leads", "orders", "order_items", "payments",
"support_tickets",
}
csv_files = {p.stem for p in self.output_dir.glob("*.csv")}
assert expected == csv_files
def test_manifest_created(self):
manifest_path = self.output_dir / "_manifest.json"
assert manifest_path.exists()
data = json.loads(manifest_path.read_text())
assert data["size"] == "xs"
assert "tables" in data
assert data["total_rows"] > 0
def test_row_counts_match_config(self):
"""Row counts for directly specified tables should match config."""
cfg = SIZE_CONFIGS["xs"]
for table in ["customers", "products", "campaigns", "web_sessions",
"web_leads", "orders", "support_tickets"]:
assert self.manifest["tables"][table] == cfg[table], (
f"{table}: expected {cfg[table]}, got {self.manifest['tables'][table]}"
)
def test_order_items_derived(self):
"""Order items should be > orders (most orders have multiple items)."""
assert self.manifest["tables"]["order_items"] > self.manifest["tables"]["orders"]
def test_payments_at_least_one_per_order(self):
"""Payments should be >= orders (some have failed retries)."""
assert self.manifest["tables"]["payments"] >= self.manifest["tables"]["orders"]
def test_csv_headers_not_empty(self):
"""Every CSV should have a header and at least one data row."""
for csv_path in self.output_dir.glob("*.csv"):
with open(csv_path) as f:
reader = csv.reader(f)
header = next(reader)
assert len(header) > 0, f"{csv_path.name}: empty header"
first_row = next(reader, None)
assert first_row is not None, f"{csv_path.name}: no data rows"
class TestReferentialIntegrity:
"""Verify foreign key relationships across tables."""
@pytest.fixture(autouse=True)
def generate(self, output_dir: Path):
self.output_dir = output_dir
gen = SampleDataGenerator(size="xs", seed=123, output_dir=output_dir)
gen.run()
self.tables = {}
for csv_path in output_dir.glob("*.csv"):
with open(csv_path) as f:
self.tables[csv_path.stem] = list(csv.DictReader(f))
def _get_ids(self, table: str, column: str) -> set[str]:
return {row[column] for row in self.tables[table]}
def _get_fk_values(self, table: str, column: str) -> set[str]:
return {row[column] for row in self.tables[table] if row[column]}
def test_orders_reference_valid_customers(self):
customer_ids = self._get_ids("customers", "customer_id")
order_customer_ids = self._get_fk_values("orders", "customer_id")
orphans = order_customer_ids - customer_ids
assert not orphans, f"Orders reference non-existent customers: {orphans}"
def test_order_items_reference_valid_orders(self):
order_ids = self._get_ids("orders", "order_id")
item_order_ids = self._get_fk_values("order_items", "order_id")
orphans = item_order_ids - order_ids
assert not orphans, f"Order items reference non-existent orders: {orphans}"
def test_order_items_reference_valid_products(self):
product_ids = self._get_ids("products", "product_id")
item_product_ids = self._get_fk_values("order_items", "product_id")
orphans = item_product_ids - product_ids
assert not orphans, f"Order items reference non-existent products: {orphans}"
def test_payments_reference_valid_orders(self):
order_ids = self._get_ids("orders", "order_id")
payment_order_ids = self._get_fk_values("payments", "order_id")
orphans = payment_order_ids - order_ids
assert not orphans, f"Payments reference non-existent orders: {orphans}"
def test_support_tickets_reference_valid_customers(self):
customer_ids = self._get_ids("customers", "customer_id")
ticket_customer_ids = self._get_fk_values("support_tickets", "customer_id")
orphans = ticket_customer_ids - customer_ids
assert not orphans, f"Tickets reference non-existent customers: {orphans}"
class TestDeterminism:
"""Verify reproducibility with same seed."""
def test_same_seed_produces_same_output(self, tmp_path: Path):
dir1 = tmp_path / "run1"
dir2 = tmp_path / "run2"
gen1 = SampleDataGenerator(size="xs", seed=99, output_dir=dir1)
gen1.run()
gen2 = SampleDataGenerator(size="xs", seed=99, output_dir=dir2)
gen2.run()
for csv_path in dir1.glob("*.csv"):
content1 = csv_path.read_text()
content2 = (dir2 / csv_path.name).read_text()
assert content1 == content2, f"{csv_path.name} differs between runs"
def test_different_seed_produces_different_output(self, tmp_path: Path):
dir1 = tmp_path / "seed1"
dir2 = tmp_path / "seed2"
gen1 = SampleDataGenerator(size="xs", seed=1, output_dir=dir1)
gen1.run()
gen2 = SampleDataGenerator(size="xs", seed=2, output_dir=dir2)
gen2.run()
content1 = (dir1 / "customers.csv").read_text()
content2 = (dir2 / "customers.csv").read_text()
assert content1 != content2
class TestParquetFormat:
"""Test Parquet output format using project's ParquetManager."""
def test_parquet_format_creates_parquet_files(self, tmp_path: Path):
"""--format parquet should produce .parquet files, no CSVs."""
out = tmp_path / "parquet_out"
gen = SampleDataGenerator(
size="xs", seed=42, output_dir=out, output_format="parquet",
)
gen.run()
parquet_files = {p.stem for p in out.glob("*.parquet")}
csv_files = list(out.glob("*.csv"))
expected = {
"customers", "products", "campaigns", "web_sessions",
"web_leads", "orders", "order_items", "payments",
"support_tickets",
}
assert expected == parquet_files
assert csv_files == [], "CSV files should be cleaned up in parquet mode"
def test_parquet_has_correct_types(self, tmp_path: Path):
"""Parquet files should have proper column types from ParquetManager."""
import duckdb
out = tmp_path / "typed"
gen = SampleDataGenerator(
size="xs", seed=42, output_dir=out, output_format="parquet",
)
gen.run()
con = duckdb.connect()
# orders.created_at should be TIMESTAMP, not VARCHAR
schema = con.execute(
f"DESCRIBE SELECT * FROM read_parquet('{out}/orders.parquet')"
).fetchall()
col_types = {row[0]: row[1] for row in schema}
assert col_types["created_at"] == "TIMESTAMP"
assert col_types["total_amount"] == "DOUBLE"
# customers.registration_date should be DATE
schema = con.execute(
f"DESCRIBE SELECT * FROM read_parquet('{out}/customers.parquet')"
).fetchall()
col_types = {row[0]: row[1] for row in schema}
assert col_types["registration_date"] == "DATE"
def test_both_format_creates_csv_and_parquet(self, tmp_path: Path):
"""--format both should produce CSVs + parquet/ subdirectory."""
out = tmp_path / "both_out"
gen = SampleDataGenerator(
size="xs", seed=42, output_dir=out, output_format="both",
)
gen.run()
csv_files = list(out.glob("*.csv"))
parquet_files = list((out / "parquet").glob("*.parquet"))
assert len(csv_files) == 9
assert len(parquet_files) == 9