stack-orchestrator/agave-stack/scripts/snapshot-download.py

547 lines
20 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
cluster, probing each for available snapshots, benchmarking download speed,
and downloading from the fastest source using aria2c (16 connections by default).
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
the single-connection wget download with aria2c parallel chunked downloads.
Usage:
# Download to /srv/solana/snapshots (mainnet, 16 connections)
./snapshot-download.py -o /srv/solana/snapshots
# Dry run — find best source, print URL
./snapshot-download.py --dry-run
# Custom RPC for cluster node discovery + 32 connections
./snapshot-download.py -r https://api.mainnet-beta.solana.com -n 32
# Testnet
./snapshot-download.py -c testnet -o /data/snapshots
Requirements:
- aria2c (apt install aria2)
- python3 >= 3.10 (stdlib only, no pip dependencies)
"""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from http.client import HTTPResponse
from pathlib import Path
from typing import NoReturn
from urllib.request import Request
log: logging.Logger = logging.getLogger("snapshot-download")
CLUSTER_RPC: dict[str, str] = {
"mainnet-beta": "https://api.mainnet-beta.solana.com",
"testnet": "https://api.testnet.solana.com",
"devnet": "https://api.devnet.solana.com",
}
# Snapshot filenames:
# snapshot-<slot>-<hash>.tar.zst
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
INCR_SNAP_RE: re.Pattern[str] = re.compile(
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
@dataclass
class SnapshotSource:
"""A snapshot file available from a specific RPC node."""
rpc_address: str
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
file_paths: list[str] = field(default_factory=list)
slots_diff: int = 0
latency_ms: float = 0.0
download_speed: float = 0.0 # bytes/sec
# -- JSON-RPC helpers ----------------------------------------------------------
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Handler that captures redirect Location instead of following it."""
def redirect_request(
self,
req: Request,
fp: HTTPResponse,
code: int,
msg: str,
headers: dict[str, str], # type: ignore[override]
newurl: str,
) -> None:
return None
def rpc_post(url: str, method: str, params: list[object] | None = None,
timeout: int = 25) -> object | None:
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
payload: bytes = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": method, "params": params or [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data: dict[str, object] = json.loads(resp.read())
return data.get("result")
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
log.debug("rpc_post %s %s failed: %s", url, method, e)
return None
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
"""HEAD request without following redirects.
Returns (Location header value, latency_sec) if the server returned a
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
"""
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
req = Request(url, method="HEAD")
try:
start: float = time.monotonic()
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
latency: float = time.monotonic() - start
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
location: str | None = resp.headers.get("Location")
resp.close()
return location, latency
except urllib.error.HTTPError as e:
# 3xx redirects raise HTTPError with the redirect info
latency = time.monotonic() - start # type: ignore[possibly-undefined]
location = e.headers.get("Location")
if location and 300 <= e.code < 400:
return location, latency
return None, 0.0
except (urllib.error.URLError, OSError, TimeoutError):
return None, 0.0
# -- Discovery -----------------------------------------------------------------
def get_current_slot(rpc_url: str) -> int | None:
"""Get current slot from RPC."""
result: object | None = rpc_post(rpc_url, "getSlot")
if isinstance(result, int):
return result
return None
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
"""Get all RPC node addresses from getClusterNodes."""
result: object | None = rpc_post(rpc_url, "getClusterNodes")
if not isinstance(result, list):
return []
rpc_addrs: list[str] = []
for node in result:
if not isinstance(node, dict):
continue
if version_filter is not None:
node_version: str | None = node.get("version")
if node_version and not node_version.startswith(version_filter):
continue
rpc: str | None = node.get("rpc")
if rpc:
rpc_addrs.append(rpc)
return list(set(rpc_addrs))
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
"""Extract filename and full redirect path from Location header.
Returns (filename, full_path). full_path includes any path prefix
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
"""
# Location may be absolute URL or relative path
if location.startswith("http://") or location.startswith("https://"):
# Absolute URL — extract path
from urllib.parse import urlparse
path: str = urlparse(location).path
else:
path = location
filename: str = path.rsplit("/", 1)[-1]
return filename, path
def probe_rpc_snapshot(
rpc_address: str,
current_slot: int,
max_age_slots: int,
max_latency_ms: float,
) -> SnapshotSource | None:
"""Probe a single RPC node for available snapshots.
Probes for full snapshot first (required), then incremental. Records all
available files. Which files to actually download is decided at download
time based on what already exists locally — not here.
Based on the discovery approach from etcusr/solana-snapshot-finder.
"""
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
# Full snapshot is required — every source must have one
full_location, full_latency = head_no_follow(full_url, timeout=2)
if not full_location:
return None
latency_ms: float = full_latency * 1000
if latency_ms > max_latency_ms:
return None
full_filename, full_path = _parse_snapshot_filename(full_location)
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
if not fm:
return None
full_snap_slot: int = int(fm.group(1))
slots_diff: int = current_slot - full_snap_slot
if slots_diff > max_age_slots or slots_diff < -100:
return None
file_paths: list[str] = [full_path]
# Also check for incremental snapshot
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
inc_location, _ = head_no_follow(inc_url, timeout=2)
if inc_location:
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
if m:
inc_base_slot: int = int(m.group(1))
# Incremental must be based on this source's full snapshot
if inc_base_slot == full_snap_slot:
file_paths.append(inc_path)
return SnapshotSource(
rpc_address=rpc_address,
file_paths=file_paths,
slots_diff=slots_diff,
latency_ms=latency_ms,
)
def discover_sources(
rpc_url: str,
current_slot: int,
max_age_slots: int,
max_latency_ms: float,
threads: int,
version_filter: str | None,
) -> list[SnapshotSource]:
"""Discover all snapshot sources from the cluster."""
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
if not rpc_nodes:
log.error("No RPC nodes found via getClusterNodes")
return []
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
sources: list[SnapshotSource] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
pool.submit(
probe_rpc_snapshot, addr, current_slot,
max_age_slots, max_latency_ms,
): addr
for addr in rpc_nodes
}
done: int = 0
for future in concurrent.futures.as_completed(futures):
done += 1
if done % 200 == 0:
log.info(" probed %d/%d nodes, %d sources found",
done, len(rpc_nodes), len(sources))
try:
result: SnapshotSource | None = future.result()
except (urllib.error.URLError, OSError, TimeoutError) as e:
log.debug("Probe failed for %s: %s", futures[future], e)
continue
if result:
sources.append(result)
log.info("Found %d RPC nodes with suitable snapshots", len(sources))
return sources
# -- Speed benchmark -----------------------------------------------------------
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
"""Measure download speed from an RPC node. Returns bytes/sec."""
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
req = Request(url)
try:
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
start: float = time.monotonic()
total: int = 0
while True:
elapsed: float = time.monotonic() - start
if elapsed >= measure_time:
break
chunk: bytes = resp.read(81920)
if not chunk:
break
total += len(chunk)
elapsed = time.monotonic() - start
if elapsed <= 0:
return 0.0
return total / elapsed
except (urllib.error.URLError, OSError, TimeoutError):
return 0.0
# -- Download ------------------------------------------------------------------
def download_aria2c(
urls: list[str],
output_dir: str,
filename: str,
connections: int = 16,
) -> bool:
"""Download a file using aria2c with parallel connections.
When multiple URLs are provided, aria2c treats them as mirrors of the
same file and distributes chunks across all of them.
"""
num_mirrors: int = len(urls)
total_splits: int = max(connections, connections * num_mirrors)
cmd: list[str] = [
"aria2c",
"--file-allocation=none",
"--continue=true",
f"--max-connection-per-server={connections}",
f"--split={total_splits}",
"--min-split-size=50M",
# aria2c retries individual chunk connections on transient network
# errors (TCP reset, timeout). This is transport-level retry analogous
# to TCP retransmit, not application-level retry of a failed operation.
"--max-tries=5",
"--retry-wait=5",
"--timeout=60",
"--connect-timeout=10",
"--summary-interval=10",
"--console-log-level=notice",
f"--dir={output_dir}",
f"--out={filename}",
"--auto-file-renaming=false",
"--allow-overwrite=true",
*urls,
]
log.info("Downloading %s", filename)
log.info(" aria2c: %d connections × %d mirrors (%d splits)",
connections, num_mirrors, total_splits)
start: float = time.monotonic()
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
elapsed: float = time.monotonic() - start
if result.returncode != 0:
log.error("aria2c failed with exit code %d", result.returncode)
return False
filepath: Path = Path(output_dir) / filename
if not filepath.exists():
log.error("aria2c reported success but %s does not exist", filepath)
return False
size_bytes: int = filepath.stat().st_size
size_gb: float = size_bytes / (1024 ** 3)
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
return True
# -- Main ----------------------------------------------------------------------
def main() -> int:
p: argparse.ArgumentParser = argparse.ArgumentParser(
description="Download Solana snapshots with aria2c parallel downloads",
)
p.add_argument("-o", "--output", default="/srv/solana/snapshots",
help="Snapshot output directory (default: /srv/solana/snapshots)")
p.add_argument("-c", "--cluster", default="mainnet-beta",
choices=list(CLUSTER_RPC),
help="Solana cluster (default: mainnet-beta)")
p.add_argument("-r", "--rpc", default=None,
help="RPC URL for cluster discovery (default: public RPC)")
p.add_argument("-n", "--connections", type=int, default=16,
help="aria2c connections per download (default: 16)")
p.add_argument("-t", "--threads", type=int, default=500,
help="Threads for parallel RPC probing (default: 500)")
p.add_argument("--max-snapshot-age", type=int, default=1300,
help="Max snapshot age in slots (default: 1300)")
p.add_argument("--max-latency", type=float, default=100,
help="Max RPC probe latency in ms (default: 100)")
p.add_argument("--min-download-speed", type=int, default=20,
help="Min download speed in MiB/s (default: 20)")
p.add_argument("--measurement-time", type=int, default=7,
help="Speed measurement duration in seconds (default: 7)")
p.add_argument("--max-speed-checks", type=int, default=15,
help="Max nodes to benchmark before giving up (default: 15)")
p.add_argument("--version", default=None,
help="Filter nodes by version prefix (e.g. '2.2')")
p.add_argument("--full-only", action="store_true",
help="Download only full snapshot, skip incremental")
p.add_argument("--dry-run", action="store_true",
help="Find best source and print URL, don't download")
p.add_argument("-v", "--verbose", action="store_true")
args: argparse.Namespace = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
# aria2c is required for actual downloads (not dry-run)
if not args.dry_run and not shutil.which("aria2c"):
log.error("aria2c not found. Install with: apt install aria2")
return 1
# Get current slot
log.info("Cluster: %s | RPC: %s", args.cluster, rpc_url)
current_slot: int | None = get_current_slot(rpc_url)
if current_slot is None:
log.error("Cannot get current slot from %s", rpc_url)
return 1
log.info("Current slot: %d", current_slot)
# Discover sources
sources: list[SnapshotSource] = discover_sources(
rpc_url, current_slot,
max_age_slots=args.max_snapshot_age,
max_latency_ms=args.max_latency,
threads=args.threads,
version_filter=args.version,
)
if not sources:
log.error("No snapshot sources found")
return 1
# Sort by latency (lowest first) for speed benchmarking
sources.sort(key=lambda s: s.latency_ms)
# Benchmark top candidates — all speeds in MiB/s (binary, 1 MiB = 1048576 bytes)
log.info("Benchmarking download speed on top %d sources...", args.max_speed_checks)
fast_sources: list[SnapshotSource] = []
checked: int = 0
min_speed_bytes: int = args.min_download_speed * 1024 * 1024 # MiB to bytes
for source in sources:
if checked >= args.max_speed_checks:
break
checked += 1
speed: float = measure_speed(source.rpc_address, args.measurement_time)
source.download_speed = speed
speed_mib: float = speed / (1024 ** 2)
if speed < min_speed_bytes:
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
source.rpc_address, speed_mib, args.min_download_speed)
continue
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
source.rpc_address, speed_mib,
source.latency_ms, source.slots_diff)
fast_sources.append(source)
if not fast_sources:
log.error("No source met minimum speed requirement (%d MiB/s)",
args.min_download_speed)
log.info("Try: --min-download-speed 10")
return 1
# Use the fastest source as primary, collect mirrors for each file
best: SnapshotSource = fast_sources[0]
file_paths: list[str] = best.file_paths
if args.full_only:
file_paths = [fp for fp in file_paths
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
# Build mirror URL lists: for each file, collect URLs from all fast sources
# that serve the same filename
download_plan: list[tuple[str, list[str]]] = []
for fp in file_paths:
filename: str = fp.rsplit("/", 1)[-1]
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
for other in fast_sources[1:]:
for other_fp in other.file_paths:
if other_fp.rsplit("/", 1)[-1] == filename:
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
break
download_plan.append((filename, mirror_urls))
speed_mib: float = best.download_speed / (1024 ** 2)
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
best.rpc_address, speed_mib, len(fast_sources))
for filename, mirror_urls in download_plan:
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
for url in mirror_urls:
log.info(" %s", url)
if args.dry_run:
for _, mirror_urls in download_plan:
for url in mirror_urls:
print(url)
return 0
# Download — skip files that already exist locally
os.makedirs(args.output, exist_ok=True)
total_start: float = time.monotonic()
for filename, mirror_urls in download_plan:
filepath: Path = Path(args.output) / filename
if filepath.exists() and filepath.stat().st_size > 0:
log.info("Skipping %s (already exists: %.1f GB)",
filename, filepath.stat().st_size / (1024 ** 3))
continue
if not download_aria2c(mirror_urls, args.output, filename, args.connections):
log.error("Failed to download %s", filename)
return 1
total_elapsed: float = time.monotonic() - total_start
log.info("All downloads complete in %.0fs", total_elapsed)
for filename, _ in download_plan:
fp: Path = Path(args.output) / filename
if fp.exists():
log.info(" %s (%.1f GB)", fp.name, fp.stat().st_size / (1024 ** 3))
return 0
if __name__ == "__main__":
sys.exit(main())