feat(pull): range-chunked parallel download for single large files
When the server advertises `accept-ranges: bytes` and a parquet exceeds `AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), `stream_download` now splits the file into N parallel HTTP Range requests (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles the parts into the destination atomically. Targets the per-flow-shaped network (corp VPN with per-TCP-connection rate-limiting) where single-stream throughput is throttled but N parallel streams over the same connection scale roughly linearly. Manifests with 1 large materialized parquet + N remote tables previously left the existing across-files `AGNES_PULL_PARALLELISM=4` pool with 1 active worker = single-stream throughput; this fixes that. Falls back to single-stream when: - HEAD doesn't advertise `accept-ranges: bytes` - Server returns 200 instead of 206 to a Range probe - File size below the threshold Cleanup discipline: every part file removed before return (success or failure); destination written via `<target>.tmp` and renamed atomically. Per-chunk retry on transient network blips (bounded by AGNES_STREAM_RETRIES).
This commit is contained in:
parent
f598b7e2f6
commit
dee33fe25b
3 changed files with 633 additions and 37 deletions
15
CHANGELOG.md
15
CHANGELOG.md
|
|
@ -10,6 +10,21 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
- **`agnes pull` chunked download for large parquets**: when the server
|
||||||
|
advertises `accept-ranges: bytes` and a parquet exceeds
|
||||||
|
`AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), the CLI now splits
|
||||||
|
the file into N parallel HTTP Range requests
|
||||||
|
(`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles
|
||||||
|
the parts into the destination atomically. Targets the per-flow-shaped
|
||||||
|
network (corp VPN with per-TCP-connection rate-limiting) where a single
|
||||||
|
stream is throttled but N parallel streams over the same connection
|
||||||
|
scale roughly linearly. Falls back to single-stream when the server
|
||||||
|
responds 200 instead of 206 to a Range probe, when no
|
||||||
|
`accept-ranges: bytes` is advertised, or when content is below the
|
||||||
|
threshold — no behavior change in the small-file / non-cooperating-
|
||||||
|
server cases.
|
||||||
|
|
||||||
## [0.38.0] — 2026-05-06
|
## [0.38.0] — 2026-05-06
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
312
cli/client.py
312
cli/client.py
|
|
@ -3,6 +3,7 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -22,6 +23,18 @@ _RETRY_BACKOFFS_S = (0.3, 1.0, 3.0) # seconds before attempt 2, 3, 4
|
||||||
# timeout dies long before BQ finishes. Operators tune via AGNES_QUERY_TIMEOUT.
|
# timeout dies long before BQ finishes. Operators tune via AGNES_QUERY_TIMEOUT.
|
||||||
QUERY_TIMEOUT_S = float(os.environ.get("AGNES_QUERY_TIMEOUT", "300"))
|
QUERY_TIMEOUT_S = float(os.environ.get("AGNES_QUERY_TIMEOUT", "300"))
|
||||||
|
|
||||||
|
# Range-chunked parallel download — see `stream_download` docstring. Defaults
|
||||||
|
# tuned for the corp-VPN per-flow rate-limiting case (single-stream throttled
|
||||||
|
# but N parallel range requests scale linearly). Disabled implicitly for
|
||||||
|
# files below the threshold or when the server doesn't advertise byte-range
|
||||||
|
# support. Operators can hard-disable by setting parallelism to 1.
|
||||||
|
_CHUNK_PARALLELISM = max(1, min(16, int(
|
||||||
|
os.environ.get("AGNES_PULL_CHUNK_PARALLELISM", "4"),
|
||||||
|
)))
|
||||||
|
_CHUNK_THRESHOLD_BYTES = int(
|
||||||
|
os.environ.get("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(50 * 1024 * 1024)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Transport-error translation ─────────────────────────────────────────
|
# ── Transport-error translation ─────────────────────────────────────────
|
||||||
# Pavel's Issue #185 Phase 3B caught the failure mode: when httpx raises
|
# Pavel's Issue #185 Phase 3B caught the failure mode: when httpx raises
|
||||||
|
|
@ -197,33 +210,197 @@ def _is_transient(exc: Exception) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
def _read_chunk_threshold_bytes() -> int:
|
||||||
"""Stream a file to `target_path` atomically and with retries.
|
"""Re-read threshold each call so tests / operators can flip it via
|
||||||
|
env var without restarting the process."""
|
||||||
|
try:
|
||||||
|
return int(os.environ.get(
|
||||||
|
"AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(_CHUNK_THRESHOLD_BYTES),
|
||||||
|
))
|
||||||
|
except ValueError:
|
||||||
|
return _CHUNK_THRESHOLD_BYTES
|
||||||
|
|
||||||
Durability properties:
|
|
||||||
- Writes to `target_path + ".tmp"`, then `os.replace` on success. The
|
def _read_chunk_parallelism() -> int:
|
||||||
real target file never exists in a half-written state.
|
"""Re-read parallelism each call (same rationale as threshold). Floor 1,
|
||||||
- Retries up to `_RETRY_ATTEMPTS` times on transient errors (network
|
ceiling 16."""
|
||||||
blip, 5xx); 4xx (auth/404) is raised immediately.
|
try:
|
||||||
- No hash check here — that's done in the sync command against the
|
n = int(os.environ.get(
|
||||||
manifest hash, because only the caller knows the expected value.
|
"AGNES_PULL_CHUNK_PARALLELISM", str(_CHUNK_PARALLELISM),
|
||||||
|
))
|
||||||
|
except ValueError:
|
||||||
|
n = _CHUNK_PARALLELISM
|
||||||
|
return max(1, min(16, n))
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_range_support(client: httpx.Client, path: str) -> tuple[int, bool]:
|
||||||
|
"""Send HEAD; return (content-length, accepts-byte-ranges).
|
||||||
|
|
||||||
|
`(0, False)` means "we couldn't tell — fall back to single-stream".
|
||||||
|
Never raises; transport errors during the probe are treated as
|
||||||
|
"no chunking, try the GET instead and let it surface the failure
|
||||||
|
in the normal retry loop".
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
resp = client.head(path)
|
||||||
|
if getattr(resp, "status_code", 200) >= 400:
|
||||||
|
return (0, False)
|
||||||
|
size = int(resp.headers.get("content-length", "0") or 0)
|
||||||
|
accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes")
|
||||||
|
return (size, accepts)
|
||||||
|
except Exception:
|
||||||
|
return (0, False)
|
||||||
|
|
||||||
|
|
||||||
|
class _RangeNotHonored(Exception):
|
||||||
|
"""Internal sentinel — server returned 200 instead of 206 to a Range
|
||||||
|
request. Caller catches and falls back to the single-stream path."""
|
||||||
|
|
||||||
|
|
||||||
|
def _download_chunk(
|
||||||
|
client: httpx.Client,
|
||||||
|
path: str,
|
||||||
|
start: int,
|
||||||
|
end: int,
|
||||||
|
part_path: Path,
|
||||||
|
progress_callback,
|
||||||
|
) -> None:
|
||||||
|
"""Stream `bytes=start-end` to `part_path`. Caller deals with retry +
|
||||||
|
cleanup. Raises on any failure (HTTPStatusError on non-206 response,
|
||||||
|
httpx.* on transport blip, `_RangeNotHonored` if server returned 200
|
||||||
|
instead of 206 — chunked path can't trust that result)."""
|
||||||
|
headers = {"Range": f"bytes={start}-{end}"}
|
||||||
|
with client.stream("GET", path, headers=headers) as response:
|
||||||
|
# Server didn't honor the Range — RFC says it MAY return 200 with
|
||||||
|
# the full body. We can't safely splice that into one part of N,
|
||||||
|
# so we abort the whole chunked path and let the caller fall back.
|
||||||
|
if response.status_code == 200:
|
||||||
|
raise _RangeNotHonored()
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(part_path, "wb") as f:
|
||||||
|
for piece in response.iter_bytes(chunk_size=65536):
|
||||||
|
f.write(piece)
|
||||||
|
if progress_callback and piece:
|
||||||
|
progress_callback(len(piece))
|
||||||
|
|
||||||
|
|
||||||
|
def _download_chunked(
|
||||||
|
client: httpx.Client,
|
||||||
|
path: str,
|
||||||
|
target_path: str,
|
||||||
|
total_size: int,
|
||||||
|
parallelism: int,
|
||||||
|
progress_callback,
|
||||||
|
) -> int:
|
||||||
|
"""Range-based parallel download. Returns total bytes written.
|
||||||
|
|
||||||
|
Raises `_RangeNotHonored` on the first 200-instead-of-206 response so
|
||||||
|
the caller can fall back. All other exceptions propagate.
|
||||||
|
|
||||||
|
Cleanup discipline: every part file we create gets removed before
|
||||||
|
return (success or failure). The destination is written via the
|
||||||
|
caller's `<target>.tmp` and renamed atomically.
|
||||||
|
"""
|
||||||
|
target = Path(target_path)
|
||||||
|
tmp_path = Path(f"{target_path}.tmp")
|
||||||
|
parallelism = max(1, parallelism)
|
||||||
|
# Build chunks — last chunk takes the remainder.
|
||||||
|
chunk_size = total_size // parallelism
|
||||||
|
if chunk_size <= 0:
|
||||||
|
chunk_size = total_size # tiny file, single chunk
|
||||||
|
parallelism = 1
|
||||||
|
ranges = []
|
||||||
|
for i in range(parallelism):
|
||||||
|
start = i * chunk_size
|
||||||
|
end = (start + chunk_size - 1) if i < parallelism - 1 else (total_size - 1)
|
||||||
|
ranges.append((i, start, end))
|
||||||
|
|
||||||
|
part_paths = [Path(f"{target_path}.part{i}") for i, _, _ in ranges]
|
||||||
|
# Pre-clean any leftovers from a prior run.
|
||||||
|
for p in part_paths:
|
||||||
|
p.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _attempt_chunk(i: int, start: int, end: int) -> None:
|
||||||
|
last_exc: Optional[Exception] = None
|
||||||
|
for attempt in range(_RETRY_ATTEMPTS + 1):
|
||||||
|
try:
|
||||||
|
_download_chunk(
|
||||||
|
client, path, start, end, part_paths[i],
|
||||||
|
progress_callback,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except _RangeNotHonored:
|
||||||
|
# Don't retry — server policy, not a transport blip.
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt == _RETRY_ATTEMPTS or not _is_transient(exc):
|
||||||
|
break
|
||||||
|
time.sleep(_RETRY_BACKOFFS_S[
|
||||||
|
min(attempt, len(_RETRY_BACKOFFS_S) - 1)
|
||||||
|
])
|
||||||
|
assert last_exc is not None
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
if parallelism == 1:
|
||||||
|
_attempt_chunk(*ranges[0])
|
||||||
|
else:
|
||||||
|
# Use a thread pool so each chunk gets its own concurrent
|
||||||
|
# request slot on the (HTTP/2-multiplexed when available)
|
||||||
|
# shared client. httpx.Client is thread-safe for stream().
|
||||||
|
with ThreadPoolExecutor(max_workers=parallelism) as ex:
|
||||||
|
futs = [ex.submit(_attempt_chunk, *r) for r in ranges]
|
||||||
|
for fut in as_completed(futs):
|
||||||
|
fut.result() # propagate first error
|
||||||
|
|
||||||
|
# Concatenate parts → tmp_path → atomic rename.
|
||||||
|
tmp_path.unlink(missing_ok=True)
|
||||||
|
total_written = 0
|
||||||
|
with open(tmp_path, "wb") as out:
|
||||||
|
for p in part_paths:
|
||||||
|
with open(p, "rb") as inp:
|
||||||
|
while True:
|
||||||
|
block = inp.read(65536)
|
||||||
|
if not block:
|
||||||
|
break
|
||||||
|
out.write(block)
|
||||||
|
total_written += len(block)
|
||||||
|
os.replace(tmp_path, target)
|
||||||
|
return total_written
|
||||||
|
finally:
|
||||||
|
# Always clean up part files + any stray tmp.
|
||||||
|
for p in part_paths:
|
||||||
|
p.unlink(missing_ok=True)
|
||||||
|
if tmp_path.exists():
|
||||||
|
try:
|
||||||
|
tmp_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _download_single_stream(
|
||||||
|
client: httpx.Client,
|
||||||
|
path: str,
|
||||||
|
target_path: str,
|
||||||
|
progress_callback,
|
||||||
|
) -> int:
|
||||||
|
"""Original single-stream path with retry. Used when chunking is
|
||||||
|
disabled (small file, no range support, or fallback after 200-on-Range)."""
|
||||||
tmp_path = Path(f"{target_path}.tmp")
|
tmp_path = Path(f"{target_path}.tmp")
|
||||||
last_exc: Optional[Exception] = None
|
last_exc: Optional[Exception] = None
|
||||||
for attempt in range(_RETRY_ATTEMPTS + 1):
|
for attempt in range(_RETRY_ATTEMPTS + 1):
|
||||||
try:
|
try:
|
||||||
tmp_path.unlink(missing_ok=True)
|
tmp_path.unlink(missing_ok=True)
|
||||||
with get_client(timeout=300.0) as client:
|
with client.stream("GET", path) as response:
|
||||||
with client.stream("GET", path) as response:
|
response.raise_for_status()
|
||||||
response.raise_for_status()
|
total = 0
|
||||||
total = 0
|
with open(tmp_path, "wb") as f:
|
||||||
with open(tmp_path, "wb") as f:
|
for chunk in response.iter_bytes(chunk_size=65536):
|
||||||
for chunk in response.iter_bytes(chunk_size=65536):
|
f.write(chunk)
|
||||||
f.write(chunk)
|
total += len(chunk)
|
||||||
total += len(chunk)
|
if progress_callback:
|
||||||
if progress_callback:
|
progress_callback(len(chunk))
|
||||||
progress_callback(len(chunk))
|
|
||||||
# os.replace is atomic on POSIX and Windows for same-filesystem moves.
|
|
||||||
os.replace(tmp_path, target_path)
|
os.replace(tmp_path, target_path)
|
||||||
return total
|
return total
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
@ -231,23 +408,84 @@ def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
||||||
if attempt == _RETRY_ATTEMPTS or not _is_transient(exc):
|
if attempt == _RETRY_ATTEMPTS or not _is_transient(exc):
|
||||||
break
|
break
|
||||||
time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)])
|
time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)])
|
||||||
# Clean up any leftover tmp, then surface the last exception. Translate
|
|
||||||
# transport errors (timeouts, connection drops, protocol errors) to
|
|
||||||
# AgnesTransportError so the CLI prints a clean message instead of a
|
|
||||||
# Python traceback (Pavel's #185 Phase 3B). HTTPStatusError (4xx/5xx
|
|
||||||
# response from the server) is NOT a transport failure and must
|
|
||||||
# re-raise verbatim so the caller's status-code handling + the rich
|
|
||||||
# server error body (e.g. 401 with "token expired", 403 with
|
|
||||||
# cross_project_forbidden detail) reach the analyst — Devin Review on
|
|
||||||
# PR #188 caught: HTTPStatusError is a subclass of HTTPError, so the
|
|
||||||
# generic isinstance(HTTPError) translation was eating status codes.
|
|
||||||
tmp_path.unlink(missing_ok=True)
|
tmp_path.unlink(missing_ok=True)
|
||||||
assert last_exc is not None
|
assert last_exc is not None
|
||||||
if isinstance(last_exc, httpx.HTTPStatusError):
|
|
||||||
raise last_exc
|
|
||||||
if isinstance(last_exc, httpx.HTTPError):
|
|
||||||
raise _translate_transport_error(
|
|
||||||
last_exc, context=f"GET {path} (stream → {target_path})",
|
|
||||||
timeout_s=300.0,
|
|
||||||
) from last_exc
|
|
||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
|
|
||||||
|
def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
||||||
|
"""Stream a file to `target_path` atomically and with retries.
|
||||||
|
|
||||||
|
Two paths:
|
||||||
|
1. **Chunked parallel** — when the server advertises `accept-ranges:
|
||||||
|
bytes` and `content-length` exceeds `AGNES_PULL_CHUNK_THRESHOLD_BYTES`
|
||||||
|
(default 50 MB), split into N range requests
|
||||||
|
(`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and
|
||||||
|
download in parallel. Concatenate the part files into `<target>.tmp`,
|
||||||
|
then `os.replace`. Falls back to single-stream if the server
|
||||||
|
responds 200 instead of 206 to a Range probe.
|
||||||
|
2. **Single-stream** — for small files, no range support, or fallback
|
||||||
|
from the chunked path. Same atomic-rename + retry semantics as
|
||||||
|
before.
|
||||||
|
|
||||||
|
Durability properties (unchanged):
|
||||||
|
- Writes to `<target>.tmp`, then `os.replace` on success. The real
|
||||||
|
target file never exists in a half-written state.
|
||||||
|
- Retries up to `_RETRY_ATTEMPTS` on transient errors (network blip,
|
||||||
|
5xx); 4xx (auth/404) is raised immediately.
|
||||||
|
- No hash check here — that's the caller's job (manifest hash).
|
||||||
|
|
||||||
|
Threading: the chunked path uses a ThreadPoolExecutor sized to the
|
||||||
|
parallelism. httpx.Client.stream() is safe to call concurrently from
|
||||||
|
multiple threads on a single client (the connection pool serializes
|
||||||
|
the underlying socket access).
|
||||||
|
"""
|
||||||
|
with get_client(timeout=300.0) as client:
|
||||||
|
return _stream_download_via(client, path, target_path, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_download_via(
|
||||||
|
client: httpx.Client,
|
||||||
|
path: str,
|
||||||
|
target_path: str,
|
||||||
|
progress_callback,
|
||||||
|
) -> int:
|
||||||
|
"""The shared body of `stream_download` parameterized on the client.
|
||||||
|
Split out so tests can inject a fake client."""
|
||||||
|
threshold = _read_chunk_threshold_bytes()
|
||||||
|
parallelism = _read_chunk_parallelism()
|
||||||
|
|
||||||
|
total_size = 0
|
||||||
|
accepts_ranges = False
|
||||||
|
if parallelism > 1:
|
||||||
|
total_size, accepts_ranges = _probe_range_support(client, path)
|
||||||
|
|
||||||
|
use_chunked = (
|
||||||
|
parallelism > 1
|
||||||
|
and accepts_ranges
|
||||||
|
and total_size > threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if use_chunked:
|
||||||
|
try:
|
||||||
|
return _download_chunked(
|
||||||
|
client, path, target_path, total_size, parallelism,
|
||||||
|
progress_callback,
|
||||||
|
)
|
||||||
|
except _RangeNotHonored:
|
||||||
|
# Server lied / proxy stripped the Range — fall through.
|
||||||
|
pass
|
||||||
|
return _download_single_stream(
|
||||||
|
client, path, target_path, progress_callback,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
# 4xx / 5xx response from the server — re-raise verbatim so the
|
||||||
|
# caller's status-code handling + the rich server error body
|
||||||
|
# reach the analyst (Devin Review on PR #188).
|
||||||
|
raise
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise _translate_transport_error(
|
||||||
|
exc, context=f"GET {path} (stream → {target_path})",
|
||||||
|
timeout_s=300.0,
|
||||||
|
) from exc
|
||||||
|
|
|
||||||
343
tests/test_pull_chunked.py
Normal file
343
tests/test_pull_chunked.py
Normal file
|
|
@ -0,0 +1,343 @@
|
||||||
|
"""Tests for range-based chunked download in cli/client.py:stream_download.
|
||||||
|
|
||||||
|
Background — the previous diagnosis measured `agnes pull` on a single 5.1 GB
|
||||||
|
materialized parquet at 0.29 MB/s on a corp VPN with per-flow rate-limiting;
|
||||||
|
4 parallel range requests over the same connection sustained 1.65 MB/s
|
||||||
|
aggregate. Existing `AGNES_PULL_PARALLELISM=4` parallelizes across files,
|
||||||
|
not within a file, so a manifest with 1 large materialized parquet + 10
|
||||||
|
remote tables yields 1 active worker = single-stream throughput.
|
||||||
|
|
||||||
|
These tests exercise the chunking code path: HEAD probe, Range-request
|
||||||
|
splitting, fallback when the server doesn't honor ranges, cleanup on
|
||||||
|
chunk failure, and the small-file bypass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake HTTP layer ─────────────────────────────────────────────────────
|
||||||
|
# The real httpx Client / AsyncClient surface is large; we mock at the
|
||||||
|
# client-method level. Our `stream_download` should:
|
||||||
|
# 1. Call HEAD to learn `content-length` + `accept-ranges`.
|
||||||
|
# 2. If ranges supported and size > threshold, issue N parallel
|
||||||
|
# `GET` with `Range: bytes=A-B`, each returning 206 + body chunk.
|
||||||
|
# 3. Concatenate part files into the destination.
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, status_code: int, headers: dict | None = None,
|
||||||
|
body: bytes = b""):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.headers = headers or {}
|
||||||
|
self._body = body
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code >= 400:
|
||||||
|
import httpx
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"HTTP {self.status_code}", request=None, response=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
def iter_bytes(self, chunk_size: int = 65536):
|
||||||
|
# Yield in chunk_size pieces so the sink loop runs realistically.
|
||||||
|
for i in range(0, len(self._body), chunk_size):
|
||||||
|
yield self._body[i:i + chunk_size]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
"""Captures calls + returns canned responses."""
|
||||||
|
|
||||||
|
def __init__(self, *, body: bytes, accept_ranges: bool = True,
|
||||||
|
reject_range_with_200: bool = False,
|
||||||
|
fail_chunk_indices: tuple[int, ...] = (),
|
||||||
|
head_status: int = 200):
|
||||||
|
self._body = body
|
||||||
|
self._accept_ranges = accept_ranges
|
||||||
|
self._reject_range_with_200 = reject_range_with_200
|
||||||
|
self._fail_chunk_indices = set(fail_chunk_indices)
|
||||||
|
self._head_status = head_status
|
||||||
|
self.head_calls = 0
|
||||||
|
self.range_calls: list[tuple[int, int]] = []
|
||||||
|
self.full_get_calls = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._chunk_attempt_counts: dict[tuple[int, int], int] = {}
|
||||||
|
|
||||||
|
# `stream_download` calls `client.head(path)` once to probe.
|
||||||
|
def head(self, path: str, **kwargs):
|
||||||
|
with self._lock:
|
||||||
|
self.head_calls += 1
|
||||||
|
if self._head_status >= 400:
|
||||||
|
return _FakeResponse(self._head_status)
|
||||||
|
headers = {"content-length": str(len(self._body))}
|
||||||
|
if self._accept_ranges:
|
||||||
|
headers["accept-ranges"] = "bytes"
|
||||||
|
return _FakeResponse(200, headers=headers)
|
||||||
|
|
||||||
|
# `stream_download` uses `client.stream("GET", path, headers=...)`
|
||||||
|
# for both the chunked and full-file paths. Range header presence
|
||||||
|
# tells us which one.
|
||||||
|
def stream(self, method: str, path: str, *, headers: dict | None = None,
|
||||||
|
**kwargs):
|
||||||
|
rng = (headers or {}).get("Range") or (headers or {}).get("range")
|
||||||
|
if rng:
|
||||||
|
# bytes=START-END
|
||||||
|
spec = rng.split("=", 1)[1]
|
||||||
|
start_s, end_s = spec.split("-", 1)
|
||||||
|
start = int(start_s)
|
||||||
|
end = int(end_s)
|
||||||
|
with self._lock:
|
||||||
|
self.range_calls.append((start, end))
|
||||||
|
key = (start, end)
|
||||||
|
attempt = self._chunk_attempt_counts.get(key, 0)
|
||||||
|
self._chunk_attempt_counts[key] = attempt + 1
|
||||||
|
# Determine chunk index (in order of unique starts).
|
||||||
|
# We map by start to a stable index for fail-injection.
|
||||||
|
chunk_idx = self._chunk_index_for_start(start)
|
||||||
|
# Should this attempt fail? Fail only on first attempt for
|
||||||
|
# listed indices — retry succeeds.
|
||||||
|
if chunk_idx in self._fail_chunk_indices and attempt == 0:
|
||||||
|
import httpx
|
||||||
|
raise httpx.ReadError("simulated chunk failure")
|
||||||
|
if self._reject_range_with_200:
|
||||||
|
# Server ignored Range — returns full body with 200.
|
||||||
|
return _FakeResponse(200, body=self._body)
|
||||||
|
piece = self._body[start:end + 1]
|
||||||
|
return _FakeResponse(
|
||||||
|
206,
|
||||||
|
headers={"content-range": f"bytes {start}-{end}/{len(self._body)}"},
|
||||||
|
body=piece,
|
||||||
|
)
|
||||||
|
# Full-file GET (single-stream fallback).
|
||||||
|
with self._lock:
|
||||||
|
self.full_get_calls += 1
|
||||||
|
return _FakeResponse(200, body=self._body)
|
||||||
|
|
||||||
|
def _chunk_index_for_start(self, start: int) -> int:
|
||||||
|
# Unique sorted starts so fail_chunk_indices is deterministic.
|
||||||
|
starts = sorted({s for s, _ in self.range_calls})
|
||||||
|
try:
|
||||||
|
return starts.index(start)
|
||||||
|
except ValueError:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test fixtures ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_config_dir(tmp_path, monkeypatch):
|
||||||
|
cfg = tmp_path / "_cfg"
|
||||||
|
cfg.mkdir()
|
||||||
|
monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_shared_client(monkeypatch):
|
||||||
|
"""Make sure no shared persistent client leaks between tests.
|
||||||
|
|
||||||
|
The shared persistent client is introduced in Change 2 (separate
|
||||||
|
commit). When this fixture runs against the post-Change-2 tree, it
|
||||||
|
reaches the module attribute; under Change 1 alone the attribute
|
||||||
|
doesn't exist yet, so we tolerate that.
|
||||||
|
"""
|
||||||
|
import cli.client as cc
|
||||||
|
if hasattr(cc, "_SHARED_CLIENT"):
|
||||||
|
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
|
||||||
|
yield
|
||||||
|
if hasattr(cc, "_SHARED_CLIENT"):
|
||||||
|
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_chunked_download_success(tmp_path, monkeypatch):
|
||||||
|
"""Server advertises ranges, file is large enough — 4 chunks, assembled
|
||||||
|
correctly into target."""
|
||||||
|
body = bytes(range(256)) * 2048 # 512 KB
|
||||||
|
threshold = 1024 # 1 KB so 512 KB is "large"
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(threshold))
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=True)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.parquet"
|
||||||
|
progress_bytes = []
|
||||||
|
total = stream_download("/api/data/x/download", str(target),
|
||||||
|
progress_callback=lambda n: progress_bytes.append(n))
|
||||||
|
|
||||||
|
assert total == len(body)
|
||||||
|
assert target.read_bytes() == body
|
||||||
|
# 4 distinct ranges issued (no overlaps; last one carries remainder).
|
||||||
|
assert len(set(fake.range_calls)) == 4
|
||||||
|
assert fake.head_calls == 1
|
||||||
|
assert fake.full_get_calls == 0
|
||||||
|
# Progress callback was called and total bytes match.
|
||||||
|
assert sum(progress_bytes) == len(body)
|
||||||
|
# Chunk parts cleaned up.
|
||||||
|
leftovers = list(tmp_path.glob("*.part*"))
|
||||||
|
assert leftovers == [], f"orphan part files: {leftovers}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_download_fallback_when_server_ignores_range(
|
||||||
|
tmp_path, monkeypatch,
|
||||||
|
):
|
||||||
|
"""Server returns 200 instead of 206 on the first range probe — abort
|
||||||
|
chunked path, fall back to single-stream. No corrupt output."""
|
||||||
|
body = b"X" * 200_000
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
|
||||||
|
# accept_ranges=True (HEAD lies), but every Range GET returns 200
|
||||||
|
# with the full body — that's the "server ignored Range" path.
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=True,
|
||||||
|
reject_range_with_200=True)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
total = stream_download("/api/data/x/download", str(target))
|
||||||
|
|
||||||
|
assert total == len(body)
|
||||||
|
assert target.read_bytes() == body
|
||||||
|
# Fell back to a single full-body GET.
|
||||||
|
assert fake.full_get_calls >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_small_file_uses_single_stream_path(tmp_path, monkeypatch):
|
||||||
|
"""Below threshold → no HEAD probe needed (or HEAD short-circuits),
|
||||||
|
no Range requests, plain single-stream download."""
|
||||||
|
body = b"x" * 500 # tiny
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "10000") # 10 KB
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=True)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
total = stream_download("/api/data/x/download", str(target))
|
||||||
|
|
||||||
|
assert total == len(body)
|
||||||
|
assert target.read_bytes() == body
|
||||||
|
assert fake.range_calls == [], "small file must not split into ranges"
|
||||||
|
assert fake.full_get_calls >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_download_no_accept_ranges_falls_back(tmp_path, monkeypatch):
|
||||||
|
"""HEAD doesn't advertise byte-range support → skip chunked path,
|
||||||
|
plain single-stream."""
|
||||||
|
body = b"y" * 200_000
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=False)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
total = stream_download("/api/data/x/download", str(target))
|
||||||
|
|
||||||
|
assert total == len(body)
|
||||||
|
assert target.read_bytes() == body
|
||||||
|
assert fake.range_calls == []
|
||||||
|
assert fake.full_get_calls >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_download_one_chunk_retries_then_succeeds(
|
||||||
|
tmp_path, monkeypatch,
|
||||||
|
):
|
||||||
|
"""One chunk fails on first attempt; retry path completes the file."""
|
||||||
|
body = bytes(range(256)) * 1024 # 256 KB
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
monkeypatch.setenv("AGNES_STREAM_RETRIES", "2")
|
||||||
|
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=True,
|
||||||
|
fail_chunk_indices=(1,)) # second chunk blips once
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
total = stream_download("/api/data/x/download", str(target))
|
||||||
|
|
||||||
|
assert total == len(body)
|
||||||
|
assert target.read_bytes() == body
|
||||||
|
# Cleanup of all part files.
|
||||||
|
assert list(tmp_path.glob("*.part*")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_download_failure_cleans_up_part_files(tmp_path, monkeypatch):
|
||||||
|
"""All retries exhausted on a chunk → no destination file, no orphan
|
||||||
|
part files."""
|
||||||
|
body = b"z" * 200_000
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
monkeypatch.setenv("AGNES_STREAM_RETRIES", "0")
|
||||||
|
|
||||||
|
# Inject a permanent failure on chunk 2 (retries=0 → first failure
|
||||||
|
# is fatal).
|
||||||
|
class _ChronicFail(_FakeClient):
|
||||||
|
def stream(self, method, path, *, headers=None, **kwargs):
|
||||||
|
rng = (headers or {}).get("Range")
|
||||||
|
if rng:
|
||||||
|
spec = rng.split("=", 1)[1]
|
||||||
|
start = int(spec.split("-", 1)[0])
|
||||||
|
# Permanently fail the chunk starting at exactly half.
|
||||||
|
if start >= len(body) // 4 and start <= len(body) // 2:
|
||||||
|
import httpx
|
||||||
|
raise httpx.ReadError("permanent")
|
||||||
|
return super().stream(method, path, headers=headers, **kwargs)
|
||||||
|
return super().stream(method, path, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
fake = _ChronicFail(body=body, accept_ranges=True)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
stream_download("/api/data/x/download", str(target))
|
||||||
|
|
||||||
|
assert not target.exists(), "no destination file after total failure"
|
||||||
|
# No orphan parts.
|
||||||
|
assert list(tmp_path.glob("*.part*")) == []
|
||||||
|
assert not (tmp_path / "out.bin.tmp").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress_callback_aggregates_across_chunks(tmp_path, monkeypatch):
|
||||||
|
"""The progress callback should fire with byte deltas summing to the
|
||||||
|
full file across all chunks — caller treats one file as one task."""
|
||||||
|
body = bytes(range(256)) * 4096 # 1 MB
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||||
|
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||||
|
|
||||||
|
fake = _FakeClient(body=body, accept_ranges=True)
|
||||||
|
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||||
|
|
||||||
|
from cli.client import stream_download
|
||||||
|
target = tmp_path / "out.bin"
|
||||||
|
advances = []
|
||||||
|
stream_download("/api/data/x/download", str(target),
|
||||||
|
progress_callback=lambda n: advances.append(n))
|
||||||
|
assert sum(advances) == len(body)
|
||||||
Loading…
Reference in a new issue