commit f4b3a46109a8da00fdd68d8999160ddc45dcc88a Author: A. F. Dudley Date: Sun Mar 8 19:13:38 2026 +0000 Squashed 'scripts/agave-container/' content from commit 4b5c875 git-subtree-dir: scripts/agave-container git-subtree-split: 4b5c875a05cbbfbde38eeb053fd5443a8a50228c diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..68a12508 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,81 @@ +# Unified Agave/Jito Solana image +# Supports three modes via AGAVE_MODE env: test, rpc, validator +# +# Build args: +# AGAVE_REPO - git repo URL (anza-xyz/agave or jito-foundation/jito-solana) +# AGAVE_VERSION - git tag to build (e.g. v3.1.9, v3.1.8-jito) + +ARG AGAVE_REPO=https://github.com/anza-xyz/agave.git +ARG AGAVE_VERSION=v3.1.9 + +# ---------- Stage 1: Build ---------- +FROM rust:1.85-bookworm AS builder + +ARG AGAVE_REPO +ARG AGAVE_VERSION + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + pkg-config \ + libssl-dev \ + libudev-dev \ + libclang-dev \ + protobuf-compiler \ + ca-certificates \ + git \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /build +RUN git clone "$AGAVE_REPO" --depth 1 --branch "$AGAVE_VERSION" --recurse-submodules agave +WORKDIR /build/agave + +# Cherry-pick --public-tvu-address support (anza-xyz/agave PR #6778, commit 9f4b3ae) +# This flag only exists on master, not in v3.1.9 — fetch the PR ref and cherry-pick +ARG TVU_ADDRESS_PR=6778 +RUN if [ -n "$TVU_ADDRESS_PR" ]; then \ + git fetch --depth 50 origin "pull/${TVU_ADDRESS_PR}/head:tvu-pr" && \ + git cherry-pick --no-commit tvu-pr; \ + fi + +# Build all binaries using the upstream install script +RUN CI_COMMIT=$(git rev-parse HEAD) scripts/cargo-install-all.sh /solana-release + +# ---------- Stage 2: Runtime ---------- +FROM debian:bookworm-slim + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl3 \ + libudev1 \ + curl \ + sudo \ + aria2 \ + python3 \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user with sudo +RUN useradd -m -s /bin/bash agave \ + && echo "agave ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers + +# Copy all compiled binaries +COPY --from=builder /solana-release/bin/ /usr/local/bin/ + +# Copy entrypoint and support scripts +COPY entrypoint.py snapshot_download.py /usr/local/bin/ +COPY start-test.sh /usr/local/bin/ +RUN chmod +x /usr/local/bin/entrypoint.py /usr/local/bin/start-test.sh + +# Create data directories +RUN mkdir -p /data/config /data/ledger /data/accounts /data/snapshots \ + && chown -R agave:agave /data + +USER agave +WORKDIR /data + +ENV RUST_LOG=info +ENV RUST_BACKTRACE=1 + +EXPOSE 8899 8900 8001 8001/udp + +ENTRYPOINT ["entrypoint.py"] diff --git a/build.sh b/build.sh new file mode 100644 index 00000000..4c4d940f --- /dev/null +++ b/build.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +# Build laconicnetwork/agave +# Set AGAVE_REPO and AGAVE_VERSION env vars to build Jito or a different version +source ${CERC_CONTAINER_BASE_DIR}/build-base.sh + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +AGAVE_REPO="${AGAVE_REPO:-https://github.com/anza-xyz/agave.git}" +AGAVE_VERSION="${AGAVE_VERSION:-v3.1.9}" + +docker build -t laconicnetwork/agave:local \ + --build-arg AGAVE_REPO="$AGAVE_REPO" \ + --build-arg AGAVE_VERSION="$AGAVE_VERSION" \ + ${build_command_args} \ + -f ${SCRIPT_DIR}/Dockerfile \ + ${SCRIPT_DIR} diff --git a/entrypoint.py b/entrypoint.py new file mode 100644 index 00000000..1122fc9c --- /dev/null +++ b/entrypoint.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 +"""Agave validator entrypoint — snapshot management, arg construction, liveness probe. + +Two subcommands: + entrypoint.py serve (default) — snapshot freshness check + exec agave-validator + entrypoint.py probe — liveness probe (slot lag check, exits 0/1) + +Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single +Python module. Test mode still dispatches to start-test.sh. + +All configuration comes from environment variables — same vars as the original +bash scripts. See compose files for defaults. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import subprocess +import sys +import time +import urllib.error +import urllib.request +from pathlib import Path +from urllib.request import Request + +log: logging.Logger = logging.getLogger("entrypoint") + +# Directories +CONFIG_DIR = "/data/config" +LEDGER_DIR = "/data/ledger" +ACCOUNTS_DIR = "/data/accounts" +SNAPSHOTS_DIR = "/data/snapshots" +LOG_DIR = "/data/log" +IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json" + +# Snapshot filename pattern +FULL_SNAP_RE: re.Pattern[str] = re.compile( + r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$" +) + +MAINNET_RPC = "https://api.mainnet-beta.solana.com" + + +# -- Helpers ------------------------------------------------------------------- + + +def env(name: str, default: str = "") -> str: + """Read env var with default.""" + return os.environ.get(name, default) + + +def env_required(name: str) -> str: + """Read required env var, exit if missing.""" + val = os.environ.get(name) + if not val: + log.error("%s is required but not set", name) + sys.exit(1) + return val + + +def env_bool(name: str, default: bool = False) -> bool: + """Read boolean env var (true/false/1/0).""" + val = os.environ.get(name, "").lower() + if not val: + return default + return val in ("true", "1", "yes") + + +def rpc_get_slot(url: str, timeout: int = 10) -> int | None: + """Get current slot from a Solana RPC endpoint.""" + payload = json.dumps({ + "jsonrpc": "2.0", "id": 1, + "method": "getSlot", "params": [], + }).encode() + req = Request(url, data=payload, + headers={"Content-Type": "application/json"}) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read()) + result = data.get("result") + if isinstance(result, int): + return result + except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError): + pass + return None + + +# -- Snapshot management ------------------------------------------------------- + + +def get_local_snapshot_slot(snapshots_dir: str) -> int | None: + """Find the highest slot among local snapshot files.""" + best_slot: int | None = None + snap_path = Path(snapshots_dir) + if not snap_path.is_dir(): + return None + for entry in snap_path.iterdir(): + m = FULL_SNAP_RE.match(entry.name) + if m: + slot = int(m.group(1)) + if best_slot is None or slot > best_slot: + best_slot = slot + return best_slot + + +def clean_snapshots(snapshots_dir: str) -> None: + """Remove all snapshot files from the directory.""" + snap_path = Path(snapshots_dir) + if not snap_path.is_dir(): + return + for entry in snap_path.iterdir(): + if entry.name.startswith(("snapshot-", "incremental-snapshot-")): + log.info("Removing old snapshot: %s", entry.name) + entry.unlink(missing_ok=True) + + +def maybe_download_snapshot(snapshots_dir: str) -> None: + """Check snapshot freshness and download if needed. + + Controlled by env vars: + SNAPSHOT_AUTO_DOWNLOAD (default: true) — enable/disable + SNAPSHOT_MAX_AGE_SLOTS (default: 20000) — staleness threshold + """ + if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True): + log.info("Snapshot auto-download disabled") + return + + max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "20000")) + + # Get mainnet current slot + mainnet_slot = rpc_get_slot(MAINNET_RPC) + if mainnet_slot is None: + log.warning("Cannot reach mainnet RPC — skipping snapshot check") + return + + # Check local snapshot + local_slot = get_local_snapshot_slot(snapshots_dir) + if local_slot is not None: + age = mainnet_slot - local_slot + log.info("Local snapshot at slot %d (mainnet: %d, age: %d slots)", + local_slot, mainnet_slot, age) + if age <= max_age: + log.info("Snapshot is fresh enough (age %d <= %d), skipping download", age, max_age) + return + log.info("Snapshot is stale (age %d > %d), downloading fresh", age, max_age) + else: + log.info("No local snapshot found, downloading") + + # Clean old snapshots before downloading + clean_snapshots(snapshots_dir) + + # Import and call snapshot download + # snapshot_download.py is installed alongside this file in /usr/local/bin/ + script_dir = Path(__file__).resolve().parent + sys.path.insert(0, str(script_dir)) + from snapshot_download import download_best_snapshot + + ok = download_best_snapshot(snapshots_dir) + if not ok: + log.error("Snapshot download failed — starting without fresh snapshot") + + +# -- Directory and identity setup ---------------------------------------------- + + +def ensure_dirs(*dirs: str) -> None: + """Create directories and fix ownership.""" + uid = os.getuid() + gid = os.getgid() + for d in dirs: + os.makedirs(d, exist_ok=True) + try: + subprocess.run( + ["sudo", "chown", "-R", f"{uid}:{gid}", d], + check=False, capture_output=True, + ) + except FileNotFoundError: + pass # sudo not available — dirs already owned correctly + + +def ensure_identity_rpc() -> None: + """Generate ephemeral identity keypair for RPC mode if not mounted.""" + if os.path.isfile(IDENTITY_FILE): + return + log.info("Generating RPC node identity keypair...") + subprocess.run( + ["solana-keygen", "new", "--no-passphrase", "--silent", + "--force", "--outfile", IDENTITY_FILE], + check=True, + ) + + +def print_identity() -> None: + """Print the node identity pubkey.""" + result = subprocess.run( + ["solana-keygen", "pubkey", IDENTITY_FILE], + capture_output=True, text=True, check=False, + ) + if result.returncode == 0: + log.info("Node identity: %s", result.stdout.strip()) + + +# -- Arg construction ---------------------------------------------------------- + + +def build_common_args() -> list[str]: + """Build agave-validator args common to both RPC and validator modes.""" + args: list[str] = [ + "--identity", IDENTITY_FILE, + "--entrypoint", env_required("VALIDATOR_ENTRYPOINT"), + "--known-validator", env_required("KNOWN_VALIDATOR"), + "--ledger", LEDGER_DIR, + "--accounts", ACCOUNTS_DIR, + "--snapshots", SNAPSHOTS_DIR, + "--rpc-port", env("RPC_PORT", "8899"), + "--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"), + "--gossip-port", env("GOSSIP_PORT", "8001"), + "--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"), + "--no-os-network-limits-test", + "--wal-recovery-mode", "skip_any_corrupted_record", + "--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"), + ] + + # Snapshot generation + if env("NO_SNAPSHOTS") == "true": + args.append("--no-snapshots") + else: + args += [ + "--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"), + "--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "5"), + ] + if env("NO_INCREMENTAL_SNAPSHOTS") != "true": + args += ["--maximum-incremental-snapshots-to-retain", "2"] + + # Account indexes + account_indexes = env("ACCOUNT_INDEXES") + if account_indexes: + for idx in account_indexes.split(","): + idx = idx.strip() + if idx: + args += ["--account-index", idx] + + # Additional entrypoints + for ep in env("EXTRA_ENTRYPOINTS").split(): + if ep: + args += ["--entrypoint", ep] + + # Additional known validators + for kv in env("EXTRA_KNOWN_VALIDATORS").split(): + if kv: + args += ["--known-validator", kv] + + # Cluster verification + genesis_hash = env("EXPECTED_GENESIS_HASH") + if genesis_hash: + args += ["--expected-genesis-hash", genesis_hash] + shred_version = env("EXPECTED_SHRED_VERSION") + if shred_version: + args += ["--expected-shred-version", shred_version] + + # Metrics — just needs to be in the environment, agave reads it directly + # (env var is already set, nothing to pass as arg) + + # Gossip host / TVU address + gossip_host = env("GOSSIP_HOST") + if gossip_host: + args += ["--gossip-host", gossip_host] + elif env("PUBLIC_TVU_ADDRESS"): + args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")] + + # Jito flags + if env("JITO_ENABLE") == "true": + log.info("Jito MEV enabled") + jito_flags: list[tuple[str, str]] = [ + ("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"), + ("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"), + ("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"), + ("JITO_COMMISSION_BPS", "--commission-bps"), + ("JITO_BLOCK_ENGINE_URL", "--block-engine-url"), + ("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"), + ] + for env_name, flag in jito_flags: + val = env(env_name) + if val: + args += [flag, val] + + return args + + +def build_rpc_args() -> list[str]: + """Build agave-validator args for RPC (non-voting) mode.""" + args = build_common_args() + args += [ + "--no-voting", + "--log", f"{LOG_DIR}/validator.log", + "--full-rpc-api", + "--enable-rpc-transaction-history", + "--rpc-pubsub-enable-block-subscription", + "--enable-extended-tx-metadata-storage", + "--no-wait-for-vote-to-start-leader", + "--no-snapshot-fetch", + ] + + # Public vs private RPC + public_rpc = env("PUBLIC_RPC_ADDRESS") + if public_rpc: + args += ["--public-rpc-address", public_rpc] + else: + args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"] + + # Jito relayer URL (RPC mode doesn't use it, but validator mode does — + # handled in build_validator_args) + + return args + + +def build_validator_args() -> list[str]: + """Build agave-validator args for voting validator mode.""" + vote_keypair = env("VOTE_ACCOUNT_KEYPAIR", + "/data/config/vote-account-keypair.json") + + # Identity must be mounted for validator mode + if not os.path.isfile(IDENTITY_FILE): + log.error("Validator identity keypair not found at %s", IDENTITY_FILE) + log.error("Mount your validator keypair to %s", IDENTITY_FILE) + sys.exit(1) + + # Vote account keypair must exist + if not os.path.isfile(vote_keypair): + log.error("Vote account keypair not found at %s", vote_keypair) + log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR") + sys.exit(1) + + # Print vote account pubkey + result = subprocess.run( + ["solana-keygen", "pubkey", vote_keypair], + capture_output=True, text=True, check=False, + ) + if result.returncode == 0: + log.info("Vote account: %s", result.stdout.strip()) + + args = build_common_args() + args += [ + "--vote-account", vote_keypair, + "--log", "-", + ] + + # Jito relayer URL (validator-only) + relayer_url = env("JITO_RELAYER_URL") + if env("JITO_ENABLE") == "true" and relayer_url: + args += ["--relayer-url", relayer_url] + + return args + + +def append_extra_args(args: list[str]) -> list[str]: + """Append EXTRA_ARGS passthrough flags.""" + extra = env("EXTRA_ARGS") + if extra: + args += extra.split() + return args + + +# -- Serve subcommand --------------------------------------------------------- + + +def cmd_serve() -> None: + """Main serve flow: snapshot check, setup, exec agave-validator.""" + mode = env("AGAVE_MODE", "test") + log.info("AGAVE_MODE=%s", mode) + + # Test mode dispatches to start-test.sh + if mode == "test": + os.execvp("start-test.sh", ["start-test.sh"]) + + if mode not in ("rpc", "validator"): + log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode) + sys.exit(1) + + # Ensure directories + dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR] + if mode == "rpc": + dirs.append(LOG_DIR) + ensure_dirs(*dirs) + + # Snapshot freshness check and auto-download + maybe_download_snapshot(SNAPSHOTS_DIR) + + # Identity setup + if mode == "rpc": + ensure_identity_rpc() + print_identity() + + # Build args + if mode == "rpc": + args = build_rpc_args() + else: + args = build_validator_args() + + args = append_extra_args(args) + + # Write startup timestamp for probe grace period + Path("/tmp/entrypoint-start").write_text(str(time.time())) + + log.info("Starting agave-validator with %d arguments", len(args)) + os.execvp("agave-validator", ["agave-validator"] + args) + + +# -- Probe subcommand --------------------------------------------------------- + + +def cmd_probe() -> None: + """Liveness probe: check local RPC slot vs mainnet. + + Exit 0 = healthy, exit 1 = unhealthy. + + Grace period: PROBE_GRACE_SECONDS (default 600) — probe always passes + during grace period to allow for snapshot unpacking and initial replay. + """ + grace_seconds = int(env("PROBE_GRACE_SECONDS", "600")) + max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000")) + + # Check grace period + start_file = Path("/tmp/entrypoint-start") + if start_file.exists(): + try: + start_time = float(start_file.read_text().strip()) + elapsed = time.time() - start_time + if elapsed < grace_seconds: + # Within grace period — always healthy + sys.exit(0) + except (ValueError, OSError): + pass + else: + # No start file — serve hasn't started yet, within grace + sys.exit(0) + + # Query local RPC + rpc_port = env("RPC_PORT", "8899") + local_url = f"http://127.0.0.1:{rpc_port}" + local_slot = rpc_get_slot(local_url, timeout=5) + if local_slot is None: + # Local RPC unreachable after grace period — unhealthy + sys.exit(1) + + # Query mainnet + mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10) + if mainnet_slot is None: + # Can't reach mainnet to compare — assume healthy (don't penalize + # the validator for mainnet RPC being down) + sys.exit(0) + + lag = mainnet_slot - local_slot + if lag > max_lag: + sys.exit(1) + + sys.exit(0) + + +# -- Main ---------------------------------------------------------------------- + + +def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", + datefmt="%H:%M:%S", + ) + + subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve" + + if subcmd == "serve": + cmd_serve() + elif subcmd == "probe": + cmd_probe() + else: + log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/snapshot_download.py b/snapshot_download.py new file mode 100644 index 00000000..61a39019 --- /dev/null +++ b/snapshot_download.py @@ -0,0 +1,641 @@ +#!/usr/bin/env python3 +"""Download Solana snapshots using aria2c for parallel multi-connection downloads. + +Discovers snapshot sources by querying getClusterNodes for all RPCs in the +cluster, probing each for available snapshots, benchmarking download speed, +and downloading from the fastest source using aria2c (16 connections by default). + +Based on the discovery approach from etcusr/solana-snapshot-finder but replaces +the single-connection wget download with aria2c parallel chunked downloads. + +Usage: + # Download to /srv/kind/solana/snapshots (mainnet, 16 connections) + ./snapshot_download.py -o /srv/kind/solana/snapshots + + # Dry run — find best source, print URL + ./snapshot_download.py --dry-run + + # Custom RPC for cluster discovery + 32 connections + ./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32 + + # Testnet + ./snapshot_download.py -c testnet -o /data/snapshots + + # Programmatic use from entrypoint.py: + from snapshot_download import download_best_snapshot + ok = download_best_snapshot("/data/snapshots") + +Requirements: + - aria2c (apt install aria2) + - python3 >= 3.10 (stdlib only, no pip dependencies) +""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import logging +import os +import re +import shutil +import subprocess +import sys +import time +import urllib.error +import urllib.request +from dataclasses import dataclass, field +from http.client import HTTPResponse +from pathlib import Path +from urllib.request import Request + +log: logging.Logger = logging.getLogger("snapshot-download") + +CLUSTER_RPC: dict[str, str] = { + "mainnet-beta": "https://api.mainnet-beta.solana.com", + "testnet": "https://api.testnet.solana.com", + "devnet": "https://api.devnet.solana.com", +} + +# Snapshot filenames: +# snapshot--.tar.zst +# incremental-snapshot---.tar.zst +FULL_SNAP_RE: re.Pattern[str] = re.compile( + r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" +) +INCR_SNAP_RE: re.Pattern[str] = re.compile( + r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" +) + + +@dataclass +class SnapshotSource: + """A snapshot file available from a specific RPC node.""" + + rpc_address: str + # Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst) + file_paths: list[str] = field(default_factory=list) + slots_diff: int = 0 + latency_ms: float = 0.0 + download_speed: float = 0.0 # bytes/sec + + +# -- JSON-RPC helpers ---------------------------------------------------------- + + +class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): + """Handler that captures redirect Location instead of following it.""" + + def redirect_request( + self, + req: Request, + fp: HTTPResponse, + code: int, + msg: str, + headers: dict[str, str], # type: ignore[override] + newurl: str, + ) -> None: + return None + + +def rpc_post(url: str, method: str, params: list[object] | None = None, + timeout: int = 25) -> object | None: + """JSON-RPC POST. Returns parsed 'result' field or None on error.""" + payload: bytes = json.dumps({ + "jsonrpc": "2.0", "id": 1, + "method": method, "params": params or [], + }).encode() + req = Request(url, data=payload, + headers={"Content-Type": "application/json"}) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + data: dict[str, object] = json.loads(resp.read()) + return data.get("result") + except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e: + log.debug("rpc_post %s %s failed: %s", url, method, e) + return None + + +def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]: + """HEAD request without following redirects. + + Returns (Location header value, latency_sec) if the server returned a + 3xx redirect. Returns (None, 0.0) on any error or non-redirect response. + """ + opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler) + req = Request(url, method="HEAD") + try: + start: float = time.monotonic() + resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment] + latency: float = time.monotonic() - start + # Non-redirect (2xx) — server didn't redirect, not useful for discovery + location: str | None = resp.headers.get("Location") + resp.close() + return location, latency + except urllib.error.HTTPError as e: + # 3xx redirects raise HTTPError with the redirect info + latency = time.monotonic() - start # type: ignore[possibly-undefined] + location = e.headers.get("Location") + if location and 300 <= e.code < 400: + return location, latency + return None, 0.0 + except (urllib.error.URLError, OSError, TimeoutError): + return None, 0.0 + + +# -- Discovery ----------------------------------------------------------------- + + +def get_current_slot(rpc_url: str) -> int | None: + """Get current slot from RPC.""" + result: object | None = rpc_post(rpc_url, "getSlot") + if isinstance(result, int): + return result + return None + + +def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]: + """Get all RPC node addresses from getClusterNodes.""" + result: object | None = rpc_post(rpc_url, "getClusterNodes") + if not isinstance(result, list): + return [] + + rpc_addrs: list[str] = [] + for node in result: + if not isinstance(node, dict): + continue + if version_filter is not None: + node_version: str | None = node.get("version") + if node_version and not node_version.startswith(version_filter): + continue + rpc: str | None = node.get("rpc") + if rpc: + rpc_addrs.append(rpc) + return list(set(rpc_addrs)) + + +def _parse_snapshot_filename(location: str) -> tuple[str, str | None]: + """Extract filename and full redirect path from Location header. + + Returns (filename, full_path). full_path includes any path prefix + the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst'). + """ + # Location may be absolute URL or relative path + if location.startswith("http://") or location.startswith("https://"): + # Absolute URL — extract path + from urllib.parse import urlparse + path: str = urlparse(location).path + else: + path = location + + filename: str = path.rsplit("/", 1)[-1] + return filename, path + + +def probe_rpc_snapshot( + rpc_address: str, + current_slot: int, +) -> SnapshotSource | None: + """Probe a single RPC node for available snapshots. + + Discovery only — no filtering. Returns a SnapshotSource with all available + info so the caller can decide what to keep. Filtering happens after all + probes complete, so rejected sources are still visible for debugging. + """ + full_url: str = f"http://{rpc_address}/snapshot.tar.bz2" + + # Full snapshot is required — every source must have one + full_location, full_latency = head_no_follow(full_url, timeout=2) + if not full_location: + return None + + latency_ms: float = full_latency * 1000 + + full_filename, full_path = _parse_snapshot_filename(full_location) + fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename) + if not fm: + return None + + full_snap_slot: int = int(fm.group(1)) + slots_diff: int = current_slot - full_snap_slot + + file_paths: list[str] = [full_path] + + # Also check for incremental snapshot + inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2" + inc_location, _ = head_no_follow(inc_url, timeout=2) + if inc_location: + inc_filename, inc_path = _parse_snapshot_filename(inc_location) + m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename) + if m: + inc_base_slot: int = int(m.group(1)) + # Incremental must be based on this source's full snapshot + if inc_base_slot == full_snap_slot: + file_paths.append(inc_path) + + return SnapshotSource( + rpc_address=rpc_address, + file_paths=file_paths, + slots_diff=slots_diff, + latency_ms=latency_ms, + ) + + +def discover_sources( + rpc_url: str, + current_slot: int, + max_age_slots: int, + max_latency_ms: float, + threads: int, + version_filter: str | None, +) -> list[SnapshotSource]: + """Discover all snapshot sources, then filter. + + Probing and filtering are separate: all reachable sources are collected + first so we can report what exists even if filters reject everything. + """ + rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter) + if not rpc_nodes: + log.error("No RPC nodes found via getClusterNodes") + return [] + + log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes)) + + all_sources: list[SnapshotSource] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: + futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = { + pool.submit(probe_rpc_snapshot, addr, current_slot): addr + for addr in rpc_nodes + } + done: int = 0 + for future in concurrent.futures.as_completed(futures): + done += 1 + if done % 200 == 0: + log.info(" probed %d/%d nodes, %d reachable", + done, len(rpc_nodes), len(all_sources)) + try: + result: SnapshotSource | None = future.result() + except (urllib.error.URLError, OSError, TimeoutError) as e: + log.debug("Probe failed for %s: %s", futures[future], e) + continue + if result: + all_sources.append(result) + + log.info("Discovered %d reachable sources", len(all_sources)) + + # Apply filters + filtered: list[SnapshotSource] = [] + rejected_age: int = 0 + rejected_latency: int = 0 + for src in all_sources: + if src.slots_diff > max_age_slots or src.slots_diff < -100: + rejected_age += 1 + continue + if src.latency_ms > max_latency_ms: + rejected_latency += 1 + continue + filtered.append(src) + + if rejected_age or rejected_latency: + log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)", + rejected_age, max_age_slots, rejected_latency, max_latency_ms) + + if not filtered and all_sources: + # Show what was available so the user can adjust filters + all_sources.sort(key=lambda s: s.slots_diff) + best = all_sources[0] + log.warning("All %d sources rejected by filters. Best available: " + "%s (age=%d slots, latency=%.0fms). " + "Try --max-snapshot-age %d --max-latency %.0f", + len(all_sources), best.rpc_address, + best.slots_diff, best.latency_ms, + best.slots_diff + 500, + max(best.latency_ms * 1.5, 500)) + + log.info("Found %d sources after filtering", len(filtered)) + return filtered + + +# -- Speed benchmark ----------------------------------------------------------- + + +def measure_speed(rpc_address: str, measure_time: int = 7) -> float: + """Measure download speed from an RPC node. Returns bytes/sec.""" + url: str = f"http://{rpc_address}/snapshot.tar.bz2" + req = Request(url) + try: + with urllib.request.urlopen(req, timeout=measure_time + 5) as resp: + start: float = time.monotonic() + total: int = 0 + while True: + elapsed: float = time.monotonic() - start + if elapsed >= measure_time: + break + chunk: bytes = resp.read(81920) + if not chunk: + break + total += len(chunk) + elapsed = time.monotonic() - start + if elapsed <= 0: + return 0.0 + return total / elapsed + except (urllib.error.URLError, OSError, TimeoutError): + return 0.0 + + +# -- Download ------------------------------------------------------------------ + + +def download_aria2c( + urls: list[str], + output_dir: str, + filename: str, + connections: int = 16, +) -> bool: + """Download a file using aria2c with parallel connections. + + When multiple URLs are provided, aria2c treats them as mirrors of the + same file and distributes chunks across all of them. + """ + num_mirrors: int = len(urls) + total_splits: int = max(connections, connections * num_mirrors) + cmd: list[str] = [ + "aria2c", + "--file-allocation=none", + "--continue=false", + f"--max-connection-per-server={connections}", + f"--split={total_splits}", + "--min-split-size=50M", + # aria2c retries individual chunk connections on transient network + # errors (TCP reset, timeout). This is transport-level retry analogous + # to TCP retransmit, not application-level retry of a failed operation. + "--max-tries=5", + "--retry-wait=5", + "--timeout=60", + "--connect-timeout=10", + "--summary-interval=10", + "--console-log-level=notice", + f"--dir={output_dir}", + f"--out={filename}", + "--auto-file-renaming=false", + "--allow-overwrite=true", + *urls, + ] + + log.info("Downloading %s", filename) + log.info(" aria2c: %d connections x %d mirrors (%d splits)", + connections, num_mirrors, total_splits) + + start: float = time.monotonic() + result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd) + elapsed: float = time.monotonic() - start + + if result.returncode != 0: + log.error("aria2c failed with exit code %d", result.returncode) + return False + + filepath: Path = Path(output_dir) / filename + if not filepath.exists(): + log.error("aria2c reported success but %s does not exist", filepath) + return False + + size_bytes: int = filepath.stat().st_size + size_gb: float = size_bytes / (1024 ** 3) + avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0 + log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb) + return True + + +# -- Public API ---------------------------------------------------------------- + + +def download_best_snapshot( + output_dir: str, + *, + cluster: str = "mainnet-beta", + rpc_url: str | None = None, + connections: int = 16, + threads: int = 500, + max_snapshot_age: int = 10000, + max_latency: float = 500, + min_download_speed: int = 20, + measurement_time: int = 7, + max_speed_checks: int = 15, + version_filter: str | None = None, + full_only: bool = False, +) -> bool: + """Download the best available snapshot to output_dir. + + This is the programmatic API — called by entrypoint.py for automatic + snapshot download. Returns True on success, False on failure. + + All parameters have sensible defaults matching the CLI interface. + """ + resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster] + + if not shutil.which("aria2c"): + log.error("aria2c not found. Install with: apt install aria2") + return False + + log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc) + current_slot: int | None = get_current_slot(resolved_rpc) + if current_slot is None: + log.error("Cannot get current slot from %s", resolved_rpc) + return False + log.info("Current slot: %d", current_slot) + + sources: list[SnapshotSource] = discover_sources( + resolved_rpc, current_slot, + max_age_slots=max_snapshot_age, + max_latency_ms=max_latency, + threads=threads, + version_filter=version_filter, + ) + if not sources: + log.error("No snapshot sources found") + return False + + # Sort by latency (lowest first) for speed benchmarking + sources.sort(key=lambda s: s.latency_ms) + + # Benchmark top candidates + log.info("Benchmarking download speed on top %d sources...", max_speed_checks) + fast_sources: list[SnapshotSource] = [] + checked: int = 0 + min_speed_bytes: int = min_download_speed * 1024 * 1024 + + for source in sources: + if checked >= max_speed_checks: + break + checked += 1 + + speed: float = measure_speed(source.rpc_address, measurement_time) + source.download_speed = speed + speed_mib: float = speed / (1024 ** 2) + + if speed < min_speed_bytes: + log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)", + source.rpc_address, speed_mib, min_download_speed) + continue + + log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)", + source.rpc_address, speed_mib, + source.latency_ms, source.slots_diff) + fast_sources.append(source) + + if not fast_sources: + log.error("No source met minimum speed requirement (%d MiB/s)", + min_download_speed) + return False + + # Use the fastest source as primary, collect mirrors for each file + best: SnapshotSource = fast_sources[0] + file_paths: list[str] = best.file_paths + if full_only: + file_paths = [fp for fp in file_paths + if fp.rsplit("/", 1)[-1].startswith("snapshot-")] + + # Build mirror URL lists + download_plan: list[tuple[str, list[str]]] = [] + for fp in file_paths: + filename: str = fp.rsplit("/", 1)[-1] + mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"] + for other in fast_sources[1:]: + for other_fp in other.file_paths: + if other_fp.rsplit("/", 1)[-1] == filename: + mirror_urls.append(f"http://{other.rpc_address}{other_fp}") + break + download_plan.append((filename, mirror_urls)) + + speed_mib: float = best.download_speed / (1024 ** 2) + log.info("Best source: %s (%.1f MiB/s), %d mirrors total", + best.rpc_address, speed_mib, len(fast_sources)) + for filename, mirror_urls in download_plan: + log.info(" %s (%d mirrors)", filename, len(mirror_urls)) + + # Download + os.makedirs(output_dir, exist_ok=True) + total_start: float = time.monotonic() + + for filename, mirror_urls in download_plan: + filepath: Path = Path(output_dir) / filename + if filepath.exists() and filepath.stat().st_size > 0: + log.info("Skipping %s (already exists: %.1f GB)", + filename, filepath.stat().st_size / (1024 ** 3)) + continue + if not download_aria2c(mirror_urls, output_dir, filename, connections): + log.error("Failed to download %s", filename) + return False + + total_elapsed: float = time.monotonic() - total_start + log.info("All downloads complete in %.0fs", total_elapsed) + for filename, _ in download_plan: + fp_path: Path = Path(output_dir) / filename + if fp_path.exists(): + log.info(" %s (%.1f GB)", fp_path.name, fp_path.stat().st_size / (1024 ** 3)) + + return True + + +# -- Main (CLI) ---------------------------------------------------------------- + + +def main() -> int: + p: argparse.ArgumentParser = argparse.ArgumentParser( + description="Download Solana snapshots with aria2c parallel downloads", + ) + p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots", + help="Snapshot output directory (default: /srv/kind/solana/snapshots)") + p.add_argument("-c", "--cluster", default="mainnet-beta", + choices=list(CLUSTER_RPC), + help="Solana cluster (default: mainnet-beta)") + p.add_argument("-r", "--rpc", default=None, + help="RPC URL for cluster discovery (default: public RPC)") + p.add_argument("-n", "--connections", type=int, default=16, + help="aria2c connections per download (default: 16)") + p.add_argument("-t", "--threads", type=int, default=500, + help="Threads for parallel RPC probing (default: 500)") + p.add_argument("--max-snapshot-age", type=int, default=10000, + help="Max snapshot age in slots (default: 10000)") + p.add_argument("--max-latency", type=float, default=500, + help="Max RPC probe latency in ms (default: 500)") + p.add_argument("--min-download-speed", type=int, default=20, + help="Min download speed in MiB/s (default: 20)") + p.add_argument("--measurement-time", type=int, default=7, + help="Speed measurement duration in seconds (default: 7)") + p.add_argument("--max-speed-checks", type=int, default=15, + help="Max nodes to benchmark before giving up (default: 15)") + p.add_argument("--version", default=None, + help="Filter nodes by version prefix (e.g. '2.2')") + p.add_argument("--full-only", action="store_true", + help="Download only full snapshot, skip incremental") + p.add_argument("--dry-run", action="store_true", + help="Find best source and print URL, don't download") + p.add_argument("--post-cmd", + help="Shell command to run after successful download " + "(e.g. 'kubectl scale deployment ... --replicas=1')") + p.add_argument("-v", "--verbose", action="store_true") + args: argparse.Namespace = p.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + + # Dry-run uses the original inline flow (needs access to sources for URL printing) + if args.dry_run: + rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster] + current_slot: int | None = get_current_slot(rpc_url) + if current_slot is None: + log.error("Cannot get current slot from %s", rpc_url) + return 1 + + sources: list[SnapshotSource] = discover_sources( + rpc_url, current_slot, + max_age_slots=args.max_snapshot_age, + max_latency_ms=args.max_latency, + threads=args.threads, + version_filter=args.version, + ) + if not sources: + log.error("No snapshot sources found") + return 1 + + sources.sort(key=lambda s: s.latency_ms) + best = sources[0] + for fp in best.file_paths: + print(f"http://{best.rpc_address}{fp}") + return 0 + + ok: bool = download_best_snapshot( + args.output, + cluster=args.cluster, + rpc_url=args.rpc, + connections=args.connections, + threads=args.threads, + max_snapshot_age=args.max_snapshot_age, + max_latency=args.max_latency, + min_download_speed=args.min_download_speed, + measurement_time=args.measurement_time, + max_speed_checks=args.max_speed_checks, + version_filter=args.version, + full_only=args.full_only, + ) + + if ok and args.post_cmd: + log.info("Running post-download command: %s", args.post_cmd) + result: subprocess.CompletedProcess[bytes] = subprocess.run( + args.post_cmd, shell=True, + ) + if result.returncode != 0: + log.error("Post-download command failed with exit code %d", + result.returncode) + return 1 + log.info("Post-download command completed successfully") + + return 0 if ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/start-test.sh b/start-test.sh new file mode 100644 index 00000000..e003a97a --- /dev/null +++ b/start-test.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ----------------------------------------------------------------------- +# Start solana-test-validator with optional SPL token setup +# +# Environment variables: +# FACILITATOR_PUBKEY - facilitator fee-payer public key (base58) +# SERVER_PUBKEY - server/payee wallet public key (base58) +# CLIENT_PUBKEY - client/payer wallet public key (base58) +# MINT_DECIMALS - token decimals (default: 6, matching USDC) +# MINT_AMOUNT - amount to mint to client (default: 1000000000) +# LEDGER_DIR - ledger directory (default: /data/ledger) +# ----------------------------------------------------------------------- + +LEDGER_DIR="${LEDGER_DIR:-/data/ledger}" +MINT_DECIMALS="${MINT_DECIMALS:-6}" +MINT_AMOUNT="${MINT_AMOUNT:-1000000000}" +SETUP_MARKER="${LEDGER_DIR}/.setup-done" + +sudo chown -R "$(id -u):$(id -g)" "$LEDGER_DIR" 2>/dev/null || true + +# Start test-validator in the background +solana-test-validator \ + --ledger "${LEDGER_DIR}" \ + --rpc-port 8899 \ + --bind-address 0.0.0.0 \ + --quiet & + +VALIDATOR_PID=$! + +# Wait for RPC to become available +echo "Waiting for test-validator RPC..." +for i in $(seq 1 60); do + if solana cluster-version --url http://127.0.0.1:8899 >/dev/null 2>&1; then + echo "Test-validator is ready (attempt ${i})" + break + fi + sleep 1 +done + +solana config set --url http://127.0.0.1:8899 + +# Only run setup once (idempotent via marker file) +if [ ! -f "${SETUP_MARKER}" ]; then + echo "Running first-time setup..." + + # Airdrop SOL to all wallets for gas + for PUBKEY in "${FACILITATOR_PUBKEY:-}" "${SERVER_PUBKEY:-}" "${CLIENT_PUBKEY:-}"; do + if [ -n "${PUBKEY}" ]; then + echo "Airdropping 100 SOL to ${PUBKEY}..." + solana airdrop 100 "${PUBKEY}" --url http://127.0.0.1:8899 || true + fi + done + + # Create a USDC-equivalent SPL token mint if any pubkeys are set + if [ -n "${CLIENT_PUBKEY:-}" ] || [ -n "${FACILITATOR_PUBKEY:-}" ] || [ -n "${SERVER_PUBKEY:-}" ]; then + MINT_AUTHORITY_FILE="${LEDGER_DIR}/mint-authority.json" + if [ ! -f "${MINT_AUTHORITY_FILE}" ]; then + solana-keygen new --no-bip39-passphrase --outfile "${MINT_AUTHORITY_FILE}" --force + MINT_AUTH_PUBKEY=$(solana-keygen pubkey "${MINT_AUTHORITY_FILE}") + solana airdrop 10 "${MINT_AUTH_PUBKEY}" --url http://127.0.0.1:8899 + fi + + MINT_ADDRESS_FILE="${LEDGER_DIR}/usdc-mint-address.txt" + if [ ! -f "${MINT_ADDRESS_FILE}" ]; then + spl-token create-token \ + --decimals "${MINT_DECIMALS}" \ + --mint-authority "${MINT_AUTHORITY_FILE}" \ + --url http://127.0.0.1:8899 \ + 2>&1 | grep "Creating token" | awk '{print $3}' > "${MINT_ADDRESS_FILE}" + echo "Created USDC mint: $(cat "${MINT_ADDRESS_FILE}")" + fi + + USDC_MINT=$(cat "${MINT_ADDRESS_FILE}") + + # Create ATAs and mint tokens for the client + if [ -n "${CLIENT_PUBKEY:-}" ]; then + echo "Creating ATA for client ${CLIENT_PUBKEY}..." + spl-token create-account "${USDC_MINT}" \ + --owner "${CLIENT_PUBKEY}" \ + --fee-payer "${MINT_AUTHORITY_FILE}" \ + --url http://127.0.0.1:8899 || true + + echo "Minting ${MINT_AMOUNT} tokens to client..." + spl-token mint "${USDC_MINT}" "${MINT_AMOUNT}" \ + --recipient-owner "${CLIENT_PUBKEY}" \ + --mint-authority "${MINT_AUTHORITY_FILE}" \ + --url http://127.0.0.1:8899 || true + fi + + # Create ATAs for server and facilitator + for PUBKEY in "${SERVER_PUBKEY:-}" "${FACILITATOR_PUBKEY:-}"; do + if [ -n "${PUBKEY}" ]; then + echo "Creating ATA for ${PUBKEY}..." + spl-token create-account "${USDC_MINT}" \ + --owner "${PUBKEY}" \ + --fee-payer "${MINT_AUTHORITY_FILE}" \ + --url http://127.0.0.1:8899 || true + fi + done + + # Expose mint address for other containers + cp "${MINT_ADDRESS_FILE}" /tmp/usdc-mint-address.txt 2>/dev/null || true + fi + + touch "${SETUP_MARKER}" + echo "Setup complete." +fi + +echo "solana-test-validator running (PID ${VALIDATOR_PID})" +wait ${VALIDATOR_PID}