#!/usr/bin/env python3 """Download Solana snapshots using aria2c for parallel multi-connection downloads. Discovers snapshot sources by querying getClusterNodes for all RPCs in the cluster, probing each for available snapshots, benchmarking download speed, and downloading from the fastest source using aria2c (16 connections by default). Based on the discovery approach from etcusr/solana-snapshot-finder but replaces the single-connection wget download with aria2c parallel chunked downloads. Usage: # Download to /srv/solana/snapshots (mainnet, 16 connections) ./snapshot-download.py -o /srv/solana/snapshots # Dry run — find best source, print URL ./snapshot-download.py --dry-run # Custom RPC for cluster node discovery + 32 connections ./snapshot-download.py -r https://api.mainnet-beta.solana.com -n 32 # Testnet ./snapshot-download.py -c testnet -o /data/snapshots Requirements: - aria2c (apt install aria2) - python3 >= 3.10 (stdlib only, no pip dependencies) """ from __future__ import annotations import argparse import concurrent.futures import json import logging import os import re import shutil import subprocess import sys import time import urllib.error import urllib.request from dataclasses import dataclass, field from http.client import HTTPResponse from pathlib import Path from typing import NoReturn from urllib.request import Request log: logging.Logger = logging.getLogger("snapshot-download") CLUSTER_RPC: dict[str, str] = { "mainnet-beta": "https://api.mainnet-beta.solana.com", "testnet": "https://api.testnet.solana.com", "devnet": "https://api.devnet.solana.com", } # Snapshot filenames: # snapshot--.tar.zst # incremental-snapshot---.tar.zst FULL_SNAP_RE: re.Pattern[str] = re.compile( r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" ) INCR_SNAP_RE: re.Pattern[str] = re.compile( r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" ) @dataclass class SnapshotSource: """A snapshot file available from a specific RPC node.""" rpc_address: str # Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst) file_paths: list[str] = field(default_factory=list) slots_diff: int = 0 latency_ms: float = 0.0 download_speed: float = 0.0 # bytes/sec # -- JSON-RPC helpers ---------------------------------------------------------- class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): """Handler that captures redirect Location instead of following it.""" def redirect_request( self, req: Request, fp: HTTPResponse, code: int, msg: str, headers: dict[str, str], # type: ignore[override] newurl: str, ) -> None: return None def rpc_post(url: str, method: str, params: list[object] | None = None, timeout: int = 25) -> object | None: """JSON-RPC POST. Returns parsed 'result' field or None on error.""" payload: bytes = json.dumps({ "jsonrpc": "2.0", "id": 1, "method": method, "params": params or [], }).encode() req = Request(url, data=payload, headers={"Content-Type": "application/json"}) try: with urllib.request.urlopen(req, timeout=timeout) as resp: data: dict[str, object] = json.loads(resp.read()) return data.get("result") except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e: log.debug("rpc_post %s %s failed: %s", url, method, e) return None def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]: """HEAD request without following redirects. Returns (Location header value, latency_sec) if the server returned a 3xx redirect. Returns (None, 0.0) on any error or non-redirect response. """ opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler) req = Request(url, method="HEAD") try: start: float = time.monotonic() resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment] latency: float = time.monotonic() - start # Non-redirect (2xx) — server didn't redirect, not useful for discovery location: str | None = resp.headers.get("Location") resp.close() return location, latency except urllib.error.HTTPError as e: # 3xx redirects raise HTTPError with the redirect info latency = time.monotonic() - start # type: ignore[possibly-undefined] location = e.headers.get("Location") if location and 300 <= e.code < 400: return location, latency return None, 0.0 except (urllib.error.URLError, OSError, TimeoutError): return None, 0.0 # -- Discovery ----------------------------------------------------------------- def get_current_slot(rpc_url: str) -> int | None: """Get current slot from RPC.""" result: object | None = rpc_post(rpc_url, "getSlot") if isinstance(result, int): return result return None def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]: """Get all RPC node addresses from getClusterNodes.""" result: object | None = rpc_post(rpc_url, "getClusterNodes") if not isinstance(result, list): return [] rpc_addrs: list[str] = [] for node in result: if not isinstance(node, dict): continue if version_filter is not None: node_version: str | None = node.get("version") if node_version and not node_version.startswith(version_filter): continue rpc: str | None = node.get("rpc") if rpc: rpc_addrs.append(rpc) return list(set(rpc_addrs)) def _parse_snapshot_filename(location: str) -> tuple[str, str | None]: """Extract filename and full redirect path from Location header. Returns (filename, full_path). full_path includes any path prefix the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst'). """ # Location may be absolute URL or relative path if location.startswith("http://") or location.startswith("https://"): # Absolute URL — extract path from urllib.parse import urlparse path: str = urlparse(location).path else: path = location filename: str = path.rsplit("/", 1)[-1] return filename, path def probe_rpc_snapshot( rpc_address: str, current_slot: int, max_age_slots: int, max_latency_ms: float, ) -> SnapshotSource | None: """Probe a single RPC node for available snapshots. Probes for full snapshot first (required), then incremental. Records all available files. Which files to actually download is decided at download time based on what already exists locally — not here. Based on the discovery approach from etcusr/solana-snapshot-finder. """ full_url: str = f"http://{rpc_address}/snapshot.tar.bz2" # Full snapshot is required — every source must have one full_location, full_latency = head_no_follow(full_url, timeout=2) if not full_location: return None latency_ms: float = full_latency * 1000 if latency_ms > max_latency_ms: return None full_filename, full_path = _parse_snapshot_filename(full_location) fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename) if not fm: return None full_snap_slot: int = int(fm.group(1)) slots_diff: int = current_slot - full_snap_slot if slots_diff > max_age_slots or slots_diff < -100: return None file_paths: list[str] = [full_path] # Also check for incremental snapshot inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2" inc_location, _ = head_no_follow(inc_url, timeout=2) if inc_location: inc_filename, inc_path = _parse_snapshot_filename(inc_location) m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename) if m: inc_base_slot: int = int(m.group(1)) # Incremental must be based on this source's full snapshot if inc_base_slot == full_snap_slot: file_paths.append(inc_path) return SnapshotSource( rpc_address=rpc_address, file_paths=file_paths, slots_diff=slots_diff, latency_ms=latency_ms, ) def discover_sources( rpc_url: str, current_slot: int, max_age_slots: int, max_latency_ms: float, threads: int, version_filter: str | None, ) -> list[SnapshotSource]: """Discover all snapshot sources from the cluster.""" rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter) if not rpc_nodes: log.error("No RPC nodes found via getClusterNodes") return [] log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes)) sources: list[SnapshotSource] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = { pool.submit( probe_rpc_snapshot, addr, current_slot, max_age_slots, max_latency_ms, ): addr for addr in rpc_nodes } done: int = 0 for future in concurrent.futures.as_completed(futures): done += 1 if done % 200 == 0: log.info(" probed %d/%d nodes, %d sources found", done, len(rpc_nodes), len(sources)) try: result: SnapshotSource | None = future.result() except (urllib.error.URLError, OSError, TimeoutError) as e: log.debug("Probe failed for %s: %s", futures[future], e) continue if result: sources.append(result) log.info("Found %d RPC nodes with suitable snapshots", len(sources)) return sources # -- Speed benchmark ----------------------------------------------------------- def measure_speed(rpc_address: str, measure_time: int = 7) -> float: """Measure download speed from an RPC node. Returns bytes/sec.""" url: str = f"http://{rpc_address}/snapshot.tar.bz2" req = Request(url) try: with urllib.request.urlopen(req, timeout=measure_time + 5) as resp: start: float = time.monotonic() total: int = 0 while True: elapsed: float = time.monotonic() - start if elapsed >= measure_time: break chunk: bytes = resp.read(81920) if not chunk: break total += len(chunk) elapsed = time.monotonic() - start if elapsed <= 0: return 0.0 return total / elapsed except (urllib.error.URLError, OSError, TimeoutError): return 0.0 # -- Download ------------------------------------------------------------------ def download_aria2c( urls: list[str], output_dir: str, filename: str, connections: int = 16, ) -> bool: """Download a file using aria2c with parallel connections. When multiple URLs are provided, aria2c treats them as mirrors of the same file and distributes chunks across all of them. """ num_mirrors: int = len(urls) total_splits: int = max(connections, connections * num_mirrors) cmd: list[str] = [ "aria2c", "--file-allocation=none", "--continue=true", f"--max-connection-per-server={connections}", f"--split={total_splits}", "--min-split-size=50M", # aria2c retries individual chunk connections on transient network # errors (TCP reset, timeout). This is transport-level retry analogous # to TCP retransmit, not application-level retry of a failed operation. "--max-tries=5", "--retry-wait=5", "--timeout=60", "--connect-timeout=10", "--summary-interval=10", "--console-log-level=notice", f"--dir={output_dir}", f"--out={filename}", "--auto-file-renaming=false", "--allow-overwrite=true", *urls, ] log.info("Downloading %s", filename) log.info(" aria2c: %d connections × %d mirrors (%d splits)", connections, num_mirrors, total_splits) start: float = time.monotonic() result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd) elapsed: float = time.monotonic() - start if result.returncode != 0: log.error("aria2c failed with exit code %d", result.returncode) return False filepath: Path = Path(output_dir) / filename if not filepath.exists(): log.error("aria2c reported success but %s does not exist", filepath) return False size_bytes: int = filepath.stat().st_size size_gb: float = size_bytes / (1024 ** 3) avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0 log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb) return True # -- Main ---------------------------------------------------------------------- def main() -> int: p: argparse.ArgumentParser = argparse.ArgumentParser( description="Download Solana snapshots with aria2c parallel downloads", ) p.add_argument("-o", "--output", default="/srv/solana/snapshots", help="Snapshot output directory (default: /srv/solana/snapshots)") p.add_argument("-c", "--cluster", default="mainnet-beta", choices=list(CLUSTER_RPC), help="Solana cluster (default: mainnet-beta)") p.add_argument("-r", "--rpc", default=None, help="RPC URL for cluster discovery (default: public RPC)") p.add_argument("-n", "--connections", type=int, default=16, help="aria2c connections per download (default: 16)") p.add_argument("-t", "--threads", type=int, default=500, help="Threads for parallel RPC probing (default: 500)") p.add_argument("--max-snapshot-age", type=int, default=1300, help="Max snapshot age in slots (default: 1300)") p.add_argument("--max-latency", type=float, default=100, help="Max RPC probe latency in ms (default: 100)") p.add_argument("--min-download-speed", type=int, default=20, help="Min download speed in MiB/s (default: 20)") p.add_argument("--measurement-time", type=int, default=7, help="Speed measurement duration in seconds (default: 7)") p.add_argument("--max-speed-checks", type=int, default=15, help="Max nodes to benchmark before giving up (default: 15)") p.add_argument("--version", default=None, help="Filter nodes by version prefix (e.g. '2.2')") p.add_argument("--full-only", action="store_true", help="Download only full snapshot, skip incremental") p.add_argument("--dry-run", action="store_true", help="Find best source and print URL, don't download") p.add_argument("-v", "--verbose", action="store_true") args: argparse.Namespace = p.parse_args() logging.basicConfig( level=logging.DEBUG if args.verbose else logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster] # aria2c is required for actual downloads (not dry-run) if not args.dry_run and not shutil.which("aria2c"): log.error("aria2c not found. Install with: apt install aria2") return 1 # Get current slot log.info("Cluster: %s | RPC: %s", args.cluster, rpc_url) current_slot: int | None = get_current_slot(rpc_url) if current_slot is None: log.error("Cannot get current slot from %s", rpc_url) return 1 log.info("Current slot: %d", current_slot) # Discover sources sources: list[SnapshotSource] = discover_sources( rpc_url, current_slot, max_age_slots=args.max_snapshot_age, max_latency_ms=args.max_latency, threads=args.threads, version_filter=args.version, ) if not sources: log.error("No snapshot sources found") return 1 # Sort by latency (lowest first) for speed benchmarking sources.sort(key=lambda s: s.latency_ms) # Benchmark top candidates — all speeds in MiB/s (binary, 1 MiB = 1048576 bytes) log.info("Benchmarking download speed on top %d sources...", args.max_speed_checks) fast_sources: list[SnapshotSource] = [] checked: int = 0 min_speed_bytes: int = args.min_download_speed * 1024 * 1024 # MiB to bytes for source in sources: if checked >= args.max_speed_checks: break checked += 1 speed: float = measure_speed(source.rpc_address, args.measurement_time) source.download_speed = speed speed_mib: float = speed / (1024 ** 2) if speed < min_speed_bytes: log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)", source.rpc_address, speed_mib, args.min_download_speed) continue log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)", source.rpc_address, speed_mib, source.latency_ms, source.slots_diff) fast_sources.append(source) if not fast_sources: log.error("No source met minimum speed requirement (%d MiB/s)", args.min_download_speed) log.info("Try: --min-download-speed 10") return 1 # Use the fastest source as primary, collect mirrors for each file best: SnapshotSource = fast_sources[0] file_paths: list[str] = best.file_paths if args.full_only: file_paths = [fp for fp in file_paths if fp.rsplit("/", 1)[-1].startswith("snapshot-")] # Build mirror URL lists: for each file, collect URLs from all fast sources # that serve the same filename download_plan: list[tuple[str, list[str]]] = [] for fp in file_paths: filename: str = fp.rsplit("/", 1)[-1] mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"] for other in fast_sources[1:]: for other_fp in other.file_paths: if other_fp.rsplit("/", 1)[-1] == filename: mirror_urls.append(f"http://{other.rpc_address}{other_fp}") break download_plan.append((filename, mirror_urls)) speed_mib: float = best.download_speed / (1024 ** 2) log.info("Best source: %s (%.1f MiB/s), %d mirrors total", best.rpc_address, speed_mib, len(fast_sources)) for filename, mirror_urls in download_plan: log.info(" %s (%d mirrors)", filename, len(mirror_urls)) for url in mirror_urls: log.info(" %s", url) if args.dry_run: for _, mirror_urls in download_plan: for url in mirror_urls: print(url) return 0 # Download — skip files that already exist locally os.makedirs(args.output, exist_ok=True) total_start: float = time.monotonic() for filename, mirror_urls in download_plan: filepath: Path = Path(args.output) / filename if filepath.exists() and filepath.stat().st_size > 0: log.info("Skipping %s (already exists: %.1f GB)", filename, filepath.stat().st_size / (1024 ** 3)) continue if not download_aria2c(mirror_urls, args.output, filename, args.connections): log.error("Failed to download %s", filename) return 1 total_elapsed: float = time.monotonic() - total_start log.info("All downloads complete in %.0fs", total_elapsed) for filename, _ in download_plan: fp: Path = Path(args.output) / filename if fp.exists(): log.info(" %s (%.1f GB)", fp.name, fp.stat().st_size / (1024 ** 3)) return 0 if __name__ == "__main__": sys.exit(main())