stack-orchestrator/scripts/agave-container/snapshot_download.py

642 lines
23 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
cluster, probing each for available snapshots, benchmarking download speed,
and downloading from the fastest source using aria2c (16 connections by default).
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
the single-connection wget download with aria2c parallel chunked downloads.
Usage:
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
./snapshot_download.py -o /srv/kind/solana/snapshots
# Dry run — find best source, print URL
./snapshot_download.py --dry-run
# Custom RPC for cluster discovery + 32 connections
./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32
# Testnet
./snapshot_download.py -c testnet -o /data/snapshots
# Programmatic use from entrypoint.py:
from snapshot_download import download_best_snapshot
ok = download_best_snapshot("/data/snapshots")
Requirements:
- aria2c (apt install aria2)
- python3 >= 3.10 (stdlib only, no pip dependencies)
"""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from http.client import HTTPResponse
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("snapshot-download")
CLUSTER_RPC: dict[str, str] = {
"mainnet-beta": "https://api.mainnet-beta.solana.com",
"testnet": "https://api.testnet.solana.com",
"devnet": "https://api.devnet.solana.com",
}
# Snapshot filenames:
# snapshot-<slot>-<hash>.tar.zst
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
INCR_SNAP_RE: re.Pattern[str] = re.compile(
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
@dataclass
class SnapshotSource:
"""A snapshot file available from a specific RPC node."""
rpc_address: str
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
file_paths: list[str] = field(default_factory=list)
slots_diff: int = 0
latency_ms: float = 0.0
download_speed: float = 0.0 # bytes/sec
# -- JSON-RPC helpers ----------------------------------------------------------
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Handler that captures redirect Location instead of following it."""
def redirect_request(
self,
req: Request,
fp: HTTPResponse,
code: int,
msg: str,
headers: dict[str, str], # type: ignore[override]
newurl: str,
) -> None:
return None
def rpc_post(url: str, method: str, params: list[object] | None = None,
timeout: int = 25) -> object | None:
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
payload: bytes = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": method, "params": params or [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data: dict[str, object] = json.loads(resp.read())
return data.get("result")
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
log.debug("rpc_post %s %s failed: %s", url, method, e)
return None
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
"""HEAD request without following redirects.
Returns (Location header value, latency_sec) if the server returned a
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
"""
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
req = Request(url, method="HEAD")
try:
start: float = time.monotonic()
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
latency: float = time.monotonic() - start
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
location: str | None = resp.headers.get("Location")
resp.close()
return location, latency
except urllib.error.HTTPError as e:
# 3xx redirects raise HTTPError with the redirect info
latency = time.monotonic() - start # type: ignore[possibly-undefined]
location = e.headers.get("Location")
if location and 300 <= e.code < 400:
return location, latency
return None, 0.0
except (urllib.error.URLError, OSError, TimeoutError):
return None, 0.0
# -- Discovery -----------------------------------------------------------------
def get_current_slot(rpc_url: str) -> int | None:
"""Get current slot from RPC."""
result: object | None = rpc_post(rpc_url, "getSlot")
if isinstance(result, int):
return result
return None
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
"""Get all RPC node addresses from getClusterNodes."""
result: object | None = rpc_post(rpc_url, "getClusterNodes")
if not isinstance(result, list):
return []
rpc_addrs: list[str] = []
for node in result:
if not isinstance(node, dict):
continue
if version_filter is not None:
node_version: str | None = node.get("version")
if node_version and not node_version.startswith(version_filter):
continue
rpc: str | None = node.get("rpc")
if rpc:
rpc_addrs.append(rpc)
return list(set(rpc_addrs))
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
"""Extract filename and full redirect path from Location header.
Returns (filename, full_path). full_path includes any path prefix
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
"""
# Location may be absolute URL or relative path
if location.startswith("http://") or location.startswith("https://"):
# Absolute URL — extract path
from urllib.parse import urlparse
path: str = urlparse(location).path
else:
path = location
filename: str = path.rsplit("/", 1)[-1]
return filename, path
def probe_rpc_snapshot(
rpc_address: str,
current_slot: int,
) -> SnapshotSource | None:
"""Probe a single RPC node for available snapshots.
Discovery only no filtering. Returns a SnapshotSource with all available
info so the caller can decide what to keep. Filtering happens after all
probes complete, so rejected sources are still visible for debugging.
"""
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
# Full snapshot is required — every source must have one
full_location, full_latency = head_no_follow(full_url, timeout=2)
if not full_location:
return None
latency_ms: float = full_latency * 1000
full_filename, full_path = _parse_snapshot_filename(full_location)
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
if not fm:
return None
full_snap_slot: int = int(fm.group(1))
slots_diff: int = current_slot - full_snap_slot
file_paths: list[str] = [full_path]
# Also check for incremental snapshot
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
inc_location, _ = head_no_follow(inc_url, timeout=2)
if inc_location:
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
if m:
inc_base_slot: int = int(m.group(1))
# Incremental must be based on this source's full snapshot
if inc_base_slot == full_snap_slot:
file_paths.append(inc_path)
return SnapshotSource(
rpc_address=rpc_address,
file_paths=file_paths,
slots_diff=slots_diff,
latency_ms=latency_ms,
)
def discover_sources(
rpc_url: str,
current_slot: int,
max_age_slots: int,
max_latency_ms: float,
threads: int,
version_filter: str | None,
) -> list[SnapshotSource]:
"""Discover all snapshot sources, then filter.
Probing and filtering are separate: all reachable sources are collected
first so we can report what exists even if filters reject everything.
"""
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
if not rpc_nodes:
log.error("No RPC nodes found via getClusterNodes")
return []
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
all_sources: list[SnapshotSource] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
for addr in rpc_nodes
}
done: int = 0
for future in concurrent.futures.as_completed(futures):
done += 1
if done % 200 == 0:
log.info(" probed %d/%d nodes, %d reachable",
done, len(rpc_nodes), len(all_sources))
try:
result: SnapshotSource | None = future.result()
except (urllib.error.URLError, OSError, TimeoutError) as e:
log.debug("Probe failed for %s: %s", futures[future], e)
continue
if result:
all_sources.append(result)
log.info("Discovered %d reachable sources", len(all_sources))
# Apply filters
filtered: list[SnapshotSource] = []
rejected_age: int = 0
rejected_latency: int = 0
for src in all_sources:
if src.slots_diff > max_age_slots or src.slots_diff < -100:
rejected_age += 1
continue
if src.latency_ms > max_latency_ms:
rejected_latency += 1
continue
filtered.append(src)
if rejected_age or rejected_latency:
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
if not filtered and all_sources:
# Show what was available so the user can adjust filters
all_sources.sort(key=lambda s: s.slots_diff)
best = all_sources[0]
log.warning("All %d sources rejected by filters. Best available: "
"%s (age=%d slots, latency=%.0fms). "
"Try --max-snapshot-age %d --max-latency %.0f",
len(all_sources), best.rpc_address,
best.slots_diff, best.latency_ms,
best.slots_diff + 500,
max(best.latency_ms * 1.5, 500))
log.info("Found %d sources after filtering", len(filtered))
return filtered
# -- Speed benchmark -----------------------------------------------------------
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
"""Measure download speed from an RPC node. Returns bytes/sec."""
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
req = Request(url)
try:
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
start: float = time.monotonic()
total: int = 0
while True:
elapsed: float = time.monotonic() - start
if elapsed >= measure_time:
break
chunk: bytes = resp.read(81920)
if not chunk:
break
total += len(chunk)
elapsed = time.monotonic() - start
if elapsed <= 0:
return 0.0
return total / elapsed
except (urllib.error.URLError, OSError, TimeoutError):
return 0.0
# -- Download ------------------------------------------------------------------
def download_aria2c(
urls: list[str],
output_dir: str,
filename: str,
connections: int = 16,
) -> bool:
"""Download a file using aria2c with parallel connections.
When multiple URLs are provided, aria2c treats them as mirrors of the
same file and distributes chunks across all of them.
"""
num_mirrors: int = len(urls)
total_splits: int = max(connections, connections * num_mirrors)
cmd: list[str] = [
"aria2c",
"--file-allocation=none",
"--continue=false",
f"--max-connection-per-server={connections}",
f"--split={total_splits}",
"--min-split-size=50M",
# aria2c retries individual chunk connections on transient network
# errors (TCP reset, timeout). This is transport-level retry analogous
# to TCP retransmit, not application-level retry of a failed operation.
"--max-tries=5",
"--retry-wait=5",
"--timeout=60",
"--connect-timeout=10",
"--summary-interval=10",
"--console-log-level=notice",
f"--dir={output_dir}",
f"--out={filename}",
"--auto-file-renaming=false",
"--allow-overwrite=true",
*urls,
]
log.info("Downloading %s", filename)
log.info(" aria2c: %d connections x %d mirrors (%d splits)",
connections, num_mirrors, total_splits)
start: float = time.monotonic()
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
elapsed: float = time.monotonic() - start
if result.returncode != 0:
log.error("aria2c failed with exit code %d", result.returncode)
return False
filepath: Path = Path(output_dir) / filename
if not filepath.exists():
log.error("aria2c reported success but %s does not exist", filepath)
return False
size_bytes: int = filepath.stat().st_size
size_gb: float = size_bytes / (1024 ** 3)
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
return True
# -- Public API ----------------------------------------------------------------
def download_best_snapshot(
output_dir: str,
*,
cluster: str = "mainnet-beta",
rpc_url: str | None = None,
connections: int = 16,
threads: int = 500,
max_snapshot_age: int = 10000,
max_latency: float = 500,
min_download_speed: int = 20,
measurement_time: int = 7,
max_speed_checks: int = 15,
version_filter: str | None = None,
full_only: bool = False,
) -> bool:
"""Download the best available snapshot to output_dir.
This is the programmatic API called by entrypoint.py for automatic
snapshot download. Returns True on success, False on failure.
All parameters have sensible defaults matching the CLI interface.
"""
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
if not shutil.which("aria2c"):
log.error("aria2c not found. Install with: apt install aria2")
return False
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
current_slot: int | None = get_current_slot(resolved_rpc)
if current_slot is None:
log.error("Cannot get current slot from %s", resolved_rpc)
return False
log.info("Current slot: %d", current_slot)
sources: list[SnapshotSource] = discover_sources(
resolved_rpc, current_slot,
max_age_slots=max_snapshot_age,
max_latency_ms=max_latency,
threads=threads,
version_filter=version_filter,
)
if not sources:
log.error("No snapshot sources found")
return False
# Sort by latency (lowest first) for speed benchmarking
sources.sort(key=lambda s: s.latency_ms)
# Benchmark top candidates
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
fast_sources: list[SnapshotSource] = []
checked: int = 0
min_speed_bytes: int = min_download_speed * 1024 * 1024
for source in sources:
if checked >= max_speed_checks:
break
checked += 1
speed: float = measure_speed(source.rpc_address, measurement_time)
source.download_speed = speed
speed_mib: float = speed / (1024 ** 2)
if speed < min_speed_bytes:
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
source.rpc_address, speed_mib, min_download_speed)
continue
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
source.rpc_address, speed_mib,
source.latency_ms, source.slots_diff)
fast_sources.append(source)
if not fast_sources:
log.error("No source met minimum speed requirement (%d MiB/s)",
min_download_speed)
return False
# Use the fastest source as primary, collect mirrors for each file
best: SnapshotSource = fast_sources[0]
file_paths: list[str] = best.file_paths
if full_only:
file_paths = [fp for fp in file_paths
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
# Build mirror URL lists
download_plan: list[tuple[str, list[str]]] = []
for fp in file_paths:
filename: str = fp.rsplit("/", 1)[-1]
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
for other in fast_sources[1:]:
for other_fp in other.file_paths:
if other_fp.rsplit("/", 1)[-1] == filename:
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
break
download_plan.append((filename, mirror_urls))
speed_mib: float = best.download_speed / (1024 ** 2)
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
best.rpc_address, speed_mib, len(fast_sources))
for filename, mirror_urls in download_plan:
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
# Download
os.makedirs(output_dir, exist_ok=True)
total_start: float = time.monotonic()
for filename, mirror_urls in download_plan:
filepath: Path = Path(output_dir) / filename
if filepath.exists() and filepath.stat().st_size > 0:
log.info("Skipping %s (already exists: %.1f GB)",
filename, filepath.stat().st_size / (1024 ** 3))
continue
if not download_aria2c(mirror_urls, output_dir, filename, connections):
log.error("Failed to download %s", filename)
return False
total_elapsed: float = time.monotonic() - total_start
log.info("All downloads complete in %.0fs", total_elapsed)
for filename, _ in download_plan:
fp_path: Path = Path(output_dir) / filename
if fp_path.exists():
log.info(" %s (%.1f GB)", fp_path.name, fp_path.stat().st_size / (1024 ** 3))
return True
# -- Main (CLI) ----------------------------------------------------------------
def main() -> int:
p: argparse.ArgumentParser = argparse.ArgumentParser(
description="Download Solana snapshots with aria2c parallel downloads",
)
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
p.add_argument("-c", "--cluster", default="mainnet-beta",
choices=list(CLUSTER_RPC),
help="Solana cluster (default: mainnet-beta)")
p.add_argument("-r", "--rpc", default=None,
help="RPC URL for cluster discovery (default: public RPC)")
p.add_argument("-n", "--connections", type=int, default=16,
help="aria2c connections per download (default: 16)")
p.add_argument("-t", "--threads", type=int, default=500,
help="Threads for parallel RPC probing (default: 500)")
p.add_argument("--max-snapshot-age", type=int, default=10000,
help="Max snapshot age in slots (default: 10000)")
p.add_argument("--max-latency", type=float, default=500,
help="Max RPC probe latency in ms (default: 500)")
p.add_argument("--min-download-speed", type=int, default=20,
help="Min download speed in MiB/s (default: 20)")
p.add_argument("--measurement-time", type=int, default=7,
help="Speed measurement duration in seconds (default: 7)")
p.add_argument("--max-speed-checks", type=int, default=15,
help="Max nodes to benchmark before giving up (default: 15)")
p.add_argument("--version", default=None,
help="Filter nodes by version prefix (e.g. '2.2')")
p.add_argument("--full-only", action="store_true",
help="Download only full snapshot, skip incremental")
p.add_argument("--dry-run", action="store_true",
help="Find best source and print URL, don't download")
p.add_argument("--post-cmd",
help="Shell command to run after successful download "
"(e.g. 'kubectl scale deployment ... --replicas=1')")
p.add_argument("-v", "--verbose", action="store_true")
args: argparse.Namespace = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
# Dry-run uses the original inline flow (needs access to sources for URL printing)
if args.dry_run:
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
current_slot: int | None = get_current_slot(rpc_url)
if current_slot is None:
log.error("Cannot get current slot from %s", rpc_url)
return 1
sources: list[SnapshotSource] = discover_sources(
rpc_url, current_slot,
max_age_slots=args.max_snapshot_age,
max_latency_ms=args.max_latency,
threads=args.threads,
version_filter=args.version,
)
if not sources:
log.error("No snapshot sources found")
return 1
sources.sort(key=lambda s: s.latency_ms)
best = sources[0]
for fp in best.file_paths:
print(f"http://{best.rpc_address}{fp}")
return 0
ok: bool = download_best_snapshot(
args.output,
cluster=args.cluster,
rpc_url=args.rpc,
connections=args.connections,
threads=args.threads,
max_snapshot_age=args.max_snapshot_age,
max_latency=args.max_latency,
min_download_speed=args.min_download_speed,
measurement_time=args.measurement_time,
max_speed_checks=args.max_speed_checks,
version_filter=args.version,
full_only=args.full_only,
)
if ok and args.post_cmd:
log.info("Running post-download command: %s", args.post_cmd)
result: subprocess.CompletedProcess[bytes] = subprocess.run(
args.post_cmd, shell=True,
)
if result.returncode != 0:
log.error("Post-download command failed with exit code %d",
result.returncode)
return 1
log.info("Post-download command completed successfully")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())