feat: add iptables playbook, symlink snapshot-download.py to agave-stack
- playbooks/biscayne-iptables.yml: manages PREROUTING DNAT and DOCKER-USER rules for both host IP (186.233.184.235) and relay loopback (137.239.194.65). Idempotent, persists via netfilter-persistent. - scripts/snapshot-download.py: replaced standalone copy with symlink to agave-stack source of truth, eliminating duplication. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>fix/kind-mount-propagation
parent
ec12e6079b
commit
078872d78d
|
|
@ -0,0 +1,116 @@
|
||||||
|
---
|
||||||
|
# Configure iptables DNAT rules for agave validator
|
||||||
|
#
|
||||||
|
# Routes external traffic to the kind node's pod IP (172.20.0.2).
|
||||||
|
# Rules must be inserted BEFORE Docker's ADDRTYPE LOCAL rule in
|
||||||
|
# PREROUTING, otherwise Docker's chain swallows the traffic.
|
||||||
|
#
|
||||||
|
# Two external IPs:
|
||||||
|
# 186.233.184.235 — primary host IP (direct access)
|
||||||
|
# 137.239.194.65 — Ashburn relay loopback (GRE tunnel endpoint)
|
||||||
|
#
|
||||||
|
# DOCKER-USER chain rules allow forwarded traffic to reach the pod.
|
||||||
|
#
|
||||||
|
# Idempotent: checks for existing rules before inserting.
|
||||||
|
# Persistent: saves rules via iptables-persistent.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ansible-playbook playbooks/biscayne-iptables.yml
|
||||||
|
#
|
||||||
|
- name: Configure iptables for agave validator
|
||||||
|
hosts: all
|
||||||
|
gather_facts: false
|
||||||
|
become: true
|
||||||
|
vars:
|
||||||
|
pod_ip: 172.20.0.2
|
||||||
|
host_ip: 186.233.184.235
|
||||||
|
relay_ip: 137.239.194.65
|
||||||
|
# Ports to forward
|
||||||
|
tcp_ports:
|
||||||
|
- 8899 # RPC
|
||||||
|
- 8900 # RPC WebSocket
|
||||||
|
- 8001 # Gossip
|
||||||
|
udp_ports:
|
||||||
|
- 8001 # Gossip UDP
|
||||||
|
udp_ranges:
|
||||||
|
- 9000:9025 # Validator dynamic ports
|
||||||
|
|
||||||
|
tasks:
|
||||||
|
# ---- PREROUTING DNAT rules ---------------------------------------------------
|
||||||
|
# Host IP rules (186.233.184.235 → pod)
|
||||||
|
- name: "PREROUTING DNAT host IP TCP ports to pod"
|
||||||
|
ansible.builtin.iptables:
|
||||||
|
table: nat
|
||||||
|
chain: PREROUTING
|
||||||
|
protocol: tcp
|
||||||
|
destination: "{{ host_ip }}"
|
||||||
|
destination_port: "{{ item }}"
|
||||||
|
jump: DNAT
|
||||||
|
to_destination: "{{ pod_ip }}:{{ item }}"
|
||||||
|
action: insert
|
||||||
|
rule_num: 1
|
||||||
|
loop: "{{ tcp_ports }}"
|
||||||
|
register: prerouting_host_tcp
|
||||||
|
|
||||||
|
# Relay IP rules (137.239.194.65 → pod)
|
||||||
|
- name: "PREROUTING DNAT relay IP TCP ports to pod"
|
||||||
|
ansible.builtin.iptables:
|
||||||
|
table: nat
|
||||||
|
chain: PREROUTING
|
||||||
|
protocol: tcp
|
||||||
|
destination: "{{ relay_ip }}"
|
||||||
|
destination_port: "{{ item }}"
|
||||||
|
jump: DNAT
|
||||||
|
to_destination: "{{ pod_ip }}:{{ item }}"
|
||||||
|
action: insert
|
||||||
|
rule_num: 1
|
||||||
|
loop: "{{ tcp_ports }}"
|
||||||
|
register: prerouting_relay_tcp
|
||||||
|
|
||||||
|
- name: "PREROUTING DNAT relay IP UDP ports to pod"
|
||||||
|
ansible.builtin.iptables:
|
||||||
|
table: nat
|
||||||
|
chain: PREROUTING
|
||||||
|
protocol: udp
|
||||||
|
destination: "{{ relay_ip }}"
|
||||||
|
destination_port: "{{ item }}"
|
||||||
|
jump: DNAT
|
||||||
|
to_destination: "{{ pod_ip }}"
|
||||||
|
action: insert
|
||||||
|
rule_num: 1
|
||||||
|
loop: "{{ udp_ports + udp_ranges }}"
|
||||||
|
register: prerouting_relay_udp
|
||||||
|
|
||||||
|
# ---- DOCKER-USER accept rules ------------------------------------------------
|
||||||
|
- name: "DOCKER-USER accept TCP to pod"
|
||||||
|
ansible.builtin.iptables:
|
||||||
|
chain: DOCKER-USER
|
||||||
|
protocol: tcp
|
||||||
|
destination: "{{ pod_ip }}"
|
||||||
|
destination_port: "{{ item }}"
|
||||||
|
jump: ACCEPT
|
||||||
|
action: insert
|
||||||
|
rule_num: 1
|
||||||
|
loop: "{{ tcp_ports }}"
|
||||||
|
register: dockeruser_tcp
|
||||||
|
|
||||||
|
- name: "DOCKER-USER accept UDP to pod"
|
||||||
|
ansible.builtin.iptables:
|
||||||
|
chain: DOCKER-USER
|
||||||
|
protocol: udp
|
||||||
|
destination: "{{ pod_ip }}"
|
||||||
|
destination_port: "{{ item }}"
|
||||||
|
jump: ACCEPT
|
||||||
|
action: insert
|
||||||
|
rule_num: 1
|
||||||
|
loop: "{{ udp_ports + udp_ranges }}"
|
||||||
|
register: dockeruser_udp
|
||||||
|
|
||||||
|
# ---- Persist rules -----------------------------------------------------------
|
||||||
|
- name: Save iptables rules
|
||||||
|
ansible.builtin.command: netfilter-persistent save
|
||||||
|
changed_when: true
|
||||||
|
when: >-
|
||||||
|
prerouting_host_tcp.changed or prerouting_relay_tcp.changed or
|
||||||
|
prerouting_relay_udp.changed or dockeruser_tcp.changed or
|
||||||
|
dockeruser_udp.changed
|
||||||
|
|
@ -1,635 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
|
|
||||||
|
|
||||||
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
|
|
||||||
cluster, probing each for available snapshots, benchmarking download speed,
|
|
||||||
and downloading from the fastest source using aria2c (16 connections by default).
|
|
||||||
|
|
||||||
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
|
|
||||||
the single-connection wget download with aria2c parallel chunked downloads.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
|
|
||||||
./snapshot-download.py -o /srv/kind/solana/snapshots
|
|
||||||
|
|
||||||
# Dry run — find best source, print URL
|
|
||||||
./snapshot-download.py --dry-run
|
|
||||||
|
|
||||||
# Custom RPC for cluster node discovery + 32 connections
|
|
||||||
./snapshot-download.py -r https://api.mainnet-beta.solana.com -n 32
|
|
||||||
|
|
||||||
# Testnet
|
|
||||||
./snapshot-download.py -c testnet -o /data/snapshots
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- aria2c (apt install aria2)
|
|
||||||
- python3 >= 3.10 (stdlib only, no pip dependencies)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import concurrent.futures
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from http.client import HTTPResponse
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.request import Request
|
|
||||||
|
|
||||||
log: logging.Logger = logging.getLogger("snapshot-download")
|
|
||||||
|
|
||||||
CLUSTER_RPC: dict[str, str] = {
|
|
||||||
"mainnet-beta": "https://api.mainnet-beta.solana.com",
|
|
||||||
"testnet": "https://api.testnet.solana.com",
|
|
||||||
"devnet": "https://api.devnet.solana.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Snapshot filenames:
|
|
||||||
# snapshot-<slot>-<hash>.tar.zst
|
|
||||||
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
|
|
||||||
FULL_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
INCR_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SnapshotSource:
|
|
||||||
"""A snapshot file available from a specific RPC node."""
|
|
||||||
|
|
||||||
rpc_address: str
|
|
||||||
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
|
|
||||||
file_paths: list[str] = field(default_factory=list)
|
|
||||||
slots_diff: int = 0
|
|
||||||
latency_ms: float = 0.0
|
|
||||||
download_speed: float = 0.0 # bytes/sec
|
|
||||||
|
|
||||||
|
|
||||||
# -- JSON-RPC helpers ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
||||||
"""Handler that captures redirect Location instead of following it."""
|
|
||||||
|
|
||||||
def redirect_request(
|
|
||||||
self,
|
|
||||||
req: Request,
|
|
||||||
fp: HTTPResponse,
|
|
||||||
code: int,
|
|
||||||
msg: str,
|
|
||||||
headers: dict[str, str], # type: ignore[override]
|
|
||||||
newurl: str,
|
|
||||||
) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def rpc_post(url: str, method: str, params: list[object] | None = None,
|
|
||||||
timeout: int = 25) -> object | None:
|
|
||||||
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
|
|
||||||
payload: bytes = json.dumps({
|
|
||||||
"jsonrpc": "2.0", "id": 1,
|
|
||||||
"method": method, "params": params or [],
|
|
||||||
}).encode()
|
|
||||||
req = Request(url, data=payload,
|
|
||||||
headers={"Content-Type": "application/json"})
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
||||||
data: dict[str, object] = json.loads(resp.read())
|
|
||||||
return data.get("result")
|
|
||||||
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
|
|
||||||
log.debug("rpc_post %s %s failed: %s", url, method, e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
|
|
||||||
"""HEAD request without following redirects.
|
|
||||||
|
|
||||||
Returns (Location header value, latency_sec) if the server returned a
|
|
||||||
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
|
|
||||||
"""
|
|
||||||
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
|
|
||||||
req = Request(url, method="HEAD")
|
|
||||||
try:
|
|
||||||
start: float = time.monotonic()
|
|
||||||
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
|
|
||||||
latency: float = time.monotonic() - start
|
|
||||||
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
|
|
||||||
location: str | None = resp.headers.get("Location")
|
|
||||||
resp.close()
|
|
||||||
return location, latency
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
# 3xx redirects raise HTTPError with the redirect info
|
|
||||||
latency = time.monotonic() - start # type: ignore[possibly-undefined]
|
|
||||||
location = e.headers.get("Location")
|
|
||||||
if location and 300 <= e.code < 400:
|
|
||||||
return location, latency
|
|
||||||
return None, 0.0
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError):
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
|
|
||||||
# -- Discovery -----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_slot(rpc_url: str) -> int | None:
|
|
||||||
"""Get current slot from RPC."""
|
|
||||||
result: object | None = rpc_post(rpc_url, "getSlot")
|
|
||||||
if isinstance(result, int):
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
|
|
||||||
"""Get all RPC node addresses from getClusterNodes."""
|
|
||||||
result: object | None = rpc_post(rpc_url, "getClusterNodes")
|
|
||||||
if not isinstance(result, list):
|
|
||||||
return []
|
|
||||||
|
|
||||||
rpc_addrs: list[str] = []
|
|
||||||
for node in result:
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
continue
|
|
||||||
if version_filter is not None:
|
|
||||||
node_version: str | None = node.get("version")
|
|
||||||
if node_version and not node_version.startswith(version_filter):
|
|
||||||
continue
|
|
||||||
rpc: str | None = node.get("rpc")
|
|
||||||
if rpc:
|
|
||||||
rpc_addrs.append(rpc)
|
|
||||||
return list(set(rpc_addrs))
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
|
|
||||||
"""Extract filename and full redirect path from Location header.
|
|
||||||
|
|
||||||
Returns (filename, full_path). full_path includes any path prefix
|
|
||||||
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
|
|
||||||
"""
|
|
||||||
# Location may be absolute URL or relative path
|
|
||||||
if location.startswith("http://") or location.startswith("https://"):
|
|
||||||
# Absolute URL — extract path
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
path: str = urlparse(location).path
|
|
||||||
else:
|
|
||||||
path = location
|
|
||||||
|
|
||||||
filename: str = path.rsplit("/", 1)[-1]
|
|
||||||
return filename, path
|
|
||||||
|
|
||||||
|
|
||||||
def probe_rpc_snapshot(
|
|
||||||
rpc_address: str,
|
|
||||||
current_slot: int,
|
|
||||||
) -> SnapshotSource | None:
|
|
||||||
"""Probe a single RPC node for available snapshots.
|
|
||||||
|
|
||||||
Discovery only — no filtering. Returns a SnapshotSource with all available
|
|
||||||
info so the caller can decide what to keep. Filtering happens after all
|
|
||||||
probes complete, so rejected sources are still visible for debugging.
|
|
||||||
"""
|
|
||||||
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
||||||
|
|
||||||
# Full snapshot is required — every source must have one
|
|
||||||
full_location, full_latency = head_no_follow(full_url, timeout=2)
|
|
||||||
if not full_location:
|
|
||||||
return None
|
|
||||||
|
|
||||||
latency_ms: float = full_latency * 1000
|
|
||||||
|
|
||||||
full_filename, full_path = _parse_snapshot_filename(full_location)
|
|
||||||
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
|
||||||
if not fm:
|
|
||||||
return None
|
|
||||||
|
|
||||||
full_snap_slot: int = int(fm.group(1))
|
|
||||||
slots_diff: int = current_slot - full_snap_slot
|
|
||||||
|
|
||||||
file_paths: list[str] = [full_path]
|
|
||||||
|
|
||||||
# Also check for incremental snapshot
|
|
||||||
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
|
|
||||||
inc_location, _ = head_no_follow(inc_url, timeout=2)
|
|
||||||
if inc_location:
|
|
||||||
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
|
|
||||||
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
|
|
||||||
if m:
|
|
||||||
inc_base_slot: int = int(m.group(1))
|
|
||||||
# Incremental must be based on this source's full snapshot
|
|
||||||
if inc_base_slot == full_snap_slot:
|
|
||||||
file_paths.append(inc_path)
|
|
||||||
|
|
||||||
return SnapshotSource(
|
|
||||||
rpc_address=rpc_address,
|
|
||||||
file_paths=file_paths,
|
|
||||||
slots_diff=slots_diff,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def discover_sources(
|
|
||||||
rpc_url: str,
|
|
||||||
current_slot: int,
|
|
||||||
max_age_slots: int,
|
|
||||||
max_latency_ms: float,
|
|
||||||
threads: int,
|
|
||||||
version_filter: str | None,
|
|
||||||
) -> list[SnapshotSource]:
|
|
||||||
"""Discover all snapshot sources, then filter.
|
|
||||||
|
|
||||||
Probing and filtering are separate: all reachable sources are collected
|
|
||||||
first so we can report what exists even if filters reject everything.
|
|
||||||
"""
|
|
||||||
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
|
|
||||||
if not rpc_nodes:
|
|
||||||
log.error("No RPC nodes found via getClusterNodes")
|
|
||||||
return []
|
|
||||||
|
|
||||||
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
|
|
||||||
|
|
||||||
all_sources: list[SnapshotSource] = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
|
|
||||||
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
|
|
||||||
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
|
|
||||||
for addr in rpc_nodes
|
|
||||||
}
|
|
||||||
done: int = 0
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
done += 1
|
|
||||||
if done % 200 == 0:
|
|
||||||
log.info(" probed %d/%d nodes, %d reachable",
|
|
||||||
done, len(rpc_nodes), len(all_sources))
|
|
||||||
try:
|
|
||||||
result: SnapshotSource | None = future.result()
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError) as e:
|
|
||||||
log.debug("Probe failed for %s: %s", futures[future], e)
|
|
||||||
continue
|
|
||||||
if result:
|
|
||||||
all_sources.append(result)
|
|
||||||
|
|
||||||
log.info("Discovered %d reachable sources", len(all_sources))
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
filtered: list[SnapshotSource] = []
|
|
||||||
rejected_age: int = 0
|
|
||||||
rejected_latency: int = 0
|
|
||||||
for src in all_sources:
|
|
||||||
if src.slots_diff > max_age_slots or src.slots_diff < -100:
|
|
||||||
rejected_age += 1
|
|
||||||
continue
|
|
||||||
if src.latency_ms > max_latency_ms:
|
|
||||||
rejected_latency += 1
|
|
||||||
continue
|
|
||||||
filtered.append(src)
|
|
||||||
|
|
||||||
if rejected_age or rejected_latency:
|
|
||||||
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
|
|
||||||
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
|
|
||||||
|
|
||||||
if not filtered and all_sources:
|
|
||||||
# Show what was available so the user can adjust filters
|
|
||||||
all_sources.sort(key=lambda s: s.slots_diff)
|
|
||||||
best = all_sources[0]
|
|
||||||
log.warning("All %d sources rejected by filters. Best available: "
|
|
||||||
"%s (age=%d slots, latency=%.0fms). "
|
|
||||||
"Try --max-snapshot-age %d --max-latency %.0f",
|
|
||||||
len(all_sources), best.rpc_address,
|
|
||||||
best.slots_diff, best.latency_ms,
|
|
||||||
best.slots_diff + 500,
|
|
||||||
max(best.latency_ms * 1.5, 500))
|
|
||||||
|
|
||||||
log.info("Found %d sources after filtering", len(filtered))
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
|
|
||||||
# -- Speed benchmark -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
|
|
||||||
"""Measure download speed from an RPC node. Returns bytes/sec."""
|
|
||||||
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
||||||
req = Request(url)
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
|
|
||||||
start: float = time.monotonic()
|
|
||||||
total: int = 0
|
|
||||||
while True:
|
|
||||||
elapsed: float = time.monotonic() - start
|
|
||||||
if elapsed >= measure_time:
|
|
||||||
break
|
|
||||||
chunk: bytes = resp.read(81920)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
total += len(chunk)
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
if elapsed <= 0:
|
|
||||||
return 0.0
|
|
||||||
return total / elapsed
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError):
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
# -- Download ------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def download_aria2c(
|
|
||||||
urls: list[str],
|
|
||||||
output_dir: str,
|
|
||||||
filename: str,
|
|
||||||
connections: int = 16,
|
|
||||||
) -> bool:
|
|
||||||
"""Download a file using aria2c with parallel connections.
|
|
||||||
|
|
||||||
When multiple URLs are provided, aria2c treats them as mirrors of the
|
|
||||||
same file and distributes chunks across all of them.
|
|
||||||
"""
|
|
||||||
num_mirrors: int = len(urls)
|
|
||||||
total_splits: int = max(connections, connections * num_mirrors)
|
|
||||||
cmd: list[str] = [
|
|
||||||
"aria2c",
|
|
||||||
"--file-allocation=none",
|
|
||||||
"--continue=false",
|
|
||||||
f"--max-connection-per-server={connections}",
|
|
||||||
f"--split={total_splits}",
|
|
||||||
"--min-split-size=50M",
|
|
||||||
# aria2c retries individual chunk connections on transient network
|
|
||||||
# errors (TCP reset, timeout). This is transport-level retry analogous
|
|
||||||
# to TCP retransmit, not application-level retry of a failed operation.
|
|
||||||
"--max-tries=5",
|
|
||||||
"--retry-wait=5",
|
|
||||||
"--timeout=60",
|
|
||||||
"--connect-timeout=10",
|
|
||||||
"--summary-interval=10",
|
|
||||||
"--console-log-level=notice",
|
|
||||||
f"--dir={output_dir}",
|
|
||||||
f"--out={filename}",
|
|
||||||
"--auto-file-renaming=false",
|
|
||||||
"--allow-overwrite=true",
|
|
||||||
*urls,
|
|
||||||
]
|
|
||||||
|
|
||||||
log.info("Downloading %s", filename)
|
|
||||||
log.info(" aria2c: %d connections × %d mirrors (%d splits)",
|
|
||||||
connections, num_mirrors, total_splits)
|
|
||||||
|
|
||||||
start: float = time.monotonic()
|
|
||||||
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
|
|
||||||
elapsed: float = time.monotonic() - start
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
log.error("aria2c failed with exit code %d", result.returncode)
|
|
||||||
return False
|
|
||||||
|
|
||||||
filepath: Path = Path(output_dir) / filename
|
|
||||||
if not filepath.exists():
|
|
||||||
log.error("aria2c reported success but %s does not exist", filepath)
|
|
||||||
return False
|
|
||||||
|
|
||||||
size_bytes: int = filepath.stat().st_size
|
|
||||||
size_gb: float = size_bytes / (1024 ** 3)
|
|
||||||
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
|
|
||||||
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# -- Public API ----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def download_best_snapshot(
|
|
||||||
output_dir: str,
|
|
||||||
*,
|
|
||||||
cluster: str = "mainnet-beta",
|
|
||||||
rpc_url: str | None = None,
|
|
||||||
connections: int = 16,
|
|
||||||
threads: int = 500,
|
|
||||||
max_snapshot_age: int = 10000,
|
|
||||||
max_latency: float = 500,
|
|
||||||
min_download_speed: int = 20,
|
|
||||||
measurement_time: int = 7,
|
|
||||||
max_speed_checks: int = 15,
|
|
||||||
version_filter: str | None = None,
|
|
||||||
full_only: bool = False,
|
|
||||||
) -> bool:
|
|
||||||
"""Download the best available snapshot to output_dir.
|
|
||||||
|
|
||||||
Programmatic API for use by entrypoint.py or other callers.
|
|
||||||
Returns True on success, False on failure.
|
|
||||||
"""
|
|
||||||
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
|
|
||||||
|
|
||||||
if not shutil.which("aria2c"):
|
|
||||||
log.error("aria2c not found. Install with: apt install aria2")
|
|
||||||
return False
|
|
||||||
|
|
||||||
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
|
|
||||||
current_slot: int | None = get_current_slot(resolved_rpc)
|
|
||||||
if current_slot is None:
|
|
||||||
log.error("Cannot get current slot from %s", resolved_rpc)
|
|
||||||
return False
|
|
||||||
log.info("Current slot: %d", current_slot)
|
|
||||||
|
|
||||||
sources: list[SnapshotSource] = discover_sources(
|
|
||||||
resolved_rpc, current_slot,
|
|
||||||
max_age_slots=max_snapshot_age,
|
|
||||||
max_latency_ms=max_latency,
|
|
||||||
threads=threads,
|
|
||||||
version_filter=version_filter,
|
|
||||||
)
|
|
||||||
if not sources:
|
|
||||||
log.error("No snapshot sources found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Sort by latency (lowest first) for speed benchmarking
|
|
||||||
sources.sort(key=lambda s: s.latency_ms)
|
|
||||||
|
|
||||||
# Benchmark top candidates
|
|
||||||
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
|
|
||||||
fast_sources: list[SnapshotSource] = []
|
|
||||||
checked: int = 0
|
|
||||||
min_speed_bytes: int = min_download_speed * 1024 * 1024
|
|
||||||
|
|
||||||
for source in sources:
|
|
||||||
if checked >= max_speed_checks:
|
|
||||||
break
|
|
||||||
checked += 1
|
|
||||||
|
|
||||||
speed: float = measure_speed(source.rpc_address, measurement_time)
|
|
||||||
source.download_speed = speed
|
|
||||||
speed_mib: float = speed / (1024 ** 2)
|
|
||||||
|
|
||||||
if speed < min_speed_bytes:
|
|
||||||
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
|
|
||||||
source.rpc_address, speed_mib, min_download_speed)
|
|
||||||
continue
|
|
||||||
|
|
||||||
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
|
|
||||||
source.rpc_address, speed_mib,
|
|
||||||
source.latency_ms, source.slots_diff)
|
|
||||||
fast_sources.append(source)
|
|
||||||
|
|
||||||
if not fast_sources:
|
|
||||||
log.error("No source met minimum speed requirement (%d MiB/s)",
|
|
||||||
min_download_speed)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Use the fastest source as primary, collect mirrors for each file
|
|
||||||
best: SnapshotSource = fast_sources[0]
|
|
||||||
file_paths: list[str] = best.file_paths
|
|
||||||
if full_only:
|
|
||||||
file_paths = [fp for fp in file_paths
|
|
||||||
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
|
|
||||||
|
|
||||||
# Build mirror URL lists
|
|
||||||
download_plan: list[tuple[str, list[str]]] = []
|
|
||||||
for fp in file_paths:
|
|
||||||
filename: str = fp.rsplit("/", 1)[-1]
|
|
||||||
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
|
|
||||||
for other in fast_sources[1:]:
|
|
||||||
for other_fp in other.file_paths:
|
|
||||||
if other_fp.rsplit("/", 1)[-1] == filename:
|
|
||||||
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
|
|
||||||
break
|
|
||||||
download_plan.append((filename, mirror_urls))
|
|
||||||
|
|
||||||
speed_mib: float = best.download_speed / (1024 ** 2)
|
|
||||||
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
|
|
||||||
best.rpc_address, speed_mib, len(fast_sources))
|
|
||||||
for filename, mirror_urls in download_plan:
|
|
||||||
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
|
|
||||||
|
|
||||||
# Download
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
total_start: float = time.monotonic()
|
|
||||||
|
|
||||||
for filename, mirror_urls in download_plan:
|
|
||||||
filepath: Path = Path(output_dir) / filename
|
|
||||||
if filepath.exists() and filepath.stat().st_size > 0:
|
|
||||||
log.info("Skipping %s (already exists: %.1f GB)",
|
|
||||||
filename, filepath.stat().st_size / (1024 ** 3))
|
|
||||||
continue
|
|
||||||
if not download_aria2c(mirror_urls, output_dir, filename, connections):
|
|
||||||
log.error("Failed to download %s", filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
total_elapsed: float = time.monotonic() - total_start
|
|
||||||
log.info("All downloads complete in %.0fs", total_elapsed)
|
|
||||||
for filename, _ in download_plan:
|
|
||||||
fp_path: Path = Path(output_dir) / filename
|
|
||||||
if fp_path.exists():
|
|
||||||
log.info(" %s (%.1f GB)", fp_path.name, fp_path.stat().st_size / (1024 ** 3))
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# -- Main (CLI) ----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
p: argparse.ArgumentParser = argparse.ArgumentParser(
|
|
||||||
description="Download Solana snapshots with aria2c parallel downloads",
|
|
||||||
)
|
|
||||||
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
|
|
||||||
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
|
|
||||||
p.add_argument("-c", "--cluster", default="mainnet-beta",
|
|
||||||
choices=list(CLUSTER_RPC),
|
|
||||||
help="Solana cluster (default: mainnet-beta)")
|
|
||||||
p.add_argument("-r", "--rpc", default=None,
|
|
||||||
help="RPC URL for cluster discovery (default: public RPC)")
|
|
||||||
p.add_argument("-n", "--connections", type=int, default=16,
|
|
||||||
help="aria2c connections per download (default: 16)")
|
|
||||||
p.add_argument("-t", "--threads", type=int, default=500,
|
|
||||||
help="Threads for parallel RPC probing (default: 500)")
|
|
||||||
p.add_argument("--max-snapshot-age", type=int, default=10000,
|
|
||||||
help="Max snapshot age in slots (default: 10000)")
|
|
||||||
p.add_argument("--max-latency", type=float, default=500,
|
|
||||||
help="Max RPC probe latency in ms (default: 500)")
|
|
||||||
p.add_argument("--min-download-speed", type=int, default=20,
|
|
||||||
help="Min download speed in MiB/s (default: 20)")
|
|
||||||
p.add_argument("--measurement-time", type=int, default=7,
|
|
||||||
help="Speed measurement duration in seconds (default: 7)")
|
|
||||||
p.add_argument("--max-speed-checks", type=int, default=15,
|
|
||||||
help="Max nodes to benchmark before giving up (default: 15)")
|
|
||||||
p.add_argument("--version", default=None,
|
|
||||||
help="Filter nodes by version prefix (e.g. '2.2')")
|
|
||||||
p.add_argument("--full-only", action="store_true",
|
|
||||||
help="Download only full snapshot, skip incremental")
|
|
||||||
p.add_argument("--dry-run", action="store_true",
|
|
||||||
help="Find best source and print URL, don't download")
|
|
||||||
p.add_argument("--post-cmd",
|
|
||||||
help="Shell command to run after successful download "
|
|
||||||
"(e.g. 'kubectl scale deployment ... --replicas=1')")
|
|
||||||
p.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
args: argparse.Namespace = p.parse_args()
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dry-run uses inline flow (needs access to sources for URL printing)
|
|
||||||
if args.dry_run:
|
|
||||||
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
|
|
||||||
current_slot: int | None = get_current_slot(rpc_url)
|
|
||||||
if current_slot is None:
|
|
||||||
log.error("Cannot get current slot from %s", rpc_url)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
sources: list[SnapshotSource] = discover_sources(
|
|
||||||
rpc_url, current_slot,
|
|
||||||
max_age_slots=args.max_snapshot_age,
|
|
||||||
max_latency_ms=args.max_latency,
|
|
||||||
threads=args.threads,
|
|
||||||
version_filter=args.version,
|
|
||||||
)
|
|
||||||
if not sources:
|
|
||||||
log.error("No snapshot sources found")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
sources.sort(key=lambda s: s.latency_ms)
|
|
||||||
best = sources[0]
|
|
||||||
for fp in best.file_paths:
|
|
||||||
print(f"http://{best.rpc_address}{fp}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
ok: bool = download_best_snapshot(
|
|
||||||
args.output,
|
|
||||||
cluster=args.cluster,
|
|
||||||
rpc_url=args.rpc,
|
|
||||||
connections=args.connections,
|
|
||||||
threads=args.threads,
|
|
||||||
max_snapshot_age=args.max_snapshot_age,
|
|
||||||
max_latency=args.max_latency,
|
|
||||||
min_download_speed=args.min_download_speed,
|
|
||||||
measurement_time=args.measurement_time,
|
|
||||||
max_speed_checks=args.max_speed_checks,
|
|
||||||
version_filter=args.version,
|
|
||||||
full_only=args.full_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ok and args.post_cmd:
|
|
||||||
log.info("Running post-download command: %s", args.post_cmd)
|
|
||||||
result: subprocess.CompletedProcess[bytes] = subprocess.run(
|
|
||||||
args.post_cmd, shell=True,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
log.error("Post-download command failed with exit code %d",
|
|
||||||
result.returncode)
|
|
||||||
return 1
|
|
||||||
log.info("Post-download command completed successfully")
|
|
||||||
|
|
||||||
return 0 if ok else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
../../agave-stack/stack-orchestrator/container-build/laconicnetwork-agave/snapshot_download.py
|
||||||
Loading…
Reference in New Issue