2026-03-08 19:13:38 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
|
|
|
|
|
|
|
|
|
|
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
|
|
|
|
|
cluster, probing each for available snapshots, benchmarking download speed,
|
|
|
|
|
and downloading from the fastest source using aria2c (16 connections by default).
|
|
|
|
|
|
|
|
|
|
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
|
|
|
|
|
the single-connection wget download with aria2c parallel chunked downloads.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
|
|
|
|
|
./snapshot_download.py -o /srv/kind/solana/snapshots
|
|
|
|
|
|
|
|
|
|
# Dry run — find best source, print URL
|
|
|
|
|
./snapshot_download.py --dry-run
|
|
|
|
|
|
|
|
|
|
# Custom RPC for cluster discovery + 32 connections
|
|
|
|
|
./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32
|
|
|
|
|
|
|
|
|
|
# Testnet
|
|
|
|
|
./snapshot_download.py -c testnet -o /data/snapshots
|
|
|
|
|
|
|
|
|
|
# Programmatic use from entrypoint.py:
|
|
|
|
|
from snapshot_download import download_best_snapshot
|
|
|
|
|
ok = download_best_snapshot("/data/snapshots")
|
|
|
|
|
|
|
|
|
|
Requirements:
|
|
|
|
|
- aria2c (apt install aria2)
|
|
|
|
|
- python3 >= 3.10 (stdlib only, no pip dependencies)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import concurrent.futures
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import shutil
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
import urllib.error
|
|
|
|
|
import urllib.request
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from http.client import HTTPResponse
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from urllib.request import Request
|
|
|
|
|
|
|
|
|
|
log: logging.Logger = logging.getLogger("snapshot-download")
|
|
|
|
|
|
|
|
|
|
CLUSTER_RPC: dict[str, str] = {
|
|
|
|
|
"mainnet-beta": "https://api.mainnet-beta.solana.com",
|
|
|
|
|
"testnet": "https://api.testnet.solana.com",
|
|
|
|
|
"devnet": "https://api.devnet.solana.com",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Snapshot filenames:
|
|
|
|
|
# snapshot-<slot>-<hash>.tar.zst
|
|
|
|
|
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
|
|
|
|
|
FULL_SNAP_RE: re.Pattern[str] = re.compile(
|
|
|
|
|
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
|
|
|
)
|
|
|
|
|
INCR_SNAP_RE: re.Pattern[str] = re.compile(
|
|
|
|
|
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class SnapshotSource:
|
|
|
|
|
"""A snapshot file available from a specific RPC node."""
|
|
|
|
|
|
|
|
|
|
rpc_address: str
|
|
|
|
|
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
|
|
|
|
|
file_paths: list[str] = field(default_factory=list)
|
|
|
|
|
slots_diff: int = 0
|
|
|
|
|
latency_ms: float = 0.0
|
|
|
|
|
download_speed: float = 0.0 # bytes/sec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- JSON-RPC helpers ----------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
|
|
|
"""Handler that captures redirect Location instead of following it."""
|
|
|
|
|
|
|
|
|
|
def redirect_request(
|
|
|
|
|
self,
|
|
|
|
|
req: Request,
|
|
|
|
|
fp: HTTPResponse,
|
|
|
|
|
code: int,
|
|
|
|
|
msg: str,
|
|
|
|
|
headers: dict[str, str], # type: ignore[override]
|
|
|
|
|
newurl: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rpc_post(url: str, method: str, params: list[object] | None = None,
|
|
|
|
|
timeout: int = 25) -> object | None:
|
|
|
|
|
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
|
|
|
|
|
payload: bytes = json.dumps({
|
|
|
|
|
"jsonrpc": "2.0", "id": 1,
|
|
|
|
|
"method": method, "params": params or [],
|
|
|
|
|
}).encode()
|
|
|
|
|
req = Request(url, data=payload,
|
|
|
|
|
headers={"Content-Type": "application/json"})
|
|
|
|
|
try:
|
|
|
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
|
|
|
data: dict[str, object] = json.loads(resp.read())
|
|
|
|
|
return data.get("result")
|
|
|
|
|
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
|
|
|
|
|
log.debug("rpc_post %s %s failed: %s", url, method, e)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
|
|
|
|
|
"""HEAD request without following redirects.
|
|
|
|
|
|
|
|
|
|
Returns (Location header value, latency_sec) if the server returned a
|
|
|
|
|
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
|
|
|
|
|
"""
|
|
|
|
|
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
|
|
|
|
|
req = Request(url, method="HEAD")
|
|
|
|
|
try:
|
|
|
|
|
start: float = time.monotonic()
|
|
|
|
|
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
|
|
|
|
|
latency: float = time.monotonic() - start
|
|
|
|
|
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
|
|
|
|
|
location: str | None = resp.headers.get("Location")
|
|
|
|
|
resp.close()
|
|
|
|
|
return location, latency
|
|
|
|
|
except urllib.error.HTTPError as e:
|
|
|
|
|
# 3xx redirects raise HTTPError with the redirect info
|
|
|
|
|
latency = time.monotonic() - start # type: ignore[possibly-undefined]
|
|
|
|
|
location = e.headers.get("Location")
|
|
|
|
|
if location and 300 <= e.code < 400:
|
|
|
|
|
return location, latency
|
|
|
|
|
return None, 0.0
|
|
|
|
|
except (urllib.error.URLError, OSError, TimeoutError):
|
|
|
|
|
return None, 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- Discovery -----------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_slot(rpc_url: str) -> int | None:
|
|
|
|
|
"""Get current slot from RPC."""
|
|
|
|
|
result: object | None = rpc_post(rpc_url, "getSlot")
|
|
|
|
|
if isinstance(result, int):
|
|
|
|
|
return result
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
|
|
|
|
|
"""Get all RPC node addresses from getClusterNodes."""
|
|
|
|
|
result: object | None = rpc_post(rpc_url, "getClusterNodes")
|
|
|
|
|
if not isinstance(result, list):
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
rpc_addrs: list[str] = []
|
|
|
|
|
for node in result:
|
|
|
|
|
if not isinstance(node, dict):
|
|
|
|
|
continue
|
|
|
|
|
if version_filter is not None:
|
|
|
|
|
node_version: str | None = node.get("version")
|
|
|
|
|
if node_version and not node_version.startswith(version_filter):
|
|
|
|
|
continue
|
|
|
|
|
rpc: str | None = node.get("rpc")
|
|
|
|
|
if rpc:
|
|
|
|
|
rpc_addrs.append(rpc)
|
|
|
|
|
return list(set(rpc_addrs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
|
|
|
|
|
"""Extract filename and full redirect path from Location header.
|
|
|
|
|
|
|
|
|
|
Returns (filename, full_path). full_path includes any path prefix
|
|
|
|
|
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
|
|
|
|
|
"""
|
|
|
|
|
# Location may be absolute URL or relative path
|
|
|
|
|
if location.startswith("http://") or location.startswith("https://"):
|
|
|
|
|
# Absolute URL — extract path
|
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
path: str = urlparse(location).path
|
|
|
|
|
else:
|
|
|
|
|
path = location
|
|
|
|
|
|
|
|
|
|
filename: str = path.rsplit("/", 1)[-1]
|
|
|
|
|
return filename, path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def probe_rpc_snapshot(
|
|
|
|
|
rpc_address: str,
|
|
|
|
|
current_slot: int,
|
|
|
|
|
) -> SnapshotSource | None:
|
|
|
|
|
"""Probe a single RPC node for available snapshots.
|
|
|
|
|
|
|
|
|
|
Discovery only — no filtering. Returns a SnapshotSource with all available
|
|
|
|
|
info so the caller can decide what to keep. Filtering happens after all
|
|
|
|
|
probes complete, so rejected sources are still visible for debugging.
|
|
|
|
|
"""
|
|
|
|
|
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
|
|
|
|
|
|
|
|
# Full snapshot is required — every source must have one
|
|
|
|
|
full_location, full_latency = head_no_follow(full_url, timeout=2)
|
|
|
|
|
if not full_location:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
latency_ms: float = full_latency * 1000
|
|
|
|
|
|
|
|
|
|
full_filename, full_path = _parse_snapshot_filename(full_location)
|
|
|
|
|
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
|
|
|
|
if not fm:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
full_snap_slot: int = int(fm.group(1))
|
|
|
|
|
slots_diff: int = current_slot - full_snap_slot
|
|
|
|
|
|
|
|
|
|
file_paths: list[str] = [full_path]
|
|
|
|
|
|
|
|
|
|
# Also check for incremental snapshot
|
|
|
|
|
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
|
|
|
|
|
inc_location, _ = head_no_follow(inc_url, timeout=2)
|
|
|
|
|
if inc_location:
|
|
|
|
|
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
|
|
|
|
|
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
|
|
|
|
|
if m:
|
|
|
|
|
inc_base_slot: int = int(m.group(1))
|
|
|
|
|
# Incremental must be based on this source's full snapshot
|
|
|
|
|
if inc_base_slot == full_snap_slot:
|
|
|
|
|
file_paths.append(inc_path)
|
|
|
|
|
|
|
|
|
|
return SnapshotSource(
|
|
|
|
|
rpc_address=rpc_address,
|
|
|
|
|
file_paths=file_paths,
|
|
|
|
|
slots_diff=slots_diff,
|
|
|
|
|
latency_ms=latency_ms,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def discover_sources(
|
|
|
|
|
rpc_url: str,
|
|
|
|
|
current_slot: int,
|
|
|
|
|
max_age_slots: int,
|
|
|
|
|
max_latency_ms: float,
|
|
|
|
|
threads: int,
|
|
|
|
|
version_filter: str | None,
|
|
|
|
|
) -> list[SnapshotSource]:
|
|
|
|
|
"""Discover all snapshot sources, then filter.
|
|
|
|
|
|
|
|
|
|
Probing and filtering are separate: all reachable sources are collected
|
|
|
|
|
first so we can report what exists even if filters reject everything.
|
|
|
|
|
"""
|
|
|
|
|
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
|
|
|
|
|
if not rpc_nodes:
|
|
|
|
|
log.error("No RPC nodes found via getClusterNodes")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
|
|
|
|
|
|
|
|
|
|
all_sources: list[SnapshotSource] = []
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
|
|
|
|
|
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
|
|
|
|
|
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
|
|
|
|
|
for addr in rpc_nodes
|
|
|
|
|
}
|
|
|
|
|
done: int = 0
|
|
|
|
|
for future in concurrent.futures.as_completed(futures):
|
|
|
|
|
done += 1
|
|
|
|
|
if done % 200 == 0:
|
|
|
|
|
log.info(" probed %d/%d nodes, %d reachable",
|
|
|
|
|
done, len(rpc_nodes), len(all_sources))
|
|
|
|
|
try:
|
|
|
|
|
result: SnapshotSource | None = future.result()
|
|
|
|
|
except (urllib.error.URLError, OSError, TimeoutError) as e:
|
|
|
|
|
log.debug("Probe failed for %s: %s", futures[future], e)
|
|
|
|
|
continue
|
|
|
|
|
if result:
|
|
|
|
|
all_sources.append(result)
|
|
|
|
|
|
|
|
|
|
log.info("Discovered %d reachable sources", len(all_sources))
|
|
|
|
|
|
|
|
|
|
# Apply filters
|
|
|
|
|
filtered: list[SnapshotSource] = []
|
|
|
|
|
rejected_age: int = 0
|
|
|
|
|
rejected_latency: int = 0
|
|
|
|
|
for src in all_sources:
|
|
|
|
|
if src.slots_diff > max_age_slots or src.slots_diff < -100:
|
|
|
|
|
rejected_age += 1
|
|
|
|
|
continue
|
|
|
|
|
if src.latency_ms > max_latency_ms:
|
|
|
|
|
rejected_latency += 1
|
|
|
|
|
continue
|
|
|
|
|
filtered.append(src)
|
|
|
|
|
|
|
|
|
|
if rejected_age or rejected_latency:
|
|
|
|
|
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
|
|
|
|
|
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
|
|
|
|
|
|
|
|
|
|
if not filtered and all_sources:
|
|
|
|
|
# Show what was available so the user can adjust filters
|
|
|
|
|
all_sources.sort(key=lambda s: s.slots_diff)
|
|
|
|
|
best = all_sources[0]
|
|
|
|
|
log.warning("All %d sources rejected by filters. Best available: "
|
|
|
|
|
"%s (age=%d slots, latency=%.0fms). "
|
|
|
|
|
"Try --max-snapshot-age %d --max-latency %.0f",
|
|
|
|
|
len(all_sources), best.rpc_address,
|
|
|
|
|
best.slots_diff, best.latency_ms,
|
|
|
|
|
best.slots_diff + 500,
|
|
|
|
|
max(best.latency_ms * 1.5, 500))
|
|
|
|
|
|
|
|
|
|
log.info("Found %d sources after filtering", len(filtered))
|
|
|
|
|
return filtered
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- Speed benchmark -----------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
|
|
|
|
|
"""Measure download speed from an RPC node. Returns bytes/sec."""
|
|
|
|
|
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
|
|
|
req = Request(url)
|
|
|
|
|
try:
|
|
|
|
|
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
|
|
|
|
|
start: float = time.monotonic()
|
|
|
|
|
total: int = 0
|
|
|
|
|
while True:
|
|
|
|
|
elapsed: float = time.monotonic() - start
|
|
|
|
|
if elapsed >= measure_time:
|
|
|
|
|
break
|
|
|
|
|
chunk: bytes = resp.read(81920)
|
|
|
|
|
if not chunk:
|
|
|
|
|
break
|
|
|
|
|
total += len(chunk)
|
|
|
|
|
elapsed = time.monotonic() - start
|
|
|
|
|
if elapsed <= 0:
|
|
|
|
|
return 0.0
|
|
|
|
|
return total / elapsed
|
|
|
|
|
except (urllib.error.URLError, OSError, TimeoutError):
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- Download ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_aria2c(
|
|
|
|
|
urls: list[str],
|
|
|
|
|
output_dir: str,
|
|
|
|
|
filename: str,
|
|
|
|
|
connections: int = 16,
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Download a file using aria2c with parallel connections.
|
|
|
|
|
|
|
|
|
|
When multiple URLs are provided, aria2c treats them as mirrors of the
|
|
|
|
|
same file and distributes chunks across all of them.
|
|
|
|
|
"""
|
|
|
|
|
num_mirrors: int = len(urls)
|
|
|
|
|
total_splits: int = max(connections, connections * num_mirrors)
|
|
|
|
|
cmd: list[str] = [
|
|
|
|
|
"aria2c",
|
|
|
|
|
"--file-allocation=none",
|
|
|
|
|
"--continue=false",
|
|
|
|
|
f"--max-connection-per-server={connections}",
|
|
|
|
|
f"--split={total_splits}",
|
|
|
|
|
"--min-split-size=50M",
|
|
|
|
|
# aria2c retries individual chunk connections on transient network
|
|
|
|
|
# errors (TCP reset, timeout). This is transport-level retry analogous
|
|
|
|
|
# to TCP retransmit, not application-level retry of a failed operation.
|
|
|
|
|
"--max-tries=5",
|
|
|
|
|
"--retry-wait=5",
|
|
|
|
|
"--timeout=60",
|
|
|
|
|
"--connect-timeout=10",
|
|
|
|
|
"--summary-interval=10",
|
|
|
|
|
"--console-log-level=notice",
|
|
|
|
|
f"--dir={output_dir}",
|
|
|
|
|
f"--out={filename}",
|
|
|
|
|
"--auto-file-renaming=false",
|
|
|
|
|
"--allow-overwrite=true",
|
|
|
|
|
*urls,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
log.info("Downloading %s", filename)
|
|
|
|
|
log.info(" aria2c: %d connections x %d mirrors (%d splits)",
|
|
|
|
|
connections, num_mirrors, total_splits)
|
|
|
|
|
|
|
|
|
|
start: float = time.monotonic()
|
|
|
|
|
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
|
|
|
|
|
elapsed: float = time.monotonic() - start
|
|
|
|
|
|
|
|
|
|
if result.returncode != 0:
|
|
|
|
|
log.error("aria2c failed with exit code %d", result.returncode)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
filepath: Path = Path(output_dir) / filename
|
|
|
|
|
if not filepath.exists():
|
|
|
|
|
log.error("aria2c reported success but %s does not exist", filepath)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
size_bytes: int = filepath.stat().st_size
|
|
|
|
|
size_gb: float = size_bytes / (1024 ** 3)
|
|
|
|
|
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
|
|
|
|
|
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- Public API ----------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_best_snapshot(
|
|
|
|
|
output_dir: str,
|
|
|
|
|
*,
|
|
|
|
|
cluster: str = "mainnet-beta",
|
|
|
|
|
rpc_url: str | None = None,
|
|
|
|
|
connections: int = 16,
|
|
|
|
|
threads: int = 500,
|
|
|
|
|
max_snapshot_age: int = 10000,
|
|
|
|
|
max_latency: float = 500,
|
|
|
|
|
min_download_speed: int = 20,
|
|
|
|
|
measurement_time: int = 7,
|
|
|
|
|
max_speed_checks: int = 15,
|
|
|
|
|
version_filter: str | None = None,
|
|
|
|
|
full_only: bool = False,
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Download the best available snapshot to output_dir.
|
|
|
|
|
|
|
|
|
|
This is the programmatic API — called by entrypoint.py for automatic
|
|
|
|
|
snapshot download. Returns True on success, False on failure.
|
|
|
|
|
|
|
|
|
|
All parameters have sensible defaults matching the CLI interface.
|
|
|
|
|
"""
|
|
|
|
|
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
|
|
|
|
|
|
|
|
|
|
if not shutil.which("aria2c"):
|
|
|
|
|
log.error("aria2c not found. Install with: apt install aria2")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
|
|
|
|
|
current_slot: int | None = get_current_slot(resolved_rpc)
|
|
|
|
|
if current_slot is None:
|
|
|
|
|
log.error("Cannot get current slot from %s", resolved_rpc)
|
|
|
|
|
return False
|
|
|
|
|
log.info("Current slot: %d", current_slot)
|
|
|
|
|
|
|
|
|
|
sources: list[SnapshotSource] = discover_sources(
|
|
|
|
|
resolved_rpc, current_slot,
|
|
|
|
|
max_age_slots=max_snapshot_age,
|
|
|
|
|
max_latency_ms=max_latency,
|
|
|
|
|
threads=threads,
|
|
|
|
|
version_filter=version_filter,
|
|
|
|
|
)
|
|
|
|
|
if not sources:
|
|
|
|
|
log.error("No snapshot sources found")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Sort by latency (lowest first) for speed benchmarking
|
|
|
|
|
sources.sort(key=lambda s: s.latency_ms)
|
|
|
|
|
|
|
|
|
|
# Benchmark top candidates
|
|
|
|
|
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
|
|
|
|
|
fast_sources: list[SnapshotSource] = []
|
|
|
|
|
checked: int = 0
|
|
|
|
|
min_speed_bytes: int = min_download_speed * 1024 * 1024
|
|
|
|
|
|
|
|
|
|
for source in sources:
|
|
|
|
|
if checked >= max_speed_checks:
|
|
|
|
|
break
|
|
|
|
|
checked += 1
|
|
|
|
|
|
|
|
|
|
speed: float = measure_speed(source.rpc_address, measurement_time)
|
|
|
|
|
source.download_speed = speed
|
|
|
|
|
speed_mib: float = speed / (1024 ** 2)
|
|
|
|
|
|
|
|
|
|
if speed < min_speed_bytes:
|
|
|
|
|
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
|
|
|
|
|
source.rpc_address, speed_mib, min_download_speed)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
|
|
|
|
|
source.rpc_address, speed_mib,
|
|
|
|
|
source.latency_ms, source.slots_diff)
|
|
|
|
|
fast_sources.append(source)
|
|
|
|
|
|
|
|
|
|
if not fast_sources:
|
|
|
|
|
log.error("No source met minimum speed requirement (%d MiB/s)",
|
|
|
|
|
min_download_speed)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Use the fastest source as primary, collect mirrors for each file
|
|
|
|
|
best: SnapshotSource = fast_sources[0]
|
|
|
|
|
file_paths: list[str] = best.file_paths
|
|
|
|
|
if full_only:
|
|
|
|
|
file_paths = [fp for fp in file_paths
|
|
|
|
|
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
|
|
|
|
|
|
|
|
|
|
# Build mirror URL lists
|
|
|
|
|
download_plan: list[tuple[str, list[str]]] = []
|
|
|
|
|
for fp in file_paths:
|
|
|
|
|
filename: str = fp.rsplit("/", 1)[-1]
|
|
|
|
|
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
|
|
|
|
|
for other in fast_sources[1:]:
|
|
|
|
|
for other_fp in other.file_paths:
|
|
|
|
|
if other_fp.rsplit("/", 1)[-1] == filename:
|
|
|
|
|
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
|
|
|
|
|
break
|
|
|
|
|
download_plan.append((filename, mirror_urls))
|
|
|
|
|
|
|
|
|
|
speed_mib: float = best.download_speed / (1024 ** 2)
|
|
|
|
|
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
|
|
|
|
|
best.rpc_address, speed_mib, len(fast_sources))
|
|
|
|
|
for filename, mirror_urls in download_plan:
|
|
|
|
|
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
|
|
|
|
|
|
2026-03-09 05:14:43 +00:00
|
|
|
# Download — full snapshot first, then re-probe for fresh incremental
|
2026-03-08 19:13:38 +00:00
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
total_start: float = time.monotonic()
|
|
|
|
|
|
2026-03-09 05:14:43 +00:00
|
|
|
# Separate full and incremental from the initial plan
|
|
|
|
|
full_downloads: list[tuple[str, list[str]]] = []
|
2026-03-08 19:13:38 +00:00
|
|
|
for filename, mirror_urls in download_plan:
|
2026-03-09 05:14:43 +00:00
|
|
|
if filename.startswith("snapshot-"):
|
|
|
|
|
full_downloads.append((filename, mirror_urls))
|
|
|
|
|
|
|
|
|
|
# Download full snapshot(s)
|
|
|
|
|
for filename, mirror_urls in full_downloads:
|
2026-03-08 19:13:38 +00:00
|
|
|
filepath: Path = Path(output_dir) / filename
|
|
|
|
|
if filepath.exists() and filepath.stat().st_size > 0:
|
|
|
|
|
log.info("Skipping %s (already exists: %.1f GB)",
|
|
|
|
|
filename, filepath.stat().st_size / (1024 ** 3))
|
|
|
|
|
continue
|
|
|
|
|
if not download_aria2c(mirror_urls, output_dir, filename, connections):
|
|
|
|
|
log.error("Failed to download %s", filename)
|
|
|
|
|
return False
|
|
|
|
|
|
2026-03-09 05:14:43 +00:00
|
|
|
# After full snapshot download, re-probe for a fresh incremental.
|
|
|
|
|
# The initial incremental is stale by now (full download takes 10+ min).
|
|
|
|
|
if not full_only:
|
|
|
|
|
# Get the full snapshot slot from the filename we just downloaded
|
|
|
|
|
full_filename: str = full_downloads[0][0]
|
|
|
|
|
fm_post: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
|
|
|
|
if fm_post:
|
|
|
|
|
full_snap_slot: int = int(fm_post.group(1))
|
|
|
|
|
log.info("Re-probing for fresh incremental based on slot %d...", full_snap_slot)
|
|
|
|
|
inc_downloaded: bool = False
|
|
|
|
|
for source in fast_sources:
|
|
|
|
|
inc_url_re: str = f"http://{source.rpc_address}/incremental-snapshot.tar.bz2"
|
|
|
|
|
inc_location, _ = head_no_follow(inc_url_re, timeout=2)
|
|
|
|
|
if not inc_location:
|
|
|
|
|
continue
|
|
|
|
|
inc_fn, inc_fp = _parse_snapshot_filename(inc_location)
|
|
|
|
|
m_inc: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn)
|
|
|
|
|
if not m_inc:
|
|
|
|
|
continue
|
|
|
|
|
if int(m_inc.group(1)) != full_snap_slot:
|
|
|
|
|
log.debug(" %s: incremental base slot %s != full %d, skipping",
|
|
|
|
|
source.rpc_address, m_inc.group(1), full_snap_slot)
|
|
|
|
|
continue
|
|
|
|
|
# Found a matching incremental — build mirror list and download
|
|
|
|
|
inc_mirrors: list[str] = [f"http://{source.rpc_address}{inc_fp}"]
|
|
|
|
|
for other in fast_sources:
|
|
|
|
|
if other.rpc_address == source.rpc_address:
|
|
|
|
|
continue
|
|
|
|
|
other_loc, _ = head_no_follow(
|
|
|
|
|
f"http://{other.rpc_address}/incremental-snapshot.tar.bz2", timeout=2)
|
|
|
|
|
if other_loc:
|
|
|
|
|
other_fn, other_fp = _parse_snapshot_filename(other_loc)
|
|
|
|
|
if other_fn == inc_fn:
|
|
|
|
|
inc_mirrors.append(f"http://{other.rpc_address}{other_fp}")
|
|
|
|
|
log.info(" Found incremental %s (%d mirrors)", inc_fn, len(inc_mirrors))
|
|
|
|
|
if download_aria2c(inc_mirrors, output_dir, inc_fn, connections):
|
|
|
|
|
inc_downloaded = True
|
|
|
|
|
break
|
|
|
|
|
if not inc_downloaded:
|
|
|
|
|
log.info("No matching incremental found — validator will replay from full snapshot")
|
|
|
|
|
|
2026-03-08 19:13:38 +00:00
|
|
|
total_elapsed: float = time.monotonic() - total_start
|
|
|
|
|
log.info("All downloads complete in %.0fs", total_elapsed)
|
|
|
|
|
for filename, _ in download_plan:
|
|
|
|
|
fp_path: Path = Path(output_dir) / filename
|
|
|
|
|
if fp_path.exists():
|
|
|
|
|
log.info(" %s (%.1f GB)", fp_path.name, fp_path.stat().st_size / (1024 ** 3))
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -- Main (CLI) ----------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> int:
|
|
|
|
|
p: argparse.ArgumentParser = argparse.ArgumentParser(
|
|
|
|
|
description="Download Solana snapshots with aria2c parallel downloads",
|
|
|
|
|
)
|
|
|
|
|
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
|
|
|
|
|
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
|
|
|
|
|
p.add_argument("-c", "--cluster", default="mainnet-beta",
|
|
|
|
|
choices=list(CLUSTER_RPC),
|
|
|
|
|
help="Solana cluster (default: mainnet-beta)")
|
|
|
|
|
p.add_argument("-r", "--rpc", default=None,
|
|
|
|
|
help="RPC URL for cluster discovery (default: public RPC)")
|
|
|
|
|
p.add_argument("-n", "--connections", type=int, default=16,
|
|
|
|
|
help="aria2c connections per download (default: 16)")
|
|
|
|
|
p.add_argument("-t", "--threads", type=int, default=500,
|
|
|
|
|
help="Threads for parallel RPC probing (default: 500)")
|
|
|
|
|
p.add_argument("--max-snapshot-age", type=int, default=10000,
|
|
|
|
|
help="Max snapshot age in slots (default: 10000)")
|
|
|
|
|
p.add_argument("--max-latency", type=float, default=500,
|
|
|
|
|
help="Max RPC probe latency in ms (default: 500)")
|
|
|
|
|
p.add_argument("--min-download-speed", type=int, default=20,
|
|
|
|
|
help="Min download speed in MiB/s (default: 20)")
|
|
|
|
|
p.add_argument("--measurement-time", type=int, default=7,
|
|
|
|
|
help="Speed measurement duration in seconds (default: 7)")
|
|
|
|
|
p.add_argument("--max-speed-checks", type=int, default=15,
|
|
|
|
|
help="Max nodes to benchmark before giving up (default: 15)")
|
|
|
|
|
p.add_argument("--version", default=None,
|
|
|
|
|
help="Filter nodes by version prefix (e.g. '2.2')")
|
|
|
|
|
p.add_argument("--full-only", action="store_true",
|
|
|
|
|
help="Download only full snapshot, skip incremental")
|
|
|
|
|
p.add_argument("--dry-run", action="store_true",
|
|
|
|
|
help="Find best source and print URL, don't download")
|
|
|
|
|
p.add_argument("--post-cmd",
|
|
|
|
|
help="Shell command to run after successful download "
|
|
|
|
|
"(e.g. 'kubectl scale deployment ... --replicas=1')")
|
|
|
|
|
p.add_argument("-v", "--verbose", action="store_true")
|
|
|
|
|
args: argparse.Namespace = p.parse_args()
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
|
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
|
|
|
datefmt="%H:%M:%S",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Dry-run uses the original inline flow (needs access to sources for URL printing)
|
|
|
|
|
if args.dry_run:
|
|
|
|
|
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
|
|
|
|
|
current_slot: int | None = get_current_slot(rpc_url)
|
|
|
|
|
if current_slot is None:
|
|
|
|
|
log.error("Cannot get current slot from %s", rpc_url)
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
sources: list[SnapshotSource] = discover_sources(
|
|
|
|
|
rpc_url, current_slot,
|
|
|
|
|
max_age_slots=args.max_snapshot_age,
|
|
|
|
|
max_latency_ms=args.max_latency,
|
|
|
|
|
threads=args.threads,
|
|
|
|
|
version_filter=args.version,
|
|
|
|
|
)
|
|
|
|
|
if not sources:
|
|
|
|
|
log.error("No snapshot sources found")
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
sources.sort(key=lambda s: s.latency_ms)
|
|
|
|
|
best = sources[0]
|
|
|
|
|
for fp in best.file_paths:
|
|
|
|
|
print(f"http://{best.rpc_address}{fp}")
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
ok: bool = download_best_snapshot(
|
|
|
|
|
args.output,
|
|
|
|
|
cluster=args.cluster,
|
|
|
|
|
rpc_url=args.rpc,
|
|
|
|
|
connections=args.connections,
|
|
|
|
|
threads=args.threads,
|
|
|
|
|
max_snapshot_age=args.max_snapshot_age,
|
|
|
|
|
max_latency=args.max_latency,
|
|
|
|
|
min_download_speed=args.min_download_speed,
|
|
|
|
|
measurement_time=args.measurement_time,
|
|
|
|
|
max_speed_checks=args.max_speed_checks,
|
|
|
|
|
version_filter=args.version,
|
|
|
|
|
full_only=args.full_only,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if ok and args.post_cmd:
|
|
|
|
|
log.info("Running post-download command: %s", args.post_cmd)
|
|
|
|
|
result: subprocess.CompletedProcess[bytes] = subprocess.run(
|
|
|
|
|
args.post_cmd, shell=True,
|
|
|
|
|
)
|
|
|
|
|
if result.returncode != 0:
|
|
|
|
|
log.error("Post-download command failed with exit code %d",
|
|
|
|
|
result.returncode)
|
|
|
|
|
return 1
|
|
|
|
|
log.info("Post-download command completed successfully")
|
|
|
|
|
|
|
|
|
|
return 0 if ok else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
sys.exit(main())
|