Generator now supports --format {csv,parquet,both}. Parquet mode
uses src.parquet_manager.ParquetManager for snappy compression,
proper column types (DATE, TIMESTAMP, DOUBLE), and metadata.
No more ad-hoc pandas conversion needed on the server.
234 lines
9.1 KiB
Python
234 lines
9.1 KiB
Python
"""Tests for the sample data generator."""
|
|
|
|
import csv
|
|
import json
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from scripts.generate_sample_data import SampleDataGenerator, SIZE_CONFIGS
|
|
|
|
|
|
@pytest.fixture
|
|
def output_dir(tmp_path: Path) -> Path:
|
|
"""Temporary output directory for generated CSV files."""
|
|
return tmp_path / "sample_data"
|
|
|
|
|
|
class TestSizeConfigs:
|
|
"""Verify size configuration integrity."""
|
|
|
|
def test_all_sizes_have_required_keys(self):
|
|
required = {
|
|
"customers", "products", "campaigns", "web_sessions",
|
|
"web_leads", "orders", "support_tickets", "months",
|
|
}
|
|
for size, cfg in SIZE_CONFIGS.items():
|
|
missing = required - set(cfg.keys())
|
|
assert not missing, f"Size '{size}' missing keys: {missing}"
|
|
|
|
def test_sizes_scale_monotonically(self):
|
|
"""Each size should be strictly larger than the previous one."""
|
|
sizes = list(SIZE_CONFIGS.keys())
|
|
for key in ["customers", "products", "orders", "web_sessions"]:
|
|
values = [SIZE_CONFIGS[s][key] for s in sizes]
|
|
assert values == sorted(values), (
|
|
f"{key} does not scale monotonically across sizes"
|
|
)
|
|
|
|
|
|
class TestXSGeneration:
|
|
"""Full generation test with xs size (fast)."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def generate(self, output_dir: Path):
|
|
self.output_dir = output_dir
|
|
gen = SampleDataGenerator(size="xs", seed=42, output_dir=output_dir)
|
|
self.manifest = gen.run()
|
|
|
|
def test_all_csv_files_created(self):
|
|
expected = {
|
|
"customers", "products", "campaigns", "web_sessions",
|
|
"web_leads", "orders", "order_items", "payments",
|
|
"support_tickets",
|
|
}
|
|
csv_files = {p.stem for p in self.output_dir.glob("*.csv")}
|
|
assert expected == csv_files
|
|
|
|
def test_manifest_created(self):
|
|
manifest_path = self.output_dir / "_manifest.json"
|
|
assert manifest_path.exists()
|
|
data = json.loads(manifest_path.read_text())
|
|
assert data["size"] == "xs"
|
|
assert "tables" in data
|
|
assert data["total_rows"] > 0
|
|
|
|
def test_row_counts_match_config(self):
|
|
"""Row counts for directly specified tables should match config."""
|
|
cfg = SIZE_CONFIGS["xs"]
|
|
for table in ["customers", "products", "campaigns", "web_sessions",
|
|
"web_leads", "orders", "support_tickets"]:
|
|
assert self.manifest["tables"][table] == cfg[table], (
|
|
f"{table}: expected {cfg[table]}, got {self.manifest['tables'][table]}"
|
|
)
|
|
|
|
def test_order_items_derived(self):
|
|
"""Order items should be > orders (most orders have multiple items)."""
|
|
assert self.manifest["tables"]["order_items"] > self.manifest["tables"]["orders"]
|
|
|
|
def test_payments_at_least_one_per_order(self):
|
|
"""Payments should be >= orders (some have failed retries)."""
|
|
assert self.manifest["tables"]["payments"] >= self.manifest["tables"]["orders"]
|
|
|
|
def test_csv_headers_not_empty(self):
|
|
"""Every CSV should have a header and at least one data row."""
|
|
for csv_path in self.output_dir.glob("*.csv"):
|
|
with open(csv_path) as f:
|
|
reader = csv.reader(f)
|
|
header = next(reader)
|
|
assert len(header) > 0, f"{csv_path.name}: empty header"
|
|
first_row = next(reader, None)
|
|
assert first_row is not None, f"{csv_path.name}: no data rows"
|
|
|
|
|
|
class TestReferentialIntegrity:
|
|
"""Verify foreign key relationships across tables."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def generate(self, output_dir: Path):
|
|
self.output_dir = output_dir
|
|
gen = SampleDataGenerator(size="xs", seed=123, output_dir=output_dir)
|
|
gen.run()
|
|
self.tables = {}
|
|
for csv_path in output_dir.glob("*.csv"):
|
|
with open(csv_path) as f:
|
|
self.tables[csv_path.stem] = list(csv.DictReader(f))
|
|
|
|
def _get_ids(self, table: str, column: str) -> set[str]:
|
|
return {row[column] for row in self.tables[table]}
|
|
|
|
def _get_fk_values(self, table: str, column: str) -> set[str]:
|
|
return {row[column] for row in self.tables[table] if row[column]}
|
|
|
|
def test_orders_reference_valid_customers(self):
|
|
customer_ids = self._get_ids("customers", "customer_id")
|
|
order_customer_ids = self._get_fk_values("orders", "customer_id")
|
|
orphans = order_customer_ids - customer_ids
|
|
assert not orphans, f"Orders reference non-existent customers: {orphans}"
|
|
|
|
def test_order_items_reference_valid_orders(self):
|
|
order_ids = self._get_ids("orders", "order_id")
|
|
item_order_ids = self._get_fk_values("order_items", "order_id")
|
|
orphans = item_order_ids - order_ids
|
|
assert not orphans, f"Order items reference non-existent orders: {orphans}"
|
|
|
|
def test_order_items_reference_valid_products(self):
|
|
product_ids = self._get_ids("products", "product_id")
|
|
item_product_ids = self._get_fk_values("order_items", "product_id")
|
|
orphans = item_product_ids - product_ids
|
|
assert not orphans, f"Order items reference non-existent products: {orphans}"
|
|
|
|
def test_payments_reference_valid_orders(self):
|
|
order_ids = self._get_ids("orders", "order_id")
|
|
payment_order_ids = self._get_fk_values("payments", "order_id")
|
|
orphans = payment_order_ids - order_ids
|
|
assert not orphans, f"Payments reference non-existent orders: {orphans}"
|
|
|
|
def test_support_tickets_reference_valid_customers(self):
|
|
customer_ids = self._get_ids("customers", "customer_id")
|
|
ticket_customer_ids = self._get_fk_values("support_tickets", "customer_id")
|
|
orphans = ticket_customer_ids - customer_ids
|
|
assert not orphans, f"Tickets reference non-existent customers: {orphans}"
|
|
|
|
|
|
class TestDeterminism:
|
|
"""Verify reproducibility with same seed."""
|
|
|
|
def test_same_seed_produces_same_output(self, tmp_path: Path):
|
|
dir1 = tmp_path / "run1"
|
|
dir2 = tmp_path / "run2"
|
|
|
|
gen1 = SampleDataGenerator(size="xs", seed=99, output_dir=dir1)
|
|
gen1.run()
|
|
|
|
gen2 = SampleDataGenerator(size="xs", seed=99, output_dir=dir2)
|
|
gen2.run()
|
|
|
|
for csv_path in dir1.glob("*.csv"):
|
|
content1 = csv_path.read_text()
|
|
content2 = (dir2 / csv_path.name).read_text()
|
|
assert content1 == content2, f"{csv_path.name} differs between runs"
|
|
|
|
def test_different_seed_produces_different_output(self, tmp_path: Path):
|
|
dir1 = tmp_path / "seed1"
|
|
dir2 = tmp_path / "seed2"
|
|
|
|
gen1 = SampleDataGenerator(size="xs", seed=1, output_dir=dir1)
|
|
gen1.run()
|
|
|
|
gen2 = SampleDataGenerator(size="xs", seed=2, output_dir=dir2)
|
|
gen2.run()
|
|
|
|
content1 = (dir1 / "customers.csv").read_text()
|
|
content2 = (dir2 / "customers.csv").read_text()
|
|
assert content1 != content2
|
|
|
|
|
|
class TestParquetFormat:
|
|
"""Test Parquet output format using project's ParquetManager."""
|
|
|
|
def test_parquet_format_creates_parquet_files(self, tmp_path: Path):
|
|
"""--format parquet should produce .parquet files, no CSVs."""
|
|
out = tmp_path / "parquet_out"
|
|
gen = SampleDataGenerator(
|
|
size="xs", seed=42, output_dir=out, output_format="parquet",
|
|
)
|
|
gen.run()
|
|
|
|
parquet_files = {p.stem for p in out.glob("*.parquet")}
|
|
csv_files = list(out.glob("*.csv"))
|
|
expected = {
|
|
"customers", "products", "campaigns", "web_sessions",
|
|
"web_leads", "orders", "order_items", "payments",
|
|
"support_tickets",
|
|
}
|
|
assert expected == parquet_files
|
|
assert csv_files == [], "CSV files should be cleaned up in parquet mode"
|
|
|
|
def test_parquet_has_correct_types(self, tmp_path: Path):
|
|
"""Parquet files should have proper column types from ParquetManager."""
|
|
import duckdb
|
|
|
|
out = tmp_path / "typed"
|
|
gen = SampleDataGenerator(
|
|
size="xs", seed=42, output_dir=out, output_format="parquet",
|
|
)
|
|
gen.run()
|
|
|
|
con = duckdb.connect()
|
|
# orders.created_at should be TIMESTAMP, not VARCHAR
|
|
schema = con.execute(
|
|
f"DESCRIBE SELECT * FROM read_parquet('{out}/orders.parquet')"
|
|
).fetchall()
|
|
col_types = {row[0]: row[1] for row in schema}
|
|
assert col_types["created_at"] == "TIMESTAMP"
|
|
assert col_types["total_amount"] == "DOUBLE"
|
|
|
|
# customers.registration_date should be DATE
|
|
schema = con.execute(
|
|
f"DESCRIBE SELECT * FROM read_parquet('{out}/customers.parquet')"
|
|
).fetchall()
|
|
col_types = {row[0]: row[1] for row in schema}
|
|
assert col_types["registration_date"] == "DATE"
|
|
|
|
def test_both_format_creates_csv_and_parquet(self, tmp_path: Path):
|
|
"""--format both should produce CSVs + parquet/ subdirectory."""
|
|
out = tmp_path / "both_out"
|
|
gen = SampleDataGenerator(
|
|
size="xs", seed=42, output_dir=out, output_format="both",
|
|
)
|
|
gen.run()
|
|
|
|
csv_files = list(out.glob("*.csv"))
|
|
parquet_files = list((out / "parquet").glob("*.parquet"))
|
|
assert len(csv_files) == 9
|
|
assert len(parquet_files) == 9
|