diff --git a/scripts/agave-container/Dockerfile b/scripts/agave-container/Dockerfile deleted file mode 100644 index 97c6227f..00000000 --- a/scripts/agave-container/Dockerfile +++ /dev/null @@ -1,81 +0,0 @@ -# Unified Agave/Jito Solana image -# Supports three modes via AGAVE_MODE env: test, rpc, validator -# -# Build args: -# AGAVE_REPO - git repo URL (anza-xyz/agave or jito-foundation/jito-solana) -# AGAVE_VERSION - git tag to build (e.g. v3.1.9, v3.1.8-jito) - -ARG AGAVE_REPO=https://github.com/anza-xyz/agave.git -ARG AGAVE_VERSION=v3.1.9 - -# ---------- Stage 1: Build ---------- -FROM rust:1.85-bookworm AS builder - -ARG AGAVE_REPO -ARG AGAVE_VERSION - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - pkg-config \ - libssl-dev \ - libudev-dev \ - libclang-dev \ - protobuf-compiler \ - ca-certificates \ - git \ - cmake \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build -RUN git clone "$AGAVE_REPO" --depth 1 --branch "$AGAVE_VERSION" --recurse-submodules agave -WORKDIR /build/agave - -# Cherry-pick --public-tvu-address support (anza-xyz/agave PR #6778, commit 9f4b3ae) -# This flag only exists on master, not in v3.1.9 — fetch the PR ref and cherry-pick -ARG TVU_ADDRESS_PR=6778 -RUN if [ -n "$TVU_ADDRESS_PR" ]; then \ - git fetch --depth 50 origin "pull/${TVU_ADDRESS_PR}/head:tvu-pr" && \ - git cherry-pick --no-commit tvu-pr; \ - fi - -# Build all binaries using the upstream install script -RUN CI_COMMIT=$(git rev-parse HEAD) scripts/cargo-install-all.sh /solana-release - -# ---------- Stage 2: Runtime ---------- -FROM debian:bookworm-slim - -RUN apt-get update && apt-get install -y --no-install-recommends \ - ca-certificates \ - libssl3 \ - libudev1 \ - curl \ - sudo \ - aria2 \ - python3 \ - && rm -rf /var/lib/apt/lists/* - -# Create non-root user with sudo -RUN useradd -m -s /bin/bash agave \ - && echo "agave ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers - -# Copy all compiled binaries -COPY --from=builder /solana-release/bin/ /usr/local/bin/ - -# Copy entrypoint and support scripts -COPY entrypoint.py snapshot_download.py ip_echo_preflight.py /usr/local/bin/ -COPY start-test.sh /usr/local/bin/ -RUN chmod +x /usr/local/bin/entrypoint.py /usr/local/bin/start-test.sh - -# Create data directories -RUN mkdir -p /data/config /data/ledger /data/accounts /data/snapshots \ - && chown -R agave:agave /data - -USER agave -WORKDIR /data - -ENV RUST_LOG=info -ENV RUST_BACKTRACE=1 - -EXPOSE 8899 8900 8001 8001/udp - -ENTRYPOINT ["entrypoint.py"] diff --git a/scripts/agave-container/build.sh b/scripts/agave-container/build.sh deleted file mode 100644 index 4c4d940f..00000000 --- a/scripts/agave-container/build.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash - -# Build laconicnetwork/agave -# Set AGAVE_REPO and AGAVE_VERSION env vars to build Jito or a different version -source ${CERC_CONTAINER_BASE_DIR}/build-base.sh - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -AGAVE_REPO="${AGAVE_REPO:-https://github.com/anza-xyz/agave.git}" -AGAVE_VERSION="${AGAVE_VERSION:-v3.1.9}" - -docker build -t laconicnetwork/agave:local \ - --build-arg AGAVE_REPO="$AGAVE_REPO" \ - --build-arg AGAVE_VERSION="$AGAVE_VERSION" \ - ${build_command_args} \ - -f ${SCRIPT_DIR}/Dockerfile \ - ${SCRIPT_DIR} diff --git a/scripts/agave-container/entrypoint.py b/scripts/agave-container/entrypoint.py deleted file mode 100644 index 2b7324c3..00000000 --- a/scripts/agave-container/entrypoint.py +++ /dev/null @@ -1,686 +0,0 @@ -#!/usr/bin/env python3 -"""Agave validator entrypoint — snapshot management, arg construction, liveness probe. - -Two subcommands: - entrypoint.py serve (default) — snapshot freshness check + run agave-validator - entrypoint.py probe — liveness probe (slot lag check, exits 0/1) - -Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single -Python module. Test mode still dispatches to start-test.sh. - -Python stays as PID 1 and traps SIGTERM. On SIGTERM, it runs -``agave-validator exit --force --ledger /data/ledger`` which connects to the -admin RPC Unix socket and tells the validator to flush I/O and exit cleanly. -This avoids the io_uring/ZFS deadlock that occurs when the process is killed. - -All configuration comes from environment variables — same vars as the original -bash scripts. See compose files for defaults. -""" - -from __future__ import annotations - -import json -import logging -import os -import re -import signal -import subprocess -import sys -import threading -import time -import urllib.error -import urllib.request -from pathlib import Path -from urllib.request import Request - -log: logging.Logger = logging.getLogger("entrypoint") - -# Directories -CONFIG_DIR = "/data/config" -LEDGER_DIR = "/data/ledger" -ACCOUNTS_DIR = "/data/accounts" -SNAPSHOTS_DIR = "/data/snapshots" -LOG_DIR = "/data/log" -IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json" - -# Snapshot filename patterns -FULL_SNAP_RE: re.Pattern[str] = re.compile( - r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$" -) -INCR_SNAP_RE: re.Pattern[str] = re.compile( - r"^incremental-snapshot-(\d+)-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$" -) - -MAINNET_RPC = "https://api.mainnet-beta.solana.com" - - -# -- Helpers ------------------------------------------------------------------- - - -def env(name: str, default: str = "") -> str: - """Read env var with default.""" - return os.environ.get(name, default) - - -def env_required(name: str) -> str: - """Read required env var, exit if missing.""" - val = os.environ.get(name) - if not val: - log.error("%s is required but not set", name) - sys.exit(1) - return val - - -def env_bool(name: str, default: bool = False) -> bool: - """Read boolean env var (true/false/1/0).""" - val = os.environ.get(name, "").lower() - if not val: - return default - return val in ("true", "1", "yes") - - -def rpc_get_slot(url: str, timeout: int = 10) -> int | None: - """Get current slot from a Solana RPC endpoint.""" - payload = json.dumps({ - "jsonrpc": "2.0", "id": 1, - "method": "getSlot", "params": [], - }).encode() - req = Request(url, data=payload, - headers={"Content-Type": "application/json"}) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - data = json.loads(resp.read()) - result = data.get("result") - if isinstance(result, int): - return result - except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError): - pass - return None - - -# -- Snapshot management ------------------------------------------------------- - - -def get_local_snapshot_slot(snapshots_dir: str) -> int | None: - """Find the highest slot among local snapshot files.""" - best_slot: int | None = None - snap_path = Path(snapshots_dir) - if not snap_path.is_dir(): - return None - for entry in snap_path.iterdir(): - m = FULL_SNAP_RE.match(entry.name) - if m: - slot = int(m.group(1)) - if best_slot is None or slot > best_slot: - best_slot = slot - return best_slot - - -def clean_snapshots(snapshots_dir: str) -> None: - """Remove all snapshot files from the directory.""" - snap_path = Path(snapshots_dir) - if not snap_path.is_dir(): - return - for entry in snap_path.iterdir(): - if entry.name.startswith(("snapshot-", "incremental-snapshot-")): - log.info("Removing old snapshot: %s", entry.name) - entry.unlink(missing_ok=True) - - -def get_incremental_slot(snapshots_dir: str, full_slot: int | None) -> int | None: - """Get the highest incremental snapshot slot matching the full's base slot.""" - if full_slot is None: - return None - snap_path = Path(snapshots_dir) - if not snap_path.is_dir(): - return None - best: int | None = None - for entry in snap_path.iterdir(): - m = INCR_SNAP_RE.match(entry.name) - if m and int(m.group(1)) == full_slot: - slot = int(m.group(2)) - if best is None or slot > best: - best = slot - return best - - -def maybe_download_snapshot(snapshots_dir: str) -> None: - """Ensure full + incremental snapshots exist before starting. - - The validator should always start from a full + incremental pair to - minimize replay time. If either is missing or the full is too old, - download fresh ones via download_best_snapshot (which does rolling - incremental convergence after downloading the full). - - Controlled by env vars: - SNAPSHOT_AUTO_DOWNLOAD (default: true) — enable/disable - SNAPSHOT_MAX_AGE_SLOTS (default: 100000) — full snapshot staleness threshold - (one full snapshot generation, ~11 hours) - """ - if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True): - log.info("Snapshot auto-download disabled") - return - - max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "100000")) - - mainnet_slot = rpc_get_slot(MAINNET_RPC) - if mainnet_slot is None: - log.warning("Cannot reach mainnet RPC — skipping snapshot check") - return - - script_dir = Path(__file__).resolve().parent - sys.path.insert(0, str(script_dir)) - from snapshot_download import download_best_snapshot, download_incremental_for_slot - - convergence = int(env("SNAPSHOT_CONVERGENCE_SLOTS", "500")) - retry_delay = int(env("SNAPSHOT_RETRY_DELAY", "60")) - - # Check local full snapshot - local_slot = get_local_snapshot_slot(snapshots_dir) - have_fresh_full = (local_slot is not None - and (mainnet_slot - local_slot) <= max_age) - - if have_fresh_full: - assert local_slot is not None - inc_slot = get_incremental_slot(snapshots_dir, local_slot) - if inc_slot is not None: - inc_gap = mainnet_slot - inc_slot - if inc_gap <= convergence: - log.info("Full (slot %d) + incremental (slot %d, gap %d) " - "within convergence, starting", - local_slot, inc_slot, inc_gap) - return - log.info("Incremental too stale (slot %d, gap %d > %d)", - inc_slot, inc_gap, convergence) - # Fresh full, need a fresh incremental - log.info("Downloading incremental for full at slot %d", local_slot) - while True: - if download_incremental_for_slot(snapshots_dir, local_slot, - convergence_slots=convergence): - return - log.warning("Incremental download failed — retrying in %ds", - retry_delay) - time.sleep(retry_delay) - - # No full or full too old — download both - log.info("Downloading full + incremental") - clean_snapshots(snapshots_dir) - while True: - if download_best_snapshot(snapshots_dir, convergence_slots=convergence): - return - log.warning("Snapshot download failed — retrying in %ds", retry_delay) - time.sleep(retry_delay) - - -# -- Directory and identity setup ---------------------------------------------- - - -def ensure_dirs(*dirs: str) -> None: - """Create directories and fix ownership.""" - uid = os.getuid() - gid = os.getgid() - for d in dirs: - os.makedirs(d, exist_ok=True) - try: - subprocess.run( - ["sudo", "chown", "-R", f"{uid}:{gid}", d], - check=False, capture_output=True, - ) - except FileNotFoundError: - pass # sudo not available — dirs already owned correctly - - -def ensure_identity_rpc() -> None: - """Generate ephemeral identity keypair for RPC mode if not mounted.""" - if os.path.isfile(IDENTITY_FILE): - return - log.info("Generating RPC node identity keypair...") - subprocess.run( - ["solana-keygen", "new", "--no-passphrase", "--silent", - "--force", "--outfile", IDENTITY_FILE], - check=True, - ) - - -def print_identity() -> None: - """Print the node identity pubkey.""" - result = subprocess.run( - ["solana-keygen", "pubkey", IDENTITY_FILE], - capture_output=True, text=True, check=False, - ) - if result.returncode == 0: - log.info("Node identity: %s", result.stdout.strip()) - - -# -- Arg construction ---------------------------------------------------------- - - -def build_common_args() -> list[str]: - """Build agave-validator args common to both RPC and validator modes.""" - args: list[str] = [ - "--identity", IDENTITY_FILE, - "--entrypoint", env_required("VALIDATOR_ENTRYPOINT"), - "--known-validator", env_required("KNOWN_VALIDATOR"), - "--ledger", LEDGER_DIR, - "--accounts", ACCOUNTS_DIR, - "--snapshots", SNAPSHOTS_DIR, - "--rpc-port", env("RPC_PORT", "8899"), - "--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"), - "--gossip-port", env("GOSSIP_PORT", "8001"), - "--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"), - "--no-os-network-limits-test", - "--wal-recovery-mode", "skip_any_corrupted_record", - "--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"), - "--no-snapshot-fetch", # entrypoint handles snapshot download - ] - - # Snapshot generation - if env("NO_SNAPSHOTS") == "true": - args.append("--no-snapshots") - else: - args += [ - "--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"), - "--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "1"), - ] - if env("NO_INCREMENTAL_SNAPSHOTS") != "true": - args += ["--maximum-incremental-snapshots-to-retain", "2"] - - # Account indexes - account_indexes = env("ACCOUNT_INDEXES") - if account_indexes: - for idx in account_indexes.split(","): - idx = idx.strip() - if idx: - args += ["--account-index", idx] - - # Additional entrypoints - for ep in env("EXTRA_ENTRYPOINTS").split(): - if ep: - args += ["--entrypoint", ep] - - # Additional known validators - for kv in env("EXTRA_KNOWN_VALIDATORS").split(): - if kv: - args += ["--known-validator", kv] - - # Cluster verification - genesis_hash = env("EXPECTED_GENESIS_HASH") - if genesis_hash: - args += ["--expected-genesis-hash", genesis_hash] - shred_version = env("EXPECTED_SHRED_VERSION") - if shred_version: - args += ["--expected-shred-version", shred_version] - - # Metrics — just needs to be in the environment, agave reads it directly - # (env var is already set, nothing to pass as arg) - - # Gossip host / TVU address - gossip_host = env("GOSSIP_HOST") - if gossip_host: - args += ["--gossip-host", gossip_host] - elif env("PUBLIC_TVU_ADDRESS"): - args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")] - - # Jito flags - if env("JITO_ENABLE") == "true": - log.info("Jito MEV enabled") - jito_flags: list[tuple[str, str]] = [ - ("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"), - ("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"), - ("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"), - ("JITO_COMMISSION_BPS", "--commission-bps"), - ("JITO_BLOCK_ENGINE_URL", "--block-engine-url"), - ("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"), - ] - for env_name, flag in jito_flags: - val = env(env_name) - if val: - args += [flag, val] - - return args - - -def build_rpc_args() -> list[str]: - """Build agave-validator args for RPC (non-voting) mode.""" - args = build_common_args() - args += [ - "--no-voting", - "--log", f"{LOG_DIR}/validator.log", - "--full-rpc-api", - "--enable-rpc-transaction-history", - "--rpc-pubsub-enable-block-subscription", - "--enable-extended-tx-metadata-storage", - "--no-wait-for-vote-to-start-leader", - ] - - # Public vs private RPC - public_rpc = env("PUBLIC_RPC_ADDRESS") - if public_rpc: - args += ["--public-rpc-address", public_rpc] - else: - args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"] - - # Jito relayer URL (RPC mode doesn't use it, but validator mode does — - # handled in build_validator_args) - - return args - - -def build_validator_args() -> list[str]: - """Build agave-validator args for voting validator mode.""" - vote_keypair = env("VOTE_ACCOUNT_KEYPAIR", - "/data/config/vote-account-keypair.json") - - # Identity must be mounted for validator mode - if not os.path.isfile(IDENTITY_FILE): - log.error("Validator identity keypair not found at %s", IDENTITY_FILE) - log.error("Mount your validator keypair to %s", IDENTITY_FILE) - sys.exit(1) - - # Vote account keypair must exist - if not os.path.isfile(vote_keypair): - log.error("Vote account keypair not found at %s", vote_keypair) - log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR") - sys.exit(1) - - # Print vote account pubkey - result = subprocess.run( - ["solana-keygen", "pubkey", vote_keypair], - capture_output=True, text=True, check=False, - ) - if result.returncode == 0: - log.info("Vote account: %s", result.stdout.strip()) - - args = build_common_args() - args += [ - "--vote-account", vote_keypair, - "--log", "-", - ] - - # Jito relayer URL (validator-only) - relayer_url = env("JITO_RELAYER_URL") - if env("JITO_ENABLE") == "true" and relayer_url: - args += ["--relayer-url", relayer_url] - - return args - - -def append_extra_args(args: list[str]) -> list[str]: - """Append EXTRA_ARGS passthrough flags.""" - extra = env("EXTRA_ARGS") - if extra: - args += extra.split() - return args - - -# -- Graceful shutdown -------------------------------------------------------- - -# Timeout for graceful exit via admin RPC. Leave 30s margin for k8s -# terminationGracePeriodSeconds (300s). -GRACEFUL_EXIT_TIMEOUT = 270 - - -def graceful_exit(child: subprocess.Popen[bytes], reason: str = "SIGTERM") -> None: - """Request graceful shutdown via the admin RPC Unix socket. - - Runs ``agave-validator exit --force --ledger /data/ledger`` which connects - to the admin RPC socket at ``/data/ledger/admin.rpc`` and sets the - validator's exit flag. The validator flushes all I/O and exits cleanly, - avoiding the io_uring/ZFS deadlock. - - If the admin RPC exit fails or the child doesn't exit within the timeout, - falls back to SIGTERM then SIGKILL. - """ - log.info("%s — requesting graceful exit via admin RPC", reason) - try: - result = subprocess.run( - ["agave-validator", "exit", "--force", "--ledger", LEDGER_DIR], - capture_output=True, text=True, timeout=30, - ) - if result.returncode == 0: - log.info("Admin RPC exit requested successfully") - else: - log.warning( - "Admin RPC exit returned %d: %s", - result.returncode, result.stderr.strip(), - ) - except subprocess.TimeoutExpired: - log.warning("Admin RPC exit command timed out after 30s") - except FileNotFoundError: - log.warning("agave-validator binary not found for exit command") - - # Wait for child to exit - try: - child.wait(timeout=GRACEFUL_EXIT_TIMEOUT) - log.info("Validator exited cleanly with code %d", child.returncode) - return - except subprocess.TimeoutExpired: - log.warning( - "Validator did not exit within %ds — sending SIGTERM", - GRACEFUL_EXIT_TIMEOUT, - ) - - # Fallback: SIGTERM - child.terminate() - try: - child.wait(timeout=15) - log.info("Validator exited after SIGTERM with code %d", child.returncode) - return - except subprocess.TimeoutExpired: - log.warning("Validator did not exit after SIGTERM — sending SIGKILL") - - # Last resort: SIGKILL - child.kill() - child.wait() - log.info("Validator killed with SIGKILL, code %d", child.returncode) - - -# -- Serve subcommand --------------------------------------------------------- - - -def _gap_monitor( - child: subprocess.Popen[bytes], - leapfrog: threading.Event, - shutting_down: threading.Event, -) -> None: - """Background thread: poll slot gap and trigger leapfrog if too far behind. - - Waits for a grace period (SNAPSHOT_MONITOR_GRACE, default 600s) before - monitoring — the validator needs time to extract snapshots and catch up. - Then polls every SNAPSHOT_MONITOR_INTERVAL (default 30s). If the gap - exceeds SNAPSHOT_LEAPFROG_SLOTS (default 5000) for SNAPSHOT_LEAPFROG_CHECKS - (default 3) consecutive checks, triggers graceful shutdown and sets the - leapfrog event so cmd_serve loops back to download a fresh incremental. - """ - threshold = int(env("SNAPSHOT_LEAPFROG_SLOTS", "5000")) - required_checks = int(env("SNAPSHOT_LEAPFROG_CHECKS", "3")) - interval = int(env("SNAPSHOT_MONITOR_INTERVAL", "30")) - grace = int(env("SNAPSHOT_MONITOR_GRACE", "600")) - rpc_port = env("RPC_PORT", "8899") - local_url = f"http://127.0.0.1:{rpc_port}" - - # Grace period — don't monitor during initial catch-up - if shutting_down.wait(grace): - return - - consecutive = 0 - while not shutting_down.is_set(): - local_slot = rpc_get_slot(local_url, timeout=5) - mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10) - - if local_slot is not None and mainnet_slot is not None: - gap = mainnet_slot - local_slot - if gap > threshold: - consecutive += 1 - log.warning("Gap %d > %d (%d/%d consecutive)", - gap, threshold, consecutive, required_checks) - if consecutive >= required_checks: - log.warning("Leapfrog triggered: gap %d", gap) - leapfrog.set() - graceful_exit(child, reason="Leapfrog") - return - else: - if consecutive > 0: - log.info("Gap %d within threshold, resetting counter", gap) - consecutive = 0 - - shutting_down.wait(interval) - - -def cmd_serve() -> None: - """Main serve flow: snapshot download, run validator, monitor gap, leapfrog. - - Python stays as PID 1. On each iteration: - 1. Download full + incremental snapshots (if needed) - 2. Start agave-validator as child process - 3. Monitor slot gap in background thread - 4. If gap exceeds threshold → graceful stop → loop back to step 1 - 5. If SIGTERM → graceful stop → exit - 6. If validator crashes → exit with its return code - """ - mode = env("AGAVE_MODE", "test") - log.info("AGAVE_MODE=%s", mode) - - if mode == "test": - os.execvp("start-test.sh", ["start-test.sh"]) - - if mode not in ("rpc", "validator"): - log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode) - sys.exit(1) - - # One-time setup - dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR] - if mode == "rpc": - dirs.append(LOG_DIR) - ensure_dirs(*dirs) - - if not env_bool("SKIP_IP_ECHO_PREFLIGHT"): - script_dir = Path(__file__).resolve().parent - sys.path.insert(0, str(script_dir)) - from ip_echo_preflight import main as ip_echo_main - if ip_echo_main() != 0: - sys.exit(1) - - if mode == "rpc": - ensure_identity_rpc() - print_identity() - - if mode == "rpc": - args = build_rpc_args() - else: - args = build_validator_args() - args = append_extra_args(args) - - # Main loop: download → run → monitor → leapfrog if needed - while True: - maybe_download_snapshot(SNAPSHOTS_DIR) - - Path("/tmp/entrypoint-start").write_text(str(time.time())) - log.info("Starting agave-validator with %d arguments", len(args)) - child = subprocess.Popen(["agave-validator"] + args) - - shutting_down = threading.Event() - leapfrog = threading.Event() - - signal.signal(signal.SIGUSR1, - lambda _sig, _frame: child.send_signal(signal.SIGUSR1)) - - def _on_sigterm(_sig: int, _frame: object) -> None: - shutting_down.set() - threading.Thread( - target=graceful_exit, args=(child,), daemon=True, - ).start() - - signal.signal(signal.SIGTERM, _on_sigterm) - - # Start gap monitor - monitor = threading.Thread( - target=_gap_monitor, - args=(child, leapfrog, shutting_down), - daemon=True, - ) - monitor.start() - - child.wait() - - if leapfrog.is_set(): - log.info("Leapfrog: restarting with fresh incremental") - continue - - sys.exit(child.returncode) - - -# -- Probe subcommand --------------------------------------------------------- - - -def cmd_probe() -> None: - """Liveness probe: check local RPC slot vs mainnet. - - Exit 0 = healthy, exit 1 = unhealthy. - - Grace period: PROBE_GRACE_SECONDS (default 600) — probe always passes - during grace period to allow for snapshot unpacking and initial replay. - """ - grace_seconds = int(env("PROBE_GRACE_SECONDS", "600")) - max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000")) - - # Check grace period - start_file = Path("/tmp/entrypoint-start") - if start_file.exists(): - try: - start_time = float(start_file.read_text().strip()) - elapsed = time.time() - start_time - if elapsed < grace_seconds: - # Within grace period — always healthy - sys.exit(0) - except (ValueError, OSError): - pass - else: - # No start file — serve hasn't started yet, within grace - sys.exit(0) - - # Query local RPC - rpc_port = env("RPC_PORT", "8899") - local_url = f"http://127.0.0.1:{rpc_port}" - local_slot = rpc_get_slot(local_url, timeout=5) - if local_slot is None: - # Local RPC unreachable after grace period — unhealthy - sys.exit(1) - - # Query mainnet - mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10) - if mainnet_slot is None: - # Can't reach mainnet to compare — assume healthy (don't penalize - # the validator for mainnet RPC being down) - sys.exit(0) - - lag = mainnet_slot - local_slot - if lag > max_lag: - sys.exit(1) - - sys.exit(0) - - -# -- Main ---------------------------------------------------------------------- - - -def main() -> None: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - datefmt="%H:%M:%S", - ) - - subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve" - - if subcmd == "serve": - cmd_serve() - elif subcmd == "probe": - cmd_probe() - else: - log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/scripts/agave-container/ip_echo_preflight.py b/scripts/agave-container/ip_echo_preflight.py deleted file mode 100644 index 20cbb259..00000000 --- a/scripts/agave-container/ip_echo_preflight.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python3 -"""ip_echo preflight — verify UDP port reachability before starting the validator. - -Implements the Solana ip_echo client protocol exactly: -1. Bind UDP sockets on the ports the validator will use -2. TCP connect to entrypoint gossip port, send IpEchoServerMessage -3. Parse IpEchoServerResponse (our IP as seen by entrypoint) -4. Wait for entrypoint's UDP probes on each port -5. Exit 0 if all ports reachable, exit 1 if any fail - -Wire format (from agave net-utils/src/): - Request: 4 null bytes + [u16; 4] tcp_ports LE + [u16; 4] udp_ports LE + \n - Response: 4 null bytes + bincode IpAddr (variant byte + addr) + optional shred_version - -Called from entrypoint.py before snapshot download. Prevents wasting hours -downloading a snapshot only to crash-loop on port reachability. -""" - -from __future__ import annotations - -import logging -import os -import socket -import struct -import sys -import threading -import time - -log = logging.getLogger("ip_echo_preflight") - -HEADER = b"\x00\x00\x00\x00" -TERMINUS = b"\x0a" -RESPONSE_BUF = 27 -IO_TIMEOUT = 5.0 -PROBE_TIMEOUT = 10.0 -MAX_RETRIES = 3 -RETRY_DELAY = 2.0 - - -def build_request(tcp_ports: list[int], udp_ports: list[int]) -> bytes: - """Build IpEchoServerMessage: header + [u16;4] tcp + [u16;4] udp + newline.""" - tcp = (tcp_ports + [0, 0, 0, 0])[:4] - udp = (udp_ports + [0, 0, 0, 0])[:4] - return HEADER + struct.pack("<4H", *tcp) + struct.pack("<4H", *udp) + TERMINUS - - -def parse_response(data: bytes) -> tuple[str, int | None]: - """Parse IpEchoServerResponse → (ip_string, shred_version | None). - - Wire format (bincode): - 4 bytes header (\0\0\0\0) - 4 bytes IpAddr enum variant (u32 LE: 0=IPv4, 1=IPv6) - 4|16 bytes address octets - 1 byte Option tag (0=None, 1=Some) - 2 bytes shred_version (u16 LE, only if Some) - """ - if len(data) < 8: - raise ValueError(f"response too short: {len(data)} bytes") - if data[:4] == b"HTTP": - raise ValueError("got HTTP response — not an ip_echo server") - if data[:4] != HEADER: - raise ValueError(f"unexpected header: {data[:4].hex()}") - variant = struct.unpack("= 3 and rest[0] == 1: - shred_version = struct.unpack(" None: - """Bind a UDP socket and wait for a probe packet.""" - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("0.0.0.0", port)) - sock.settimeout(0.5) - try: - while not stop.is_set(): - try: - _data, addr = sock.recvfrom(64) - results[port] = ("ok", addr) - return - except socket.timeout: - continue - finally: - sock.close() - except OSError as exc: - results[port] = ("bind_error", str(exc)) - - -def ip_echo_check( - entrypoint_host: str, - entrypoint_port: int, - udp_ports: list[int], -) -> tuple[str, dict[int, bool]]: - """Run one ip_echo exchange and return (seen_ip, {port: reachable}). - - Raises on TCP failure (caller retries). - """ - udp_ports = [p for p in udp_ports if p != 0][:4] - - # Start UDP listeners before sending the TCP request - results: dict[int, tuple] = {} - stop = threading.Event() - threads = [] - for port in udp_ports: - t = threading.Thread(target=_listen_udp, args=(port, results, stop), daemon=True) - t.start() - threads.append(t) - time.sleep(0.1) # let listeners bind - - # TCP: send request, read response - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(IO_TIMEOUT) - try: - sock.connect((entrypoint_host, entrypoint_port)) - sock.sendall(build_request([], udp_ports)) - resp = sock.recv(RESPONSE_BUF) - finally: - sock.close() - - seen_ip, shred_version = parse_response(resp) - log.info( - "entrypoint %s:%d sees us as %s (shred_version=%s)", - entrypoint_host, entrypoint_port, seen_ip, shred_version, - ) - - # Wait for UDP probes - deadline = time.monotonic() + PROBE_TIMEOUT - while time.monotonic() < deadline: - if all(p in results for p in udp_ports): - break - time.sleep(0.2) - - stop.set() - for t in threads: - t.join(timeout=1) - - port_ok: dict[int, bool] = {} - for port in udp_ports: - if port not in results: - log.error("port %d: no probe received within %.0fs", port, PROBE_TIMEOUT) - port_ok[port] = False - else: - status, detail = results[port] - if status == "ok": - log.info("port %d: probe received from %s", port, detail) - port_ok[port] = True - else: - log.error("port %d: %s: %s", port, status, detail) - port_ok[port] = False - - return seen_ip, port_ok - - -def run_preflight( - entrypoint_host: str, - entrypoint_port: int, - udp_ports: list[int], - expected_ip: str = "", -) -> bool: - """Run ip_echo check with retries. Returns True if all ports pass.""" - for attempt in range(1, MAX_RETRIES + 1): - log.info("ip_echo attempt %d/%d → %s:%d, ports %s", - attempt, MAX_RETRIES, entrypoint_host, entrypoint_port, udp_ports) - try: - seen_ip, port_ok = ip_echo_check(entrypoint_host, entrypoint_port, udp_ports) - except Exception as exc: - log.error("attempt %d TCP failed: %s", attempt, exc) - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY) - continue - - if expected_ip and seen_ip != expected_ip: - log.error( - "IP MISMATCH: entrypoint sees %s, expected %s (GOSSIP_HOST). " - "Outbound mangle/SNAT path is broken.", - seen_ip, expected_ip, - ) - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY) - continue - - reachable = [p for p, ok in port_ok.items() if ok] - unreachable = [p for p, ok in port_ok.items() if not ok] - - if not unreachable: - log.info("PASS: all ports reachable %s, seen as %s", reachable, seen_ip) - return True - - log.error( - "attempt %d: unreachable %s, reachable %s, seen as %s", - attempt, unreachable, reachable, seen_ip, - ) - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY) - - log.error("FAIL: ip_echo preflight exhausted %d attempts", MAX_RETRIES) - return False - - -def main() -> int: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - datefmt="%H:%M:%S", - ) - - # Parse entrypoint — VALIDATOR_ENTRYPOINT is "host:port" - raw = os.environ.get("VALIDATOR_ENTRYPOINT", "") - if not raw and len(sys.argv) > 1: - raw = sys.argv[1] - if not raw: - log.error("set VALIDATOR_ENTRYPOINT or pass host:port as argument") - return 1 - - if ":" in raw: - host, port_str = raw.rsplit(":", 1) - ep_port = int(port_str) - else: - host = raw - ep_port = 8001 - - gossip_port = int(os.environ.get("GOSSIP_PORT", "8001")) - dynamic_range = os.environ.get("DYNAMIC_PORT_RANGE", "9000-10000") - range_start = int(dynamic_range.split("-")[0]) - expected_ip = os.environ.get("GOSSIP_HOST", "") - - # Test gossip + first 3 ports from dynamic range (4 max per ip_echo message) - udp_ports = [gossip_port, range_start, range_start + 2, range_start + 3] - - ok = run_preflight(host, ep_port, udp_ports, expected_ip) - return 0 if ok else 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/agave-container/snapshot_download.py b/scripts/agave-container/snapshot_download.py deleted file mode 100644 index 2af2b976..00000000 --- a/scripts/agave-container/snapshot_download.py +++ /dev/null @@ -1,878 +0,0 @@ -#!/usr/bin/env python3 -"""Download Solana snapshots using aria2c for parallel multi-connection downloads. - -Discovers snapshot sources by querying getClusterNodes for all RPCs in the -cluster, probing each for available snapshots, benchmarking download speed, -and downloading from the fastest source using aria2c (16 connections by default). - -Based on the discovery approach from etcusr/solana-snapshot-finder but replaces -the single-connection wget download with aria2c parallel chunked downloads. - -Usage: - # Download to /srv/kind/solana/snapshots (mainnet, 16 connections) - ./snapshot_download.py -o /srv/kind/solana/snapshots - - # Dry run — find best source, print URL - ./snapshot_download.py --dry-run - - # Custom RPC for cluster discovery + 32 connections - ./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32 - - # Testnet - ./snapshot_download.py -c testnet -o /data/snapshots - - # Programmatic use from entrypoint.py: - from snapshot_download import download_best_snapshot - ok = download_best_snapshot("/data/snapshots") - -Requirements: - - aria2c (apt install aria2) - - python3 >= 3.10 (stdlib only, no pip dependencies) -""" - -from __future__ import annotations - -import argparse -import concurrent.futures -import json -import logging -import os -import re -import shutil -import subprocess -import sys -import time -import urllib.error -import urllib.request -from dataclasses import dataclass, field -from http.client import HTTPResponse -from pathlib import Path -from urllib.request import Request - -log: logging.Logger = logging.getLogger("snapshot-download") - -CLUSTER_RPC: dict[str, str] = { - "mainnet-beta": "https://api.mainnet-beta.solana.com", - "testnet": "https://api.testnet.solana.com", - "devnet": "https://api.devnet.solana.com", -} - -# Snapshot filenames: -# snapshot--.tar.zst -# incremental-snapshot---.tar.zst -FULL_SNAP_RE: re.Pattern[str] = re.compile( - r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" -) -INCR_SNAP_RE: re.Pattern[str] = re.compile( - r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$" -) - - -@dataclass -class SnapshotSource: - """A snapshot file available from a specific RPC node.""" - - rpc_address: str - # Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst) - file_paths: list[str] = field(default_factory=list) - slots_diff: int = 0 - latency_ms: float = 0.0 - download_speed: float = 0.0 # bytes/sec - - -# -- JSON-RPC helpers ---------------------------------------------------------- - - -class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): - """Handler that captures redirect Location instead of following it.""" - - def redirect_request( - self, - req: Request, - fp: HTTPResponse, - code: int, - msg: str, - headers: dict[str, str], # type: ignore[override] - newurl: str, - ) -> None: - return None - - -def rpc_post(url: str, method: str, params: list[object] | None = None, - timeout: int = 25) -> object | None: - """JSON-RPC POST. Returns parsed 'result' field or None on error.""" - payload: bytes = json.dumps({ - "jsonrpc": "2.0", "id": 1, - "method": method, "params": params or [], - }).encode() - req = Request(url, data=payload, - headers={"Content-Type": "application/json"}) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - data: dict[str, object] = json.loads(resp.read()) - return data.get("result") - except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e: - log.debug("rpc_post %s %s failed: %s", url, method, e) - return None - - -def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]: - """HEAD request without following redirects. - - Returns (Location header value, latency_sec) if the server returned a - 3xx redirect. Returns (None, 0.0) on any error or non-redirect response. - """ - opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler) - req = Request(url, method="HEAD") - try: - start: float = time.monotonic() - resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment] - latency: float = time.monotonic() - start - # Non-redirect (2xx) — server didn't redirect, not useful for discovery - location: str | None = resp.headers.get("Location") - resp.close() - return location, latency - except urllib.error.HTTPError as e: - # 3xx redirects raise HTTPError with the redirect info - latency = time.monotonic() - start # type: ignore[possibly-undefined] - location = e.headers.get("Location") - if location and 300 <= e.code < 400: - return location, latency - return None, 0.0 - except (urllib.error.URLError, OSError, TimeoutError): - return None, 0.0 - - -# -- Discovery ----------------------------------------------------------------- - - -def get_current_slot(rpc_url: str) -> int | None: - """Get current slot from RPC.""" - result: object | None = rpc_post(rpc_url, "getSlot") - if isinstance(result, int): - return result - return None - - -def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]: - """Get all RPC node addresses from getClusterNodes.""" - result: object | None = rpc_post(rpc_url, "getClusterNodes") - if not isinstance(result, list): - return [] - - rpc_addrs: list[str] = [] - for node in result: - if not isinstance(node, dict): - continue - if version_filter is not None: - node_version: str | None = node.get("version") - if node_version and not node_version.startswith(version_filter): - continue - rpc: str | None = node.get("rpc") - if rpc: - rpc_addrs.append(rpc) - return list(set(rpc_addrs)) - - -def _parse_snapshot_filename(location: str) -> tuple[str, str | None]: - """Extract filename and full redirect path from Location header. - - Returns (filename, full_path). full_path includes any path prefix - the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst'). - """ - # Location may be absolute URL or relative path - if location.startswith("http://") or location.startswith("https://"): - # Absolute URL — extract path - from urllib.parse import urlparse - path: str = urlparse(location).path - else: - path = location - - filename: str = path.rsplit("/", 1)[-1] - return filename, path - - -def probe_rpc_snapshot( - rpc_address: str, - current_slot: int, -) -> SnapshotSource | None: - """Probe a single RPC node for available snapshots. - - Discovery only — no filtering. Returns a SnapshotSource with all available - info so the caller can decide what to keep. Filtering happens after all - probes complete, so rejected sources are still visible for debugging. - """ - full_url: str = f"http://{rpc_address}/snapshot.tar.bz2" - - # Full snapshot is required — every source must have one - full_location, full_latency = head_no_follow(full_url, timeout=2) - if not full_location: - return None - - latency_ms: float = full_latency * 1000 - - full_filename, full_path = _parse_snapshot_filename(full_location) - fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename) - if not fm: - return None - - full_snap_slot: int = int(fm.group(1)) - slots_diff: int = current_slot - full_snap_slot - - file_paths: list[str] = [full_path] - - # Also check for incremental snapshot - inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2" - inc_location, _ = head_no_follow(inc_url, timeout=2) - if inc_location: - inc_filename, inc_path = _parse_snapshot_filename(inc_location) - m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename) - if m: - inc_base_slot: int = int(m.group(1)) - # Incremental must be based on this source's full snapshot - if inc_base_slot == full_snap_slot: - file_paths.append(inc_path) - - return SnapshotSource( - rpc_address=rpc_address, - file_paths=file_paths, - slots_diff=slots_diff, - latency_ms=latency_ms, - ) - - -def discover_sources( - rpc_url: str, - current_slot: int, - max_age_slots: int, - max_latency_ms: float, - threads: int, - version_filter: str | None, -) -> list[SnapshotSource]: - """Discover all snapshot sources, then filter. - - Probing and filtering are separate: all reachable sources are collected - first so we can report what exists even if filters reject everything. - """ - rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter) - if not rpc_nodes: - log.error("No RPC nodes found via getClusterNodes") - return [] - - log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes)) - - all_sources: list[SnapshotSource] = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: - futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = { - pool.submit(probe_rpc_snapshot, addr, current_slot): addr - for addr in rpc_nodes - } - done: int = 0 - for future in concurrent.futures.as_completed(futures): - done += 1 - if done % 200 == 0: - log.info(" probed %d/%d nodes, %d reachable", - done, len(rpc_nodes), len(all_sources)) - try: - result: SnapshotSource | None = future.result() - except (urllib.error.URLError, OSError, TimeoutError) as e: - log.debug("Probe failed for %s: %s", futures[future], e) - continue - if result: - all_sources.append(result) - - log.info("Discovered %d reachable sources", len(all_sources)) - - # Apply filters - filtered: list[SnapshotSource] = [] - rejected_age: int = 0 - rejected_latency: int = 0 - for src in all_sources: - if src.slots_diff > max_age_slots or src.slots_diff < -100: - rejected_age += 1 - continue - if src.latency_ms > max_latency_ms: - rejected_latency += 1 - continue - filtered.append(src) - - if rejected_age or rejected_latency: - log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)", - rejected_age, max_age_slots, rejected_latency, max_latency_ms) - - if not filtered and all_sources: - # Show what was available so the user can adjust filters - all_sources.sort(key=lambda s: s.slots_diff) - best = all_sources[0] - log.warning("All %d sources rejected by filters. Best available: " - "%s (age=%d slots, latency=%.0fms). " - "Try --max-snapshot-age %d --max-latency %.0f", - len(all_sources), best.rpc_address, - best.slots_diff, best.latency_ms, - best.slots_diff + 500, - max(best.latency_ms * 1.5, 500)) - - log.info("Found %d sources after filtering", len(filtered)) - return filtered - - -# -- Speed benchmark ----------------------------------------------------------- - - -def measure_speed(rpc_address: str, measure_time: int = 7) -> float: - """Measure download speed from an RPC node. Returns bytes/sec.""" - url: str = f"http://{rpc_address}/snapshot.tar.bz2" - req = Request(url) - try: - with urllib.request.urlopen(req, timeout=measure_time + 5) as resp: - start: float = time.monotonic() - total: int = 0 - while True: - elapsed: float = time.monotonic() - start - if elapsed >= measure_time: - break - chunk: bytes = resp.read(81920) - if not chunk: - break - total += len(chunk) - elapsed = time.monotonic() - start - if elapsed <= 0: - return 0.0 - return total / elapsed - except (urllib.error.URLError, OSError, TimeoutError): - return 0.0 - - -# -- Incremental probing ------------------------------------------------------- - - -def probe_incremental( - fast_sources: list[SnapshotSource], - full_snap_slot: int, -) -> tuple[str | None, list[str]]: - """Probe fast sources for the best incremental matching full_snap_slot. - - Returns (filename, mirror_urls) or (None, []) if no match found. - The "best" incremental is the one with the highest slot (closest to head). - """ - best_filename: str | None = None - best_slot: int = 0 - best_source: SnapshotSource | None = None - best_path: str | None = None - - for source in fast_sources: - inc_url: str = f"http://{source.rpc_address}/incremental-snapshot.tar.bz2" - inc_location, _ = head_no_follow(inc_url, timeout=2) - if not inc_location: - continue - inc_fn, inc_fp = _parse_snapshot_filename(inc_location) - m: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn) - if not m: - continue - if int(m.group(1)) != full_snap_slot: - log.debug(" %s: incremental base slot %s != full %d, skipping", - source.rpc_address, m.group(1), full_snap_slot) - continue - inc_slot: int = int(m.group(2)) - if inc_slot > best_slot: - best_slot = inc_slot - best_filename = inc_fn - best_source = source - best_path = inc_fp - - if best_filename is None or best_source is None or best_path is None: - return None, [] - - # Build mirror list — check other sources for the same filename - mirror_urls: list[str] = [f"http://{best_source.rpc_address}{best_path}"] - for other in fast_sources: - if other.rpc_address == best_source.rpc_address: - continue - other_loc, _ = head_no_follow( - f"http://{other.rpc_address}/incremental-snapshot.tar.bz2", timeout=2) - if other_loc: - other_fn, other_fp = _parse_snapshot_filename(other_loc) - if other_fn == best_filename: - mirror_urls.append(f"http://{other.rpc_address}{other_fp}") - - return best_filename, mirror_urls - - -# -- Download ------------------------------------------------------------------ - - -def download_aria2c( - urls: list[str], - output_dir: str, - filename: str, - connections: int = 16, -) -> bool: - """Download a file using aria2c with parallel connections. - - When multiple URLs are provided, aria2c treats them as mirrors of the - same file and distributes chunks across all of them. - """ - num_mirrors: int = len(urls) - total_splits: int = max(connections, connections * num_mirrors) - cmd: list[str] = [ - "aria2c", - "--file-allocation=none", - "--continue=false", - f"--max-connection-per-server={connections}", - f"--split={total_splits}", - "--min-split-size=50M", - # aria2c retries individual chunk connections on transient network - # errors (TCP reset, timeout). This is transport-level retry analogous - # to TCP retransmit, not application-level retry of a failed operation. - "--max-tries=5", - "--retry-wait=5", - "--timeout=60", - "--connect-timeout=10", - "--summary-interval=10", - "--console-log-level=notice", - f"--dir={output_dir}", - f"--out={filename}", - "--auto-file-renaming=false", - "--allow-overwrite=true", - *urls, - ] - - log.info("Downloading %s", filename) - log.info(" aria2c: %d connections x %d mirrors (%d splits)", - connections, num_mirrors, total_splits) - - start: float = time.monotonic() - result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd) - elapsed: float = time.monotonic() - start - - if result.returncode != 0: - log.error("aria2c failed with exit code %d", result.returncode) - return False - - filepath: Path = Path(output_dir) / filename - if not filepath.exists(): - log.error("aria2c reported success but %s does not exist", filepath) - return False - - size_bytes: int = filepath.stat().st_size - size_gb: float = size_bytes / (1024 ** 3) - avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0 - log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb) - return True - - -# -- Shared helpers ------------------------------------------------------------ - - -def _discover_and_benchmark( - rpc_url: str, - current_slot: int, - *, - max_snapshot_age: int = 10000, - max_latency: float = 500, - threads: int = 500, - min_download_speed: int = 20, - measurement_time: int = 7, - max_speed_checks: int = 15, - version_filter: str | None = None, -) -> list[SnapshotSource]: - """Discover snapshot sources and benchmark download speed. - - Returns sources that meet the minimum speed requirement, sorted by speed. - """ - sources: list[SnapshotSource] = discover_sources( - rpc_url, current_slot, - max_age_slots=max_snapshot_age, - max_latency_ms=max_latency, - threads=threads, - version_filter=version_filter, - ) - if not sources: - return [] - - sources.sort(key=lambda s: s.latency_ms) - - log.info("Benchmarking download speed on top %d sources...", max_speed_checks) - fast_sources: list[SnapshotSource] = [] - checked: int = 0 - min_speed_bytes: int = min_download_speed * 1024 * 1024 - - for source in sources: - if checked >= max_speed_checks: - break - checked += 1 - - speed: float = measure_speed(source.rpc_address, measurement_time) - source.download_speed = speed - speed_mib: float = speed / (1024 ** 2) - - if speed < min_speed_bytes: - log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)", - source.rpc_address, speed_mib, min_download_speed) - continue - - log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)", - source.rpc_address, speed_mib, - source.latency_ms, source.slots_diff) - fast_sources.append(source) - - return fast_sources - - -def _rolling_incremental_download( - fast_sources: list[SnapshotSource], - full_snap_slot: int, - output_dir: str, - convergence_slots: int, - connections: int, - rpc_url: str, -) -> str | None: - """Download incrementals in a loop until converged. - - Probes fast_sources for incrementals matching full_snap_slot, downloads - the freshest one, then re-probes until the gap to head is within - convergence_slots. Returns the filename of the final incremental, - or None if no incremental was found. - """ - prev_inc_filename: str | None = None - loop_start: float = time.monotonic() - max_convergence_time: float = 1800.0 # 30 min wall-clock limit - - while True: - if time.monotonic() - loop_start > max_convergence_time: - if prev_inc_filename: - log.warning("Convergence timeout (%.0fs) — using %s", - max_convergence_time, prev_inc_filename) - else: - log.warning("Convergence timeout (%.0fs) — no incremental downloaded", - max_convergence_time) - break - - inc_fn, inc_mirrors = probe_incremental(fast_sources, full_snap_slot) - if inc_fn is None: - if prev_inc_filename is None: - log.error("No matching incremental found for base slot %d", - full_snap_slot) - else: - log.info("No newer incremental available, using %s", prev_inc_filename) - break - - m_inc: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn) - assert m_inc is not None - inc_slot: int = int(m_inc.group(2)) - - head_slot: int | None = get_current_slot(rpc_url) - if head_slot is None: - log.warning("Cannot get current slot — downloading best available incremental") - gap: int = convergence_slots + 1 - else: - gap = head_slot - inc_slot - - if inc_fn == prev_inc_filename: - if gap <= convergence_slots: - log.info("Incremental %s already downloaded (gap %d slots, converged)", - inc_fn, gap) - break - log.info("No newer incremental yet (slot %d, gap %d slots), waiting...", - inc_slot, gap) - time.sleep(10) - continue - - if prev_inc_filename is not None: - old_path: Path = Path(output_dir) / prev_inc_filename - if old_path.exists(): - log.info("Removing superseded incremental %s", prev_inc_filename) - old_path.unlink() - - log.info("Downloading incremental %s (%d mirrors, slot %d, gap %d slots)", - inc_fn, len(inc_mirrors), inc_slot, gap) - if not download_aria2c(inc_mirrors, output_dir, inc_fn, connections): - log.warning("Failed to download incremental %s — re-probing in 10s", inc_fn) - time.sleep(10) - continue - - prev_inc_filename = inc_fn - - if gap <= convergence_slots: - log.info("Converged: incremental slot %d is %d slots behind head", - inc_slot, gap) - break - - if head_slot is None: - break - - log.info("Not converged (gap %d > %d), re-probing in 10s...", - gap, convergence_slots) - time.sleep(10) - - return prev_inc_filename - - -# -- Public API ---------------------------------------------------------------- - - -def download_incremental_for_slot( - output_dir: str, - full_snap_slot: int, - *, - cluster: str = "mainnet-beta", - rpc_url: str | None = None, - connections: int = 16, - threads: int = 500, - max_snapshot_age: int = 10000, - max_latency: float = 500, - min_download_speed: int = 20, - measurement_time: int = 7, - max_speed_checks: int = 15, - version_filter: str | None = None, - convergence_slots: int = 500, -) -> bool: - """Download an incremental snapshot for an existing full snapshot. - - Discovers sources, benchmarks speed, then runs the rolling incremental - download loop for the given full snapshot base slot. Does NOT download - a full snapshot. - - Returns True if an incremental was downloaded, False otherwise. - """ - resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster] - - if not shutil.which("aria2c"): - log.error("aria2c not found. Install with: apt install aria2") - return False - - log.info("Incremental download for base slot %d", full_snap_slot) - current_slot: int | None = get_current_slot(resolved_rpc) - if current_slot is None: - log.error("Cannot get current slot from %s", resolved_rpc) - return False - - fast_sources: list[SnapshotSource] = _discover_and_benchmark( - resolved_rpc, current_slot, - max_snapshot_age=max_snapshot_age, - max_latency=max_latency, - threads=threads, - min_download_speed=min_download_speed, - measurement_time=measurement_time, - max_speed_checks=max_speed_checks, - version_filter=version_filter, - ) - if not fast_sources: - log.error("No fast sources found") - return False - - os.makedirs(output_dir, exist_ok=True) - result: str | None = _rolling_incremental_download( - fast_sources, full_snap_slot, output_dir, - convergence_slots, connections, resolved_rpc, - ) - return result is not None - - -def download_best_snapshot( - output_dir: str, - *, - cluster: str = "mainnet-beta", - rpc_url: str | None = None, - connections: int = 16, - threads: int = 500, - max_snapshot_age: int = 10000, - max_latency: float = 500, - min_download_speed: int = 20, - measurement_time: int = 7, - max_speed_checks: int = 15, - version_filter: str | None = None, - full_only: bool = False, - convergence_slots: int = 500, -) -> bool: - """Download the best available snapshot to output_dir. - - This is the programmatic API — called by entrypoint.py for automatic - snapshot download. Returns True on success, False on failure. - - All parameters have sensible defaults matching the CLI interface. - """ - resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster] - - if not shutil.which("aria2c"): - log.error("aria2c not found. Install with: apt install aria2") - return False - - log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc) - current_slot: int | None = get_current_slot(resolved_rpc) - if current_slot is None: - log.error("Cannot get current slot from %s", resolved_rpc) - return False - log.info("Current slot: %d", current_slot) - - fast_sources: list[SnapshotSource] = _discover_and_benchmark( - resolved_rpc, current_slot, - max_snapshot_age=max_snapshot_age, - max_latency=max_latency, - threads=threads, - min_download_speed=min_download_speed, - measurement_time=measurement_time, - max_speed_checks=max_speed_checks, - version_filter=version_filter, - ) - if not fast_sources: - log.error("No fast sources found") - return False - - # Use the fastest source as primary, build full snapshot download plan - best: SnapshotSource = fast_sources[0] - full_paths: list[str] = [fp for fp in best.file_paths - if fp.rsplit("/", 1)[-1].startswith("snapshot-")] - if not full_paths: - log.error("Best source has no full snapshot") - return False - - # Build mirror URLs for the full snapshot - full_filename: str = full_paths[0].rsplit("/", 1)[-1] - full_mirrors: list[str] = [f"http://{best.rpc_address}{full_paths[0]}"] - for other in fast_sources[1:]: - for other_fp in other.file_paths: - if other_fp.rsplit("/", 1)[-1] == full_filename: - full_mirrors.append(f"http://{other.rpc_address}{other_fp}") - break - - speed_mib: float = best.download_speed / (1024 ** 2) - log.info("Best source: %s (%.1f MiB/s), %d mirrors", - best.rpc_address, speed_mib, len(full_mirrors)) - - # Download full snapshot - os.makedirs(output_dir, exist_ok=True) - total_start: float = time.monotonic() - - filepath: Path = Path(output_dir) / full_filename - if filepath.exists() and filepath.stat().st_size > 0: - log.info("Skipping %s (already exists: %.1f GB)", - full_filename, filepath.stat().st_size / (1024 ** 3)) - else: - if not download_aria2c(full_mirrors, output_dir, full_filename, connections): - log.error("Failed to download %s", full_filename) - return False - - # Download incremental separately — the full download took minutes, - # so any incremental from discovery is stale. Re-probe for fresh ones. - if not full_only: - fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename) - if fm: - full_snap_slot: int = int(fm.group(1)) - log.info("Downloading incremental for base slot %d...", full_snap_slot) - _rolling_incremental_download( - fast_sources, full_snap_slot, output_dir, - convergence_slots, connections, resolved_rpc, - ) - - total_elapsed: float = time.monotonic() - total_start - log.info("All downloads complete in %.0fs", total_elapsed) - - return True - - -# -- Main (CLI) ---------------------------------------------------------------- - - -def main() -> int: - p: argparse.ArgumentParser = argparse.ArgumentParser( - description="Download Solana snapshots with aria2c parallel downloads", - ) - p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots", - help="Snapshot output directory (default: /srv/kind/solana/snapshots)") - p.add_argument("-c", "--cluster", default="mainnet-beta", - choices=list(CLUSTER_RPC), - help="Solana cluster (default: mainnet-beta)") - p.add_argument("-r", "--rpc", default=None, - help="RPC URL for cluster discovery (default: public RPC)") - p.add_argument("-n", "--connections", type=int, default=16, - help="aria2c connections per download (default: 16)") - p.add_argument("-t", "--threads", type=int, default=500, - help="Threads for parallel RPC probing (default: 500)") - p.add_argument("--max-snapshot-age", type=int, default=10000, - help="Max snapshot age in slots (default: 10000)") - p.add_argument("--max-latency", type=float, default=500, - help="Max RPC probe latency in ms (default: 500)") - p.add_argument("--min-download-speed", type=int, default=20, - help="Min download speed in MiB/s (default: 20)") - p.add_argument("--measurement-time", type=int, default=7, - help="Speed measurement duration in seconds (default: 7)") - p.add_argument("--max-speed-checks", type=int, default=15, - help="Max nodes to benchmark before giving up (default: 15)") - p.add_argument("--version", default=None, - help="Filter nodes by version prefix (e.g. '2.2')") - p.add_argument("--convergence-slots", type=int, default=500, - help="Max slot gap for incremental convergence (default: 500)") - p.add_argument("--full-only", action="store_true", - help="Download only full snapshot, skip incremental") - p.add_argument("--dry-run", action="store_true", - help="Find best source and print URL, don't download") - p.add_argument("--post-cmd", - help="Shell command to run after successful download " - "(e.g. 'kubectl scale deployment ... --replicas=1')") - p.add_argument("-v", "--verbose", action="store_true") - args: argparse.Namespace = p.parse_args() - - logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - ) - - # Dry-run uses the original inline flow (needs access to sources for URL printing) - if args.dry_run: - rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster] - current_slot: int | None = get_current_slot(rpc_url) - if current_slot is None: - log.error("Cannot get current slot from %s", rpc_url) - return 1 - - sources: list[SnapshotSource] = discover_sources( - rpc_url, current_slot, - max_age_slots=args.max_snapshot_age, - max_latency_ms=args.max_latency, - threads=args.threads, - version_filter=args.version, - ) - if not sources: - log.error("No snapshot sources found") - return 1 - - sources.sort(key=lambda s: s.latency_ms) - best = sources[0] - for fp in best.file_paths: - print(f"http://{best.rpc_address}{fp}") - return 0 - - ok: bool = download_best_snapshot( - args.output, - cluster=args.cluster, - rpc_url=args.rpc, - connections=args.connections, - threads=args.threads, - max_snapshot_age=args.max_snapshot_age, - max_latency=args.max_latency, - min_download_speed=args.min_download_speed, - measurement_time=args.measurement_time, - max_speed_checks=args.max_speed_checks, - version_filter=args.version, - full_only=args.full_only, - convergence_slots=args.convergence_slots, - ) - - if ok and args.post_cmd: - log.info("Running post-download command: %s", args.post_cmd) - result: subprocess.CompletedProcess[bytes] = subprocess.run( - args.post_cmd, shell=True, - ) - if result.returncode != 0: - log.error("Post-download command failed with exit code %d", - result.returncode) - return 1 - log.info("Post-download command completed successfully") - - return 0 if ok else 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/agave-container/start-test.sh b/scripts/agave-container/start-test.sh deleted file mode 100644 index e003a97a..00000000 --- a/scripts/agave-container/start-test.sh +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# ----------------------------------------------------------------------- -# Start solana-test-validator with optional SPL token setup -# -# Environment variables: -# FACILITATOR_PUBKEY - facilitator fee-payer public key (base58) -# SERVER_PUBKEY - server/payee wallet public key (base58) -# CLIENT_PUBKEY - client/payer wallet public key (base58) -# MINT_DECIMALS - token decimals (default: 6, matching USDC) -# MINT_AMOUNT - amount to mint to client (default: 1000000000) -# LEDGER_DIR - ledger directory (default: /data/ledger) -# ----------------------------------------------------------------------- - -LEDGER_DIR="${LEDGER_DIR:-/data/ledger}" -MINT_DECIMALS="${MINT_DECIMALS:-6}" -MINT_AMOUNT="${MINT_AMOUNT:-1000000000}" -SETUP_MARKER="${LEDGER_DIR}/.setup-done" - -sudo chown -R "$(id -u):$(id -g)" "$LEDGER_DIR" 2>/dev/null || true - -# Start test-validator in the background -solana-test-validator \ - --ledger "${LEDGER_DIR}" \ - --rpc-port 8899 \ - --bind-address 0.0.0.0 \ - --quiet & - -VALIDATOR_PID=$! - -# Wait for RPC to become available -echo "Waiting for test-validator RPC..." -for i in $(seq 1 60); do - if solana cluster-version --url http://127.0.0.1:8899 >/dev/null 2>&1; then - echo "Test-validator is ready (attempt ${i})" - break - fi - sleep 1 -done - -solana config set --url http://127.0.0.1:8899 - -# Only run setup once (idempotent via marker file) -if [ ! -f "${SETUP_MARKER}" ]; then - echo "Running first-time setup..." - - # Airdrop SOL to all wallets for gas - for PUBKEY in "${FACILITATOR_PUBKEY:-}" "${SERVER_PUBKEY:-}" "${CLIENT_PUBKEY:-}"; do - if [ -n "${PUBKEY}" ]; then - echo "Airdropping 100 SOL to ${PUBKEY}..." - solana airdrop 100 "${PUBKEY}" --url http://127.0.0.1:8899 || true - fi - done - - # Create a USDC-equivalent SPL token mint if any pubkeys are set - if [ -n "${CLIENT_PUBKEY:-}" ] || [ -n "${FACILITATOR_PUBKEY:-}" ] || [ -n "${SERVER_PUBKEY:-}" ]; then - MINT_AUTHORITY_FILE="${LEDGER_DIR}/mint-authority.json" - if [ ! -f "${MINT_AUTHORITY_FILE}" ]; then - solana-keygen new --no-bip39-passphrase --outfile "${MINT_AUTHORITY_FILE}" --force - MINT_AUTH_PUBKEY=$(solana-keygen pubkey "${MINT_AUTHORITY_FILE}") - solana airdrop 10 "${MINT_AUTH_PUBKEY}" --url http://127.0.0.1:8899 - fi - - MINT_ADDRESS_FILE="${LEDGER_DIR}/usdc-mint-address.txt" - if [ ! -f "${MINT_ADDRESS_FILE}" ]; then - spl-token create-token \ - --decimals "${MINT_DECIMALS}" \ - --mint-authority "${MINT_AUTHORITY_FILE}" \ - --url http://127.0.0.1:8899 \ - 2>&1 | grep "Creating token" | awk '{print $3}' > "${MINT_ADDRESS_FILE}" - echo "Created USDC mint: $(cat "${MINT_ADDRESS_FILE}")" - fi - - USDC_MINT=$(cat "${MINT_ADDRESS_FILE}") - - # Create ATAs and mint tokens for the client - if [ -n "${CLIENT_PUBKEY:-}" ]; then - echo "Creating ATA for client ${CLIENT_PUBKEY}..." - spl-token create-account "${USDC_MINT}" \ - --owner "${CLIENT_PUBKEY}" \ - --fee-payer "${MINT_AUTHORITY_FILE}" \ - --url http://127.0.0.1:8899 || true - - echo "Minting ${MINT_AMOUNT} tokens to client..." - spl-token mint "${USDC_MINT}" "${MINT_AMOUNT}" \ - --recipient-owner "${CLIENT_PUBKEY}" \ - --mint-authority "${MINT_AUTHORITY_FILE}" \ - --url http://127.0.0.1:8899 || true - fi - - # Create ATAs for server and facilitator - for PUBKEY in "${SERVER_PUBKEY:-}" "${FACILITATOR_PUBKEY:-}"; do - if [ -n "${PUBKEY}" ]; then - echo "Creating ATA for ${PUBKEY}..." - spl-token create-account "${USDC_MINT}" \ - --owner "${PUBKEY}" \ - --fee-payer "${MINT_AUTHORITY_FILE}" \ - --url http://127.0.0.1:8899 || true - fi - done - - # Expose mint address for other containers - cp "${MINT_ADDRESS_FILE}" /tmp/usdc-mint-address.txt 2>/dev/null || true - fi - - touch "${SETUP_MARKER}" - echo "Setup complete." -fi - -echo "solana-test-validator running (PID ${VALIDATOR_PID})" -wait ${VALIDATOR_PID}