547 lines
20 KiB
Python
Executable File
547 lines
20 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
|
||
|
||
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
|
||
cluster, probing each for available snapshots, benchmarking download speed,
|
||
and downloading from the fastest source using aria2c (16 connections by default).
|
||
|
||
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
|
||
the single-connection wget download with aria2c parallel chunked downloads.
|
||
|
||
Usage:
|
||
# Download to /srv/solana/snapshots (mainnet, 16 connections)
|
||
./snapshot-download.py -o /srv/solana/snapshots
|
||
|
||
# Dry run — find best source, print URL
|
||
./snapshot-download.py --dry-run
|
||
|
||
# Custom RPC for cluster node discovery + 32 connections
|
||
./snapshot-download.py -r https://api.mainnet-beta.solana.com -n 32
|
||
|
||
# Testnet
|
||
./snapshot-download.py -c testnet -o /data/snapshots
|
||
|
||
Requirements:
|
||
- aria2c (apt install aria2)
|
||
- python3 >= 3.10 (stdlib only, no pip dependencies)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import concurrent.futures
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import shutil
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import urllib.error
|
||
import urllib.request
|
||
from dataclasses import dataclass, field
|
||
from http.client import HTTPResponse
|
||
from pathlib import Path
|
||
from typing import NoReturn
|
||
from urllib.request import Request
|
||
|
||
log: logging.Logger = logging.getLogger("snapshot-download")
|
||
|
||
CLUSTER_RPC: dict[str, str] = {
|
||
"mainnet-beta": "https://api.mainnet-beta.solana.com",
|
||
"testnet": "https://api.testnet.solana.com",
|
||
"devnet": "https://api.devnet.solana.com",
|
||
}
|
||
|
||
# Snapshot filenames:
|
||
# snapshot-<slot>-<hash>.tar.zst
|
||
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
|
||
FULL_SNAP_RE: re.Pattern[str] = re.compile(
|
||
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
||
)
|
||
INCR_SNAP_RE: re.Pattern[str] = re.compile(
|
||
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class SnapshotSource:
|
||
"""A snapshot file available from a specific RPC node."""
|
||
|
||
rpc_address: str
|
||
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
|
||
file_paths: list[str] = field(default_factory=list)
|
||
slots_diff: int = 0
|
||
latency_ms: float = 0.0
|
||
download_speed: float = 0.0 # bytes/sec
|
||
|
||
|
||
# -- JSON-RPC helpers ----------------------------------------------------------
|
||
|
||
|
||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
||
"""Handler that captures redirect Location instead of following it."""
|
||
|
||
def redirect_request(
|
||
self,
|
||
req: Request,
|
||
fp: HTTPResponse,
|
||
code: int,
|
||
msg: str,
|
||
headers: dict[str, str], # type: ignore[override]
|
||
newurl: str,
|
||
) -> None:
|
||
return None
|
||
|
||
|
||
def rpc_post(url: str, method: str, params: list[object] | None = None,
|
||
timeout: int = 25) -> object | None:
|
||
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
|
||
payload: bytes = json.dumps({
|
||
"jsonrpc": "2.0", "id": 1,
|
||
"method": method, "params": params or [],
|
||
}).encode()
|
||
req = Request(url, data=payload,
|
||
headers={"Content-Type": "application/json"})
|
||
try:
|
||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||
data: dict[str, object] = json.loads(resp.read())
|
||
return data.get("result")
|
||
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
|
||
log.debug("rpc_post %s %s failed: %s", url, method, e)
|
||
return None
|
||
|
||
|
||
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
|
||
"""HEAD request without following redirects.
|
||
|
||
Returns (Location header value, latency_sec) if the server returned a
|
||
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
|
||
"""
|
||
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
|
||
req = Request(url, method="HEAD")
|
||
try:
|
||
start: float = time.monotonic()
|
||
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
|
||
latency: float = time.monotonic() - start
|
||
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
|
||
location: str | None = resp.headers.get("Location")
|
||
resp.close()
|
||
return location, latency
|
||
except urllib.error.HTTPError as e:
|
||
# 3xx redirects raise HTTPError with the redirect info
|
||
latency = time.monotonic() - start # type: ignore[possibly-undefined]
|
||
location = e.headers.get("Location")
|
||
if location and 300 <= e.code < 400:
|
||
return location, latency
|
||
return None, 0.0
|
||
except (urllib.error.URLError, OSError, TimeoutError):
|
||
return None, 0.0
|
||
|
||
|
||
# -- Discovery -----------------------------------------------------------------
|
||
|
||
|
||
def get_current_slot(rpc_url: str) -> int | None:
|
||
"""Get current slot from RPC."""
|
||
result: object | None = rpc_post(rpc_url, "getSlot")
|
||
if isinstance(result, int):
|
||
return result
|
||
return None
|
||
|
||
|
||
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
|
||
"""Get all RPC node addresses from getClusterNodes."""
|
||
result: object | None = rpc_post(rpc_url, "getClusterNodes")
|
||
if not isinstance(result, list):
|
||
return []
|
||
|
||
rpc_addrs: list[str] = []
|
||
for node in result:
|
||
if not isinstance(node, dict):
|
||
continue
|
||
if version_filter is not None:
|
||
node_version: str | None = node.get("version")
|
||
if node_version and not node_version.startswith(version_filter):
|
||
continue
|
||
rpc: str | None = node.get("rpc")
|
||
if rpc:
|
||
rpc_addrs.append(rpc)
|
||
return list(set(rpc_addrs))
|
||
|
||
|
||
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
|
||
"""Extract filename and full redirect path from Location header.
|
||
|
||
Returns (filename, full_path). full_path includes any path prefix
|
||
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
|
||
"""
|
||
# Location may be absolute URL or relative path
|
||
if location.startswith("http://") or location.startswith("https://"):
|
||
# Absolute URL — extract path
|
||
from urllib.parse import urlparse
|
||
path: str = urlparse(location).path
|
||
else:
|
||
path = location
|
||
|
||
filename: str = path.rsplit("/", 1)[-1]
|
||
return filename, path
|
||
|
||
|
||
def probe_rpc_snapshot(
|
||
rpc_address: str,
|
||
current_slot: int,
|
||
max_age_slots: int,
|
||
max_latency_ms: float,
|
||
) -> SnapshotSource | None:
|
||
"""Probe a single RPC node for available snapshots.
|
||
|
||
Probes for full snapshot first (required), then incremental. Records all
|
||
available files. Which files to actually download is decided at download
|
||
time based on what already exists locally — not here.
|
||
|
||
Based on the discovery approach from etcusr/solana-snapshot-finder.
|
||
"""
|
||
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
||
|
||
# Full snapshot is required — every source must have one
|
||
full_location, full_latency = head_no_follow(full_url, timeout=2)
|
||
if not full_location:
|
||
return None
|
||
|
||
latency_ms: float = full_latency * 1000
|
||
if latency_ms > max_latency_ms:
|
||
return None
|
||
|
||
full_filename, full_path = _parse_snapshot_filename(full_location)
|
||
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
||
if not fm:
|
||
return None
|
||
|
||
full_snap_slot: int = int(fm.group(1))
|
||
slots_diff: int = current_slot - full_snap_slot
|
||
|
||
if slots_diff > max_age_slots or slots_diff < -100:
|
||
return None
|
||
|
||
file_paths: list[str] = [full_path]
|
||
|
||
# Also check for incremental snapshot
|
||
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
|
||
inc_location, _ = head_no_follow(inc_url, timeout=2)
|
||
if inc_location:
|
||
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
|
||
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
|
||
if m:
|
||
inc_base_slot: int = int(m.group(1))
|
||
# Incremental must be based on this source's full snapshot
|
||
if inc_base_slot == full_snap_slot:
|
||
file_paths.append(inc_path)
|
||
|
||
return SnapshotSource(
|
||
rpc_address=rpc_address,
|
||
file_paths=file_paths,
|
||
slots_diff=slots_diff,
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
|
||
def discover_sources(
|
||
rpc_url: str,
|
||
current_slot: int,
|
||
max_age_slots: int,
|
||
max_latency_ms: float,
|
||
threads: int,
|
||
version_filter: str | None,
|
||
) -> list[SnapshotSource]:
|
||
"""Discover all snapshot sources from the cluster."""
|
||
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
|
||
if not rpc_nodes:
|
||
log.error("No RPC nodes found via getClusterNodes")
|
||
return []
|
||
|
||
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
|
||
|
||
sources: list[SnapshotSource] = []
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
|
||
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
|
||
pool.submit(
|
||
probe_rpc_snapshot, addr, current_slot,
|
||
max_age_slots, max_latency_ms,
|
||
): addr
|
||
for addr in rpc_nodes
|
||
}
|
||
done: int = 0
|
||
for future in concurrent.futures.as_completed(futures):
|
||
done += 1
|
||
if done % 200 == 0:
|
||
log.info(" probed %d/%d nodes, %d sources found",
|
||
done, len(rpc_nodes), len(sources))
|
||
try:
|
||
result: SnapshotSource | None = future.result()
|
||
except (urllib.error.URLError, OSError, TimeoutError) as e:
|
||
log.debug("Probe failed for %s: %s", futures[future], e)
|
||
continue
|
||
if result:
|
||
sources.append(result)
|
||
|
||
log.info("Found %d RPC nodes with suitable snapshots", len(sources))
|
||
return sources
|
||
|
||
|
||
# -- Speed benchmark -----------------------------------------------------------
|
||
|
||
|
||
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
|
||
"""Measure download speed from an RPC node. Returns bytes/sec."""
|
||
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
||
req = Request(url)
|
||
try:
|
||
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
|
||
start: float = time.monotonic()
|
||
total: int = 0
|
||
while True:
|
||
elapsed: float = time.monotonic() - start
|
||
if elapsed >= measure_time:
|
||
break
|
||
chunk: bytes = resp.read(81920)
|
||
if not chunk:
|
||
break
|
||
total += len(chunk)
|
||
elapsed = time.monotonic() - start
|
||
if elapsed <= 0:
|
||
return 0.0
|
||
return total / elapsed
|
||
except (urllib.error.URLError, OSError, TimeoutError):
|
||
return 0.0
|
||
|
||
|
||
# -- Download ------------------------------------------------------------------
|
||
|
||
|
||
def download_aria2c(
|
||
urls: list[str],
|
||
output_dir: str,
|
||
filename: str,
|
||
connections: int = 16,
|
||
) -> bool:
|
||
"""Download a file using aria2c with parallel connections.
|
||
|
||
When multiple URLs are provided, aria2c treats them as mirrors of the
|
||
same file and distributes chunks across all of them.
|
||
"""
|
||
num_mirrors: int = len(urls)
|
||
total_splits: int = max(connections, connections * num_mirrors)
|
||
cmd: list[str] = [
|
||
"aria2c",
|
||
"--file-allocation=none",
|
||
"--continue=true",
|
||
f"--max-connection-per-server={connections}",
|
||
f"--split={total_splits}",
|
||
"--min-split-size=50M",
|
||
# aria2c retries individual chunk connections on transient network
|
||
# errors (TCP reset, timeout). This is transport-level retry analogous
|
||
# to TCP retransmit, not application-level retry of a failed operation.
|
||
"--max-tries=5",
|
||
"--retry-wait=5",
|
||
"--timeout=60",
|
||
"--connect-timeout=10",
|
||
"--summary-interval=10",
|
||
"--console-log-level=notice",
|
||
f"--dir={output_dir}",
|
||
f"--out={filename}",
|
||
"--auto-file-renaming=false",
|
||
"--allow-overwrite=true",
|
||
*urls,
|
||
]
|
||
|
||
log.info("Downloading %s", filename)
|
||
log.info(" aria2c: %d connections × %d mirrors (%d splits)",
|
||
connections, num_mirrors, total_splits)
|
||
|
||
start: float = time.monotonic()
|
||
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
|
||
elapsed: float = time.monotonic() - start
|
||
|
||
if result.returncode != 0:
|
||
log.error("aria2c failed with exit code %d", result.returncode)
|
||
return False
|
||
|
||
filepath: Path = Path(output_dir) / filename
|
||
if not filepath.exists():
|
||
log.error("aria2c reported success but %s does not exist", filepath)
|
||
return False
|
||
|
||
size_bytes: int = filepath.stat().st_size
|
||
size_gb: float = size_bytes / (1024 ** 3)
|
||
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
|
||
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
|
||
return True
|
||
|
||
|
||
# -- Main ----------------------------------------------------------------------
|
||
|
||
|
||
def main() -> int:
|
||
p: argparse.ArgumentParser = argparse.ArgumentParser(
|
||
description="Download Solana snapshots with aria2c parallel downloads",
|
||
)
|
||
p.add_argument("-o", "--output", default="/srv/solana/snapshots",
|
||
help="Snapshot output directory (default: /srv/solana/snapshots)")
|
||
p.add_argument("-c", "--cluster", default="mainnet-beta",
|
||
choices=list(CLUSTER_RPC),
|
||
help="Solana cluster (default: mainnet-beta)")
|
||
p.add_argument("-r", "--rpc", default=None,
|
||
help="RPC URL for cluster discovery (default: public RPC)")
|
||
p.add_argument("-n", "--connections", type=int, default=16,
|
||
help="aria2c connections per download (default: 16)")
|
||
p.add_argument("-t", "--threads", type=int, default=500,
|
||
help="Threads for parallel RPC probing (default: 500)")
|
||
p.add_argument("--max-snapshot-age", type=int, default=1300,
|
||
help="Max snapshot age in slots (default: 1300)")
|
||
p.add_argument("--max-latency", type=float, default=100,
|
||
help="Max RPC probe latency in ms (default: 100)")
|
||
p.add_argument("--min-download-speed", type=int, default=20,
|
||
help="Min download speed in MiB/s (default: 20)")
|
||
p.add_argument("--measurement-time", type=int, default=7,
|
||
help="Speed measurement duration in seconds (default: 7)")
|
||
p.add_argument("--max-speed-checks", type=int, default=15,
|
||
help="Max nodes to benchmark before giving up (default: 15)")
|
||
p.add_argument("--version", default=None,
|
||
help="Filter nodes by version prefix (e.g. '2.2')")
|
||
p.add_argument("--full-only", action="store_true",
|
||
help="Download only full snapshot, skip incremental")
|
||
p.add_argument("--dry-run", action="store_true",
|
||
help="Find best source and print URL, don't download")
|
||
p.add_argument("-v", "--verbose", action="store_true")
|
||
args: argparse.Namespace = p.parse_args()
|
||
|
||
logging.basicConfig(
|
||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
|
||
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
|
||
|
||
# aria2c is required for actual downloads (not dry-run)
|
||
if not args.dry_run and not shutil.which("aria2c"):
|
||
log.error("aria2c not found. Install with: apt install aria2")
|
||
return 1
|
||
|
||
# Get current slot
|
||
log.info("Cluster: %s | RPC: %s", args.cluster, rpc_url)
|
||
current_slot: int | None = get_current_slot(rpc_url)
|
||
if current_slot is None:
|
||
log.error("Cannot get current slot from %s", rpc_url)
|
||
return 1
|
||
log.info("Current slot: %d", current_slot)
|
||
|
||
# Discover sources
|
||
sources: list[SnapshotSource] = discover_sources(
|
||
rpc_url, current_slot,
|
||
max_age_slots=args.max_snapshot_age,
|
||
max_latency_ms=args.max_latency,
|
||
threads=args.threads,
|
||
version_filter=args.version,
|
||
)
|
||
if not sources:
|
||
log.error("No snapshot sources found")
|
||
return 1
|
||
|
||
# Sort by latency (lowest first) for speed benchmarking
|
||
sources.sort(key=lambda s: s.latency_ms)
|
||
|
||
# Benchmark top candidates — all speeds in MiB/s (binary, 1 MiB = 1048576 bytes)
|
||
log.info("Benchmarking download speed on top %d sources...", args.max_speed_checks)
|
||
fast_sources: list[SnapshotSource] = []
|
||
checked: int = 0
|
||
min_speed_bytes: int = args.min_download_speed * 1024 * 1024 # MiB to bytes
|
||
|
||
for source in sources:
|
||
if checked >= args.max_speed_checks:
|
||
break
|
||
checked += 1
|
||
|
||
speed: float = measure_speed(source.rpc_address, args.measurement_time)
|
||
source.download_speed = speed
|
||
speed_mib: float = speed / (1024 ** 2)
|
||
|
||
if speed < min_speed_bytes:
|
||
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
|
||
source.rpc_address, speed_mib, args.min_download_speed)
|
||
continue
|
||
|
||
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
|
||
source.rpc_address, speed_mib,
|
||
source.latency_ms, source.slots_diff)
|
||
fast_sources.append(source)
|
||
|
||
if not fast_sources:
|
||
log.error("No source met minimum speed requirement (%d MiB/s)",
|
||
args.min_download_speed)
|
||
log.info("Try: --min-download-speed 10")
|
||
return 1
|
||
|
||
# Use the fastest source as primary, collect mirrors for each file
|
||
best: SnapshotSource = fast_sources[0]
|
||
file_paths: list[str] = best.file_paths
|
||
if args.full_only:
|
||
file_paths = [fp for fp in file_paths
|
||
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
|
||
|
||
# Build mirror URL lists: for each file, collect URLs from all fast sources
|
||
# that serve the same filename
|
||
download_plan: list[tuple[str, list[str]]] = []
|
||
for fp in file_paths:
|
||
filename: str = fp.rsplit("/", 1)[-1]
|
||
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
|
||
for other in fast_sources[1:]:
|
||
for other_fp in other.file_paths:
|
||
if other_fp.rsplit("/", 1)[-1] == filename:
|
||
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
|
||
break
|
||
download_plan.append((filename, mirror_urls))
|
||
|
||
speed_mib: float = best.download_speed / (1024 ** 2)
|
||
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
|
||
best.rpc_address, speed_mib, len(fast_sources))
|
||
for filename, mirror_urls in download_plan:
|
||
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
|
||
for url in mirror_urls:
|
||
log.info(" %s", url)
|
||
|
||
if args.dry_run:
|
||
for _, mirror_urls in download_plan:
|
||
for url in mirror_urls:
|
||
print(url)
|
||
return 0
|
||
|
||
# Download — skip files that already exist locally
|
||
os.makedirs(args.output, exist_ok=True)
|
||
total_start: float = time.monotonic()
|
||
|
||
for filename, mirror_urls in download_plan:
|
||
filepath: Path = Path(args.output) / filename
|
||
if filepath.exists() and filepath.stat().st_size > 0:
|
||
log.info("Skipping %s (already exists: %.1f GB)",
|
||
filename, filepath.stat().st_size / (1024 ** 3))
|
||
continue
|
||
if not download_aria2c(mirror_urls, args.output, filename, args.connections):
|
||
log.error("Failed to download %s", filename)
|
||
return 1
|
||
|
||
total_elapsed: float = time.monotonic() - total_start
|
||
log.info("All downloads complete in %.0fs", total_elapsed)
|
||
for filename, _ in download_plan:
|
||
fp: Path = Path(args.output) / filename
|
||
if fp.exists():
|
||
log.info(" %s (%.1f GB)", fp.name, fp.stat().st_size / (1024 ** 3))
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|