agnes-the-ai-analyst/tests/test_pull_chunked.py
ZdenekSrotyr 77d88014df fix: devil's advocate R3 — reap PID-suffixed leftovers from dead processes
R3 final pass surfaced one issue, addressed:

R2#2 introduced PID-suffixed <target>.{pid}.tmp / .{pid}.partN to
prevent concurrent agnes pull invocations from yanking each other's
in-progress writes. The pre-clean inside _download_chunked /
_download_single_stream only deletes leftovers from the CURRENT
process's PID — files from a SIGKILL'd or crashed prior pull (any
other PID) are never touched and accumulate on disk forever.

Add _reap_dead_pid_leftovers(target_path) called at the start of both
download paths. Globs <target>.*.tmp / <target>.*.partN, extracts the
embedded PID, calls os.kill(pid, 0) to test liveness (POSIX standard
no-op probe), and unlinks files whose process no longer exists.
Permission-denied = process is alive but owned by another user → keep
the file (conservative). Windows users get the conservative 'keep'
default.

Two new tests pin the behavior — live-PID file preserved, dead-PID
.tmp + .partN reaped, bare-name (legacy) untouched, garbage filenames
skipped without raise.
2026-05-06 14:04:47 +02:00

399 lines
16 KiB
Python

"""Tests for range-based chunked download in cli/client.py:stream_download.
Background — the previous diagnosis measured `agnes pull` on a single 5.1 GB
materialized parquet at 0.29 MB/s on a corp VPN with per-flow rate-limiting;
4 parallel range requests over the same connection sustained 1.65 MB/s
aggregate. Existing `AGNES_PULL_PARALLELISM=4` parallelizes across files,
not within a file, so a manifest with 1 large materialized parquet + 10
remote tables yields 1 active worker = single-stream throughput.
These tests exercise the chunking code path: HEAD probe, Range-request
splitting, fallback when the server doesn't honor ranges, cleanup on
chunk failure, and the small-file bypass.
"""
from __future__ import annotations
import os
import threading
from pathlib import Path
from unittest.mock import patch
import pytest
# ── Fake HTTP layer ─────────────────────────────────────────────────────
# The real httpx Client / AsyncClient surface is large; we mock at the
# client-method level. Our `stream_download` should:
# 1. Call HEAD to learn `content-length` + `accept-ranges`.
# 2. If ranges supported and size > threshold, issue N parallel
# `GET` with `Range: bytes=A-B`, each returning 206 + body chunk.
# 3. Concatenate part files into the destination.
class _FakeResponse:
def __init__(self, status_code: int, headers: dict | None = None,
body: bytes = b""):
self.status_code = status_code
self.headers = headers or {}
self._body = body
def raise_for_status(self):
if self.status_code >= 400:
import httpx
raise httpx.HTTPStatusError(
f"HTTP {self.status_code}", request=None, response=self,
)
def iter_bytes(self, chunk_size: int = 65536):
# Yield in chunk_size pieces so the sink loop runs realistically.
for i in range(0, len(self._body), chunk_size):
yield self._body[i:i + chunk_size]
def __enter__(self):
return self
def __exit__(self, *a):
return False
class _FakeClient:
"""Captures calls + returns canned responses."""
def __init__(self, *, body: bytes, accept_ranges: bool = True,
reject_range_with_200: bool = False,
fail_chunk_indices: tuple[int, ...] = (),
head_status: int = 200):
self._body = body
self._accept_ranges = accept_ranges
self._reject_range_with_200 = reject_range_with_200
self._fail_chunk_indices = set(fail_chunk_indices)
self._head_status = head_status
self.head_calls = 0
self.range_calls: list[tuple[int, int]] = []
self.full_get_calls = 0
self._lock = threading.Lock()
self._chunk_attempt_counts: dict[tuple[int, int], int] = {}
# `stream_download` calls `client.head(path)` once to probe.
def head(self, path: str, **kwargs):
with self._lock:
self.head_calls += 1
if self._head_status >= 400:
return _FakeResponse(self._head_status)
headers = {"content-length": str(len(self._body))}
if self._accept_ranges:
headers["accept-ranges"] = "bytes"
return _FakeResponse(200, headers=headers)
# `stream_download` uses `client.stream("GET", path, headers=...)`
# for both the chunked and full-file paths. Range header presence
# tells us which one.
def stream(self, method: str, path: str, *, headers: dict | None = None,
**kwargs):
rng = (headers or {}).get("Range") or (headers or {}).get("range")
if rng:
# bytes=START-END
spec = rng.split("=", 1)[1]
start_s, end_s = spec.split("-", 1)
start = int(start_s)
end = int(end_s)
with self._lock:
self.range_calls.append((start, end))
key = (start, end)
attempt = self._chunk_attempt_counts.get(key, 0)
self._chunk_attempt_counts[key] = attempt + 1
# Determine chunk index (in order of unique starts).
# We map by start to a stable index for fail-injection.
chunk_idx = self._chunk_index_for_start(start)
# Should this attempt fail? Fail only on first attempt for
# listed indices — retry succeeds.
if chunk_idx in self._fail_chunk_indices and attempt == 0:
import httpx
raise httpx.ReadError("simulated chunk failure")
if self._reject_range_with_200:
# Server ignored Range — returns full body with 200.
return _FakeResponse(200, body=self._body)
piece = self._body[start:end + 1]
return _FakeResponse(
206,
headers={"content-range": f"bytes {start}-{end}/{len(self._body)}"},
body=piece,
)
# Full-file GET (single-stream fallback).
with self._lock:
self.full_get_calls += 1
return _FakeResponse(200, body=self._body)
def _chunk_index_for_start(self, start: int) -> int:
# Unique sorted starts so fail_chunk_indices is deterministic.
starts = sorted({s for s, _ in self.range_calls})
try:
return starts.index(start)
except ValueError:
return -1
def __enter__(self):
return self
def __exit__(self, *a):
return False
def close(self):
pass
# ── Test fixtures ───────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _isolate_config_dir(tmp_path, monkeypatch):
cfg = tmp_path / "_cfg"
cfg.mkdir()
monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg))
@pytest.fixture(autouse=True)
def _reset_shared_client(monkeypatch):
"""Reset the persistent shared httpx.Client between tests so each
test starts from a known state. Tests that need to inject a fake
client also stub `_get_shared_client` directly via the
`_inject_fake_client` helper below."""
import cli.client as cc
if hasattr(cc, "_SHARED_CLIENT"):
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
yield
if hasattr(cc, "_SHARED_CLIENT"):
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
def _inject_fake_client(monkeypatch, fake):
"""Patch both client factories to return the same fake. Tests target
`_get_shared_client` (the path stream_download actually takes) and
also `get_client` so the fallback path also lands on the fake."""
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
monkeypatch.setattr("cli.client._get_shared_client",
lambda: fake, raising=False)
# ── Tests ───────────────────────────────────────────────────────────────
def test_chunked_download_success(tmp_path, monkeypatch):
"""Server advertises ranges, file is large enough — 4 chunks, assembled
correctly into target."""
body = bytes(range(256)) * 2048 # 512 KB
threshold = 1024 # 1 KB so 512 KB is "large"
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(threshold))
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
fake = _FakeClient(body=body, accept_ranges=True)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.parquet"
progress_bytes = []
total = stream_download("/api/data/x/download", str(target),
progress_callback=lambda n: progress_bytes.append(n))
assert total == len(body)
assert target.read_bytes() == body
# 4 distinct ranges issued (no overlaps; last one carries remainder).
assert len(set(fake.range_calls)) == 4
assert fake.head_calls == 1
assert fake.full_get_calls == 0
# Progress callback was called and total bytes match.
assert sum(progress_bytes) == len(body)
# Chunk parts cleaned up.
leftovers = list(tmp_path.glob("*.part*"))
assert leftovers == [], f"orphan part files: {leftovers}"
def test_chunked_download_fallback_when_server_ignores_range(
tmp_path, monkeypatch,
):
"""Server returns 200 instead of 206 on the first range probe — abort
chunked path, fall back to single-stream. No corrupt output."""
body = b"X" * 200_000
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
# accept_ranges=True (HEAD lies), but every Range GET returns 200
# with the full body — that's the "server ignored Range" path.
fake = _FakeClient(body=body, accept_ranges=True,
reject_range_with_200=True)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
total = stream_download("/api/data/x/download", str(target))
assert total == len(body)
assert target.read_bytes() == body
# Fell back to a single full-body GET.
assert fake.full_get_calls >= 1
def test_small_file_uses_single_stream_path(tmp_path, monkeypatch):
"""Below threshold → no HEAD probe needed (or HEAD short-circuits),
no Range requests, plain single-stream download."""
body = b"x" * 500 # tiny
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "10000") # 10 KB
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
fake = _FakeClient(body=body, accept_ranges=True)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
total = stream_download("/api/data/x/download", str(target))
assert total == len(body)
assert target.read_bytes() == body
assert fake.range_calls == [], "small file must not split into ranges"
assert fake.full_get_calls >= 1
def test_chunked_download_no_accept_ranges_falls_back(tmp_path, monkeypatch):
"""HEAD doesn't advertise byte-range support → skip chunked path,
plain single-stream."""
body = b"y" * 200_000
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
fake = _FakeClient(body=body, accept_ranges=False)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
total = stream_download("/api/data/x/download", str(target))
assert total == len(body)
assert target.read_bytes() == body
assert fake.range_calls == []
assert fake.full_get_calls >= 1
def test_chunked_download_one_chunk_retries_then_succeeds(
tmp_path, monkeypatch,
):
"""One chunk fails on first attempt; retry path completes the file."""
body = bytes(range(256)) * 1024 # 256 KB
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
monkeypatch.setenv("AGNES_STREAM_RETRIES", "2")
fake = _FakeClient(body=body, accept_ranges=True,
fail_chunk_indices=(1,)) # second chunk blips once
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
total = stream_download("/api/data/x/download", str(target))
assert total == len(body)
assert target.read_bytes() == body
# Cleanup of all part files.
assert list(tmp_path.glob("*.part*")) == []
def test_chunked_download_failure_cleans_up_part_files(tmp_path, monkeypatch):
"""All retries exhausted on a chunk → no destination file, no orphan
part files."""
body = b"z" * 200_000
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
monkeypatch.setenv("AGNES_STREAM_RETRIES", "0")
# Inject a permanent failure on chunk 2 (retries=0 → first failure
# is fatal).
class _ChronicFail(_FakeClient):
def stream(self, method, path, *, headers=None, **kwargs):
rng = (headers or {}).get("Range")
if rng:
spec = rng.split("=", 1)[1]
start = int(spec.split("-", 1)[0])
# Permanently fail the chunk starting at exactly half.
if start >= len(body) // 4 and start <= len(body) // 2:
import httpx
raise httpx.ReadError("permanent")
return super().stream(method, path, headers=headers, **kwargs)
return super().stream(method, path, headers=headers, **kwargs)
fake = _ChronicFail(body=body, accept_ranges=True)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
with pytest.raises(Exception):
stream_download("/api/data/x/download", str(target))
assert not target.exists(), "no destination file after total failure"
# No orphan parts.
assert list(tmp_path.glob("*.part*")) == []
assert not (tmp_path / "out.bin.tmp").exists()
def test_progress_callback_aggregates_across_chunks(tmp_path, monkeypatch):
"""The progress callback should fire with byte deltas summing to the
full file across all chunks — caller treats one file as one task."""
body = bytes(range(256)) * 4096 # 1 MB
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
fake = _FakeClient(body=body, accept_ranges=True)
_inject_fake_client(monkeypatch, fake)
from cli.client import stream_download
target = tmp_path / "out.bin"
advances = []
stream_download("/api/data/x/download", str(target),
progress_callback=lambda n: advances.append(n))
assert sum(advances) == len(body)
def test_dead_pid_leftovers_are_reaped(tmp_path, monkeypatch):
"""Devil's-advocate R3 finding #1: PID-suffixed `<target>.{pid}.tmp`
and `.partN` files from a SIGKILL'd previous pull must be reaped on
the next pull, otherwise they accumulate on disk indefinitely.
PID 1 (init) is always alive, so a file with pid=1 must NOT be
reaped. PID 99999999 (~10⁸) is essentially guaranteed not-alive on
any modern Linux/macOS — used as the dead-PID marker.
"""
target = tmp_path / "out.bin"
# Live-PID leftover (pid=1 = init, always alive). Must NOT be reaped.
live_path = tmp_path / "out.bin.1.tmp"
live_path.write_bytes(b"live process leftover")
# Dead-PID leftovers — both .tmp and .part0 forms.
dead_tmp = tmp_path / "out.bin.99999999.tmp"
dead_tmp.write_bytes(b"dead process leftover tmp")
dead_part = tmp_path / "out.bin.99999999.part0"
dead_part.write_bytes(b"dead process leftover part")
# Bare-name leftover (no PID suffix) — pre-existing pattern, NOT
# touched by the new reaper. Reaper only matches `.{digits}.tmp`
# / `.{digits}.partN` exactly.
bare_tmp = tmp_path / "out.bin.tmp"
bare_tmp.write_bytes(b"bare leftover")
from cli.client import _reap_dead_pid_leftovers
_reap_dead_pid_leftovers(str(target))
assert live_path.exists(), "live-PID leftover must be preserved"
assert not dead_tmp.exists(), "dead-PID .tmp must be reaped"
assert not dead_part.exists(), "dead-PID .partN must be reaped"
assert bare_tmp.exists(), "bare-name leftover is out of scope for the reaper"
def test_reap_handles_garbage_in_filename(tmp_path):
"""Files in the parquet dir whose names happen to glob-match but
don't conform to the PID-suffix shape must be skipped without
raising."""
target = tmp_path / "out.bin"
weird = tmp_path / "out.bin.garbage.tmp"
weird.write_bytes(b"x")
from cli.client import _reap_dead_pid_leftovers
# Must not raise even though the filename has no integer PID.
_reap_dead_pid_leftovers(str(target))
assert weird.exists(), "non-PID-shaped file must not be reaped"