chore: remove scripts/agave-container before subtree add

Moving container scripts into agave-stack subtree (correct direction).
The source of truth will be agave-stack/ in this repo, pushed out to
LaconicNetwork/agave-stack via git subtree push.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
fix/kind-mount-propagation
A. F. Dudley 2026-03-10 06:21:12 +00:00
parent 08380ec070
commit 7c58809cc1
6 changed files with 0 additions and 2023 deletions

View File

@ -1,81 +0,0 @@
# Unified Agave/Jito Solana image
# Supports three modes via AGAVE_MODE env: test, rpc, validator
#
# Build args:
# AGAVE_REPO - git repo URL (anza-xyz/agave or jito-foundation/jito-solana)
# AGAVE_VERSION - git tag to build (e.g. v3.1.9, v3.1.8-jito)
ARG AGAVE_REPO=https://github.com/anza-xyz/agave.git
ARG AGAVE_VERSION=v3.1.9
# ---------- Stage 1: Build ----------
FROM rust:1.85-bookworm AS builder
ARG AGAVE_REPO
ARG AGAVE_VERSION
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
pkg-config \
libssl-dev \
libudev-dev \
libclang-dev \
protobuf-compiler \
ca-certificates \
git \
cmake \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
RUN git clone "$AGAVE_REPO" --depth 1 --branch "$AGAVE_VERSION" --recurse-submodules agave
WORKDIR /build/agave
# Cherry-pick --public-tvu-address support (anza-xyz/agave PR #6778, commit 9f4b3ae)
# This flag only exists on master, not in v3.1.9 — fetch the PR ref and cherry-pick
ARG TVU_ADDRESS_PR=6778
RUN if [ -n "$TVU_ADDRESS_PR" ]; then \
git fetch --depth 50 origin "pull/${TVU_ADDRESS_PR}/head:tvu-pr" && \
git cherry-pick --no-commit tvu-pr; \
fi
# Build all binaries using the upstream install script
RUN CI_COMMIT=$(git rev-parse HEAD) scripts/cargo-install-all.sh /solana-release
# ---------- Stage 2: Runtime ----------
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libssl3 \
libudev1 \
curl \
sudo \
aria2 \
python3 \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user with sudo
RUN useradd -m -s /bin/bash agave \
&& echo "agave ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
# Copy all compiled binaries
COPY --from=builder /solana-release/bin/ /usr/local/bin/
# Copy entrypoint and support scripts
COPY entrypoint.py snapshot_download.py ip_echo_preflight.py /usr/local/bin/
COPY start-test.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.py /usr/local/bin/start-test.sh
# Create data directories
RUN mkdir -p /data/config /data/ledger /data/accounts /data/snapshots \
&& chown -R agave:agave /data
USER agave
WORKDIR /data
ENV RUST_LOG=info
ENV RUST_BACKTRACE=1
EXPOSE 8899 8900 8001 8001/udp
ENTRYPOINT ["entrypoint.py"]

View File

@ -1,17 +0,0 @@
#!/usr/bin/env bash
# Build laconicnetwork/agave
# Set AGAVE_REPO and AGAVE_VERSION env vars to build Jito or a different version
source ${CERC_CONTAINER_BASE_DIR}/build-base.sh
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
AGAVE_REPO="${AGAVE_REPO:-https://github.com/anza-xyz/agave.git}"
AGAVE_VERSION="${AGAVE_VERSION:-v3.1.9}"
docker build -t laconicnetwork/agave:local \
--build-arg AGAVE_REPO="$AGAVE_REPO" \
--build-arg AGAVE_VERSION="$AGAVE_VERSION" \
${build_command_args} \
-f ${SCRIPT_DIR}/Dockerfile \
${SCRIPT_DIR}

View File

@ -1,686 +0,0 @@
#!/usr/bin/env python3
"""Agave validator entrypoint — snapshot management, arg construction, liveness probe.
Two subcommands:
entrypoint.py serve (default) snapshot freshness check + run agave-validator
entrypoint.py probe liveness probe (slot lag check, exits 0/1)
Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single
Python module. Test mode still dispatches to start-test.sh.
Python stays as PID 1 and traps SIGTERM. On SIGTERM, it runs
``agave-validator exit --force --ledger /data/ledger`` which connects to the
admin RPC Unix socket and tells the validator to flush I/O and exit cleanly.
This avoids the io_uring/ZFS deadlock that occurs when the process is killed.
All configuration comes from environment variables same vars as the original
bash scripts. See compose files for defaults.
"""
from __future__ import annotations
import json
import logging
import os
import re
import signal
import subprocess
import sys
import threading
import time
import urllib.error
import urllib.request
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("entrypoint")
# Directories
CONFIG_DIR = "/data/config"
LEDGER_DIR = "/data/ledger"
ACCOUNTS_DIR = "/data/accounts"
SNAPSHOTS_DIR = "/data/snapshots"
LOG_DIR = "/data/log"
IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json"
# Snapshot filename patterns
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
)
INCR_SNAP_RE: re.Pattern[str] = re.compile(
r"^incremental-snapshot-(\d+)-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
)
MAINNET_RPC = "https://api.mainnet-beta.solana.com"
# -- Helpers -------------------------------------------------------------------
def env(name: str, default: str = "") -> str:
"""Read env var with default."""
return os.environ.get(name, default)
def env_required(name: str) -> str:
"""Read required env var, exit if missing."""
val = os.environ.get(name)
if not val:
log.error("%s is required but not set", name)
sys.exit(1)
return val
def env_bool(name: str, default: bool = False) -> bool:
"""Read boolean env var (true/false/1/0)."""
val = os.environ.get(name, "").lower()
if not val:
return default
return val in ("true", "1", "yes")
def rpc_get_slot(url: str, timeout: int = 10) -> int | None:
"""Get current slot from a Solana RPC endpoint."""
payload = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": "getSlot", "params": [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read())
result = data.get("result")
if isinstance(result, int):
return result
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError):
pass
return None
# -- Snapshot management -------------------------------------------------------
def get_local_snapshot_slot(snapshots_dir: str) -> int | None:
"""Find the highest slot among local snapshot files."""
best_slot: int | None = None
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return None
for entry in snap_path.iterdir():
m = FULL_SNAP_RE.match(entry.name)
if m:
slot = int(m.group(1))
if best_slot is None or slot > best_slot:
best_slot = slot
return best_slot
def clean_snapshots(snapshots_dir: str) -> None:
"""Remove all snapshot files from the directory."""
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return
for entry in snap_path.iterdir():
if entry.name.startswith(("snapshot-", "incremental-snapshot-")):
log.info("Removing old snapshot: %s", entry.name)
entry.unlink(missing_ok=True)
def get_incremental_slot(snapshots_dir: str, full_slot: int | None) -> int | None:
"""Get the highest incremental snapshot slot matching the full's base slot."""
if full_slot is None:
return None
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return None
best: int | None = None
for entry in snap_path.iterdir():
m = INCR_SNAP_RE.match(entry.name)
if m and int(m.group(1)) == full_slot:
slot = int(m.group(2))
if best is None or slot > best:
best = slot
return best
def maybe_download_snapshot(snapshots_dir: str) -> None:
"""Ensure full + incremental snapshots exist before starting.
The validator should always start from a full + incremental pair to
minimize replay time. If either is missing or the full is too old,
download fresh ones via download_best_snapshot (which does rolling
incremental convergence after downloading the full).
Controlled by env vars:
SNAPSHOT_AUTO_DOWNLOAD (default: true) enable/disable
SNAPSHOT_MAX_AGE_SLOTS (default: 100000) full snapshot staleness threshold
(one full snapshot generation, ~11 hours)
"""
if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True):
log.info("Snapshot auto-download disabled")
return
max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "100000"))
mainnet_slot = rpc_get_slot(MAINNET_RPC)
if mainnet_slot is None:
log.warning("Cannot reach mainnet RPC — skipping snapshot check")
return
script_dir = Path(__file__).resolve().parent
sys.path.insert(0, str(script_dir))
from snapshot_download import download_best_snapshot, download_incremental_for_slot
convergence = int(env("SNAPSHOT_CONVERGENCE_SLOTS", "500"))
retry_delay = int(env("SNAPSHOT_RETRY_DELAY", "60"))
# Check local full snapshot
local_slot = get_local_snapshot_slot(snapshots_dir)
have_fresh_full = (local_slot is not None
and (mainnet_slot - local_slot) <= max_age)
if have_fresh_full:
assert local_slot is not None
inc_slot = get_incremental_slot(snapshots_dir, local_slot)
if inc_slot is not None:
inc_gap = mainnet_slot - inc_slot
if inc_gap <= convergence:
log.info("Full (slot %d) + incremental (slot %d, gap %d) "
"within convergence, starting",
local_slot, inc_slot, inc_gap)
return
log.info("Incremental too stale (slot %d, gap %d > %d)",
inc_slot, inc_gap, convergence)
# Fresh full, need a fresh incremental
log.info("Downloading incremental for full at slot %d", local_slot)
while True:
if download_incremental_for_slot(snapshots_dir, local_slot,
convergence_slots=convergence):
return
log.warning("Incremental download failed — retrying in %ds",
retry_delay)
time.sleep(retry_delay)
# No full or full too old — download both
log.info("Downloading full + incremental")
clean_snapshots(snapshots_dir)
while True:
if download_best_snapshot(snapshots_dir, convergence_slots=convergence):
return
log.warning("Snapshot download failed — retrying in %ds", retry_delay)
time.sleep(retry_delay)
# -- Directory and identity setup ----------------------------------------------
def ensure_dirs(*dirs: str) -> None:
"""Create directories and fix ownership."""
uid = os.getuid()
gid = os.getgid()
for d in dirs:
os.makedirs(d, exist_ok=True)
try:
subprocess.run(
["sudo", "chown", "-R", f"{uid}:{gid}", d],
check=False, capture_output=True,
)
except FileNotFoundError:
pass # sudo not available — dirs already owned correctly
def ensure_identity_rpc() -> None:
"""Generate ephemeral identity keypair for RPC mode if not mounted."""
if os.path.isfile(IDENTITY_FILE):
return
log.info("Generating RPC node identity keypair...")
subprocess.run(
["solana-keygen", "new", "--no-passphrase", "--silent",
"--force", "--outfile", IDENTITY_FILE],
check=True,
)
def print_identity() -> None:
"""Print the node identity pubkey."""
result = subprocess.run(
["solana-keygen", "pubkey", IDENTITY_FILE],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Node identity: %s", result.stdout.strip())
# -- Arg construction ----------------------------------------------------------
def build_common_args() -> list[str]:
"""Build agave-validator args common to both RPC and validator modes."""
args: list[str] = [
"--identity", IDENTITY_FILE,
"--entrypoint", env_required("VALIDATOR_ENTRYPOINT"),
"--known-validator", env_required("KNOWN_VALIDATOR"),
"--ledger", LEDGER_DIR,
"--accounts", ACCOUNTS_DIR,
"--snapshots", SNAPSHOTS_DIR,
"--rpc-port", env("RPC_PORT", "8899"),
"--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"),
"--gossip-port", env("GOSSIP_PORT", "8001"),
"--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"),
"--no-os-network-limits-test",
"--wal-recovery-mode", "skip_any_corrupted_record",
"--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"),
"--no-snapshot-fetch", # entrypoint handles snapshot download
]
# Snapshot generation
if env("NO_SNAPSHOTS") == "true":
args.append("--no-snapshots")
else:
args += [
"--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"),
"--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "1"),
]
if env("NO_INCREMENTAL_SNAPSHOTS") != "true":
args += ["--maximum-incremental-snapshots-to-retain", "2"]
# Account indexes
account_indexes = env("ACCOUNT_INDEXES")
if account_indexes:
for idx in account_indexes.split(","):
idx = idx.strip()
if idx:
args += ["--account-index", idx]
# Additional entrypoints
for ep in env("EXTRA_ENTRYPOINTS").split():
if ep:
args += ["--entrypoint", ep]
# Additional known validators
for kv in env("EXTRA_KNOWN_VALIDATORS").split():
if kv:
args += ["--known-validator", kv]
# Cluster verification
genesis_hash = env("EXPECTED_GENESIS_HASH")
if genesis_hash:
args += ["--expected-genesis-hash", genesis_hash]
shred_version = env("EXPECTED_SHRED_VERSION")
if shred_version:
args += ["--expected-shred-version", shred_version]
# Metrics — just needs to be in the environment, agave reads it directly
# (env var is already set, nothing to pass as arg)
# Gossip host / TVU address
gossip_host = env("GOSSIP_HOST")
if gossip_host:
args += ["--gossip-host", gossip_host]
elif env("PUBLIC_TVU_ADDRESS"):
args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")]
# Jito flags
if env("JITO_ENABLE") == "true":
log.info("Jito MEV enabled")
jito_flags: list[tuple[str, str]] = [
("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"),
("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"),
("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"),
("JITO_COMMISSION_BPS", "--commission-bps"),
("JITO_BLOCK_ENGINE_URL", "--block-engine-url"),
("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"),
]
for env_name, flag in jito_flags:
val = env(env_name)
if val:
args += [flag, val]
return args
def build_rpc_args() -> list[str]:
"""Build agave-validator args for RPC (non-voting) mode."""
args = build_common_args()
args += [
"--no-voting",
"--log", f"{LOG_DIR}/validator.log",
"--full-rpc-api",
"--enable-rpc-transaction-history",
"--rpc-pubsub-enable-block-subscription",
"--enable-extended-tx-metadata-storage",
"--no-wait-for-vote-to-start-leader",
]
# Public vs private RPC
public_rpc = env("PUBLIC_RPC_ADDRESS")
if public_rpc:
args += ["--public-rpc-address", public_rpc]
else:
args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"]
# Jito relayer URL (RPC mode doesn't use it, but validator mode does —
# handled in build_validator_args)
return args
def build_validator_args() -> list[str]:
"""Build agave-validator args for voting validator mode."""
vote_keypair = env("VOTE_ACCOUNT_KEYPAIR",
"/data/config/vote-account-keypair.json")
# Identity must be mounted for validator mode
if not os.path.isfile(IDENTITY_FILE):
log.error("Validator identity keypair not found at %s", IDENTITY_FILE)
log.error("Mount your validator keypair to %s", IDENTITY_FILE)
sys.exit(1)
# Vote account keypair must exist
if not os.path.isfile(vote_keypair):
log.error("Vote account keypair not found at %s", vote_keypair)
log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR")
sys.exit(1)
# Print vote account pubkey
result = subprocess.run(
["solana-keygen", "pubkey", vote_keypair],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Vote account: %s", result.stdout.strip())
args = build_common_args()
args += [
"--vote-account", vote_keypair,
"--log", "-",
]
# Jito relayer URL (validator-only)
relayer_url = env("JITO_RELAYER_URL")
if env("JITO_ENABLE") == "true" and relayer_url:
args += ["--relayer-url", relayer_url]
return args
def append_extra_args(args: list[str]) -> list[str]:
"""Append EXTRA_ARGS passthrough flags."""
extra = env("EXTRA_ARGS")
if extra:
args += extra.split()
return args
# -- Graceful shutdown --------------------------------------------------------
# Timeout for graceful exit via admin RPC. Leave 30s margin for k8s
# terminationGracePeriodSeconds (300s).
GRACEFUL_EXIT_TIMEOUT = 270
def graceful_exit(child: subprocess.Popen[bytes], reason: str = "SIGTERM") -> None:
"""Request graceful shutdown via the admin RPC Unix socket.
Runs ``agave-validator exit --force --ledger /data/ledger`` which connects
to the admin RPC socket at ``/data/ledger/admin.rpc`` and sets the
validator's exit flag. The validator flushes all I/O and exits cleanly,
avoiding the io_uring/ZFS deadlock.
If the admin RPC exit fails or the child doesn't exit within the timeout,
falls back to SIGTERM then SIGKILL.
"""
log.info("%s — requesting graceful exit via admin RPC", reason)
try:
result = subprocess.run(
["agave-validator", "exit", "--force", "--ledger", LEDGER_DIR],
capture_output=True, text=True, timeout=30,
)
if result.returncode == 0:
log.info("Admin RPC exit requested successfully")
else:
log.warning(
"Admin RPC exit returned %d: %s",
result.returncode, result.stderr.strip(),
)
except subprocess.TimeoutExpired:
log.warning("Admin RPC exit command timed out after 30s")
except FileNotFoundError:
log.warning("agave-validator binary not found for exit command")
# Wait for child to exit
try:
child.wait(timeout=GRACEFUL_EXIT_TIMEOUT)
log.info("Validator exited cleanly with code %d", child.returncode)
return
except subprocess.TimeoutExpired:
log.warning(
"Validator did not exit within %ds — sending SIGTERM",
GRACEFUL_EXIT_TIMEOUT,
)
# Fallback: SIGTERM
child.terminate()
try:
child.wait(timeout=15)
log.info("Validator exited after SIGTERM with code %d", child.returncode)
return
except subprocess.TimeoutExpired:
log.warning("Validator did not exit after SIGTERM — sending SIGKILL")
# Last resort: SIGKILL
child.kill()
child.wait()
log.info("Validator killed with SIGKILL, code %d", child.returncode)
# -- Serve subcommand ---------------------------------------------------------
def _gap_monitor(
child: subprocess.Popen[bytes],
leapfrog: threading.Event,
shutting_down: threading.Event,
) -> None:
"""Background thread: poll slot gap and trigger leapfrog if too far behind.
Waits for a grace period (SNAPSHOT_MONITOR_GRACE, default 600s) before
monitoring the validator needs time to extract snapshots and catch up.
Then polls every SNAPSHOT_MONITOR_INTERVAL (default 30s). If the gap
exceeds SNAPSHOT_LEAPFROG_SLOTS (default 5000) for SNAPSHOT_LEAPFROG_CHECKS
(default 3) consecutive checks, triggers graceful shutdown and sets the
leapfrog event so cmd_serve loops back to download a fresh incremental.
"""
threshold = int(env("SNAPSHOT_LEAPFROG_SLOTS", "5000"))
required_checks = int(env("SNAPSHOT_LEAPFROG_CHECKS", "3"))
interval = int(env("SNAPSHOT_MONITOR_INTERVAL", "30"))
grace = int(env("SNAPSHOT_MONITOR_GRACE", "600"))
rpc_port = env("RPC_PORT", "8899")
local_url = f"http://127.0.0.1:{rpc_port}"
# Grace period — don't monitor during initial catch-up
if shutting_down.wait(grace):
return
consecutive = 0
while not shutting_down.is_set():
local_slot = rpc_get_slot(local_url, timeout=5)
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
if local_slot is not None and mainnet_slot is not None:
gap = mainnet_slot - local_slot
if gap > threshold:
consecutive += 1
log.warning("Gap %d > %d (%d/%d consecutive)",
gap, threshold, consecutive, required_checks)
if consecutive >= required_checks:
log.warning("Leapfrog triggered: gap %d", gap)
leapfrog.set()
graceful_exit(child, reason="Leapfrog")
return
else:
if consecutive > 0:
log.info("Gap %d within threshold, resetting counter", gap)
consecutive = 0
shutting_down.wait(interval)
def cmd_serve() -> None:
"""Main serve flow: snapshot download, run validator, monitor gap, leapfrog.
Python stays as PID 1. On each iteration:
1. Download full + incremental snapshots (if needed)
2. Start agave-validator as child process
3. Monitor slot gap in background thread
4. If gap exceeds threshold graceful stop loop back to step 1
5. If SIGTERM graceful stop exit
6. If validator crashes exit with its return code
"""
mode = env("AGAVE_MODE", "test")
log.info("AGAVE_MODE=%s", mode)
if mode == "test":
os.execvp("start-test.sh", ["start-test.sh"])
if mode not in ("rpc", "validator"):
log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode)
sys.exit(1)
# One-time setup
dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR]
if mode == "rpc":
dirs.append(LOG_DIR)
ensure_dirs(*dirs)
if not env_bool("SKIP_IP_ECHO_PREFLIGHT"):
script_dir = Path(__file__).resolve().parent
sys.path.insert(0, str(script_dir))
from ip_echo_preflight import main as ip_echo_main
if ip_echo_main() != 0:
sys.exit(1)
if mode == "rpc":
ensure_identity_rpc()
print_identity()
if mode == "rpc":
args = build_rpc_args()
else:
args = build_validator_args()
args = append_extra_args(args)
# Main loop: download → run → monitor → leapfrog if needed
while True:
maybe_download_snapshot(SNAPSHOTS_DIR)
Path("/tmp/entrypoint-start").write_text(str(time.time()))
log.info("Starting agave-validator with %d arguments", len(args))
child = subprocess.Popen(["agave-validator"] + args)
shutting_down = threading.Event()
leapfrog = threading.Event()
signal.signal(signal.SIGUSR1,
lambda _sig, _frame: child.send_signal(signal.SIGUSR1))
def _on_sigterm(_sig: int, _frame: object) -> None:
shutting_down.set()
threading.Thread(
target=graceful_exit, args=(child,), daemon=True,
).start()
signal.signal(signal.SIGTERM, _on_sigterm)
# Start gap monitor
monitor = threading.Thread(
target=_gap_monitor,
args=(child, leapfrog, shutting_down),
daemon=True,
)
monitor.start()
child.wait()
if leapfrog.is_set():
log.info("Leapfrog: restarting with fresh incremental")
continue
sys.exit(child.returncode)
# -- Probe subcommand ---------------------------------------------------------
def cmd_probe() -> None:
"""Liveness probe: check local RPC slot vs mainnet.
Exit 0 = healthy, exit 1 = unhealthy.
Grace period: PROBE_GRACE_SECONDS (default 600) probe always passes
during grace period to allow for snapshot unpacking and initial replay.
"""
grace_seconds = int(env("PROBE_GRACE_SECONDS", "600"))
max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000"))
# Check grace period
start_file = Path("/tmp/entrypoint-start")
if start_file.exists():
try:
start_time = float(start_file.read_text().strip())
elapsed = time.time() - start_time
if elapsed < grace_seconds:
# Within grace period — always healthy
sys.exit(0)
except (ValueError, OSError):
pass
else:
# No start file — serve hasn't started yet, within grace
sys.exit(0)
# Query local RPC
rpc_port = env("RPC_PORT", "8899")
local_url = f"http://127.0.0.1:{rpc_port}"
local_slot = rpc_get_slot(local_url, timeout=5)
if local_slot is None:
# Local RPC unreachable after grace period — unhealthy
sys.exit(1)
# Query mainnet
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
if mainnet_slot is None:
# Can't reach mainnet to compare — assume healthy (don't penalize
# the validator for mainnet RPC being down)
sys.exit(0)
lag = mainnet_slot - local_slot
if lag > max_lag:
sys.exit(1)
sys.exit(0)
# -- Main ----------------------------------------------------------------------
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
datefmt="%H:%M:%S",
)
subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve"
if subcmd == "serve":
cmd_serve()
elif subcmd == "probe":
cmd_probe()
else:
log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,249 +0,0 @@
#!/usr/bin/env python3
"""ip_echo preflight — verify UDP port reachability before starting the validator.
Implements the Solana ip_echo client protocol exactly:
1. Bind UDP sockets on the ports the validator will use
2. TCP connect to entrypoint gossip port, send IpEchoServerMessage
3. Parse IpEchoServerResponse (our IP as seen by entrypoint)
4. Wait for entrypoint's UDP probes on each port
5. Exit 0 if all ports reachable, exit 1 if any fail
Wire format (from agave net-utils/src/):
Request: 4 null bytes + [u16; 4] tcp_ports LE + [u16; 4] udp_ports LE + \n
Response: 4 null bytes + bincode IpAddr (variant byte + addr) + optional shred_version
Called from entrypoint.py before snapshot download. Prevents wasting hours
downloading a snapshot only to crash-loop on port reachability.
"""
from __future__ import annotations
import logging
import os
import socket
import struct
import sys
import threading
import time
log = logging.getLogger("ip_echo_preflight")
HEADER = b"\x00\x00\x00\x00"
TERMINUS = b"\x0a"
RESPONSE_BUF = 27
IO_TIMEOUT = 5.0
PROBE_TIMEOUT = 10.0
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def build_request(tcp_ports: list[int], udp_ports: list[int]) -> bytes:
"""Build IpEchoServerMessage: header + [u16;4] tcp + [u16;4] udp + newline."""
tcp = (tcp_ports + [0, 0, 0, 0])[:4]
udp = (udp_ports + [0, 0, 0, 0])[:4]
return HEADER + struct.pack("<4H", *tcp) + struct.pack("<4H", *udp) + TERMINUS
def parse_response(data: bytes) -> tuple[str, int | None]:
"""Parse IpEchoServerResponse → (ip_string, shred_version | None).
Wire format (bincode):
4 bytes header (\0\0\0\0)
4 bytes IpAddr enum variant (u32 LE: 0=IPv4, 1=IPv6)
4|16 bytes address octets
1 byte Option tag (0=None, 1=Some)
2 bytes shred_version (u16 LE, only if Some)
"""
if len(data) < 8:
raise ValueError(f"response too short: {len(data)} bytes")
if data[:4] == b"HTTP":
raise ValueError("got HTTP response — not an ip_echo server")
if data[:4] != HEADER:
raise ValueError(f"unexpected header: {data[:4].hex()}")
variant = struct.unpack("<I", data[4:8])[0]
if variant == 0: # IPv4
if len(data) < 12:
raise ValueError(f"IPv4 response truncated: {len(data)} bytes")
ip = socket.inet_ntoa(data[8:12])
rest = data[12:]
elif variant == 1: # IPv6
if len(data) < 24:
raise ValueError(f"IPv6 response truncated: {len(data)} bytes")
ip = socket.inet_ntop(socket.AF_INET6, data[8:24])
rest = data[24:]
else:
raise ValueError(f"unknown IpAddr variant: {variant}")
shred_version = None
if len(rest) >= 3 and rest[0] == 1:
shred_version = struct.unpack("<H", rest[1:3])[0]
return ip, shred_version
def _listen_udp(port: int, results: dict, stop: threading.Event) -> None:
"""Bind a UDP socket and wait for a probe packet."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", port))
sock.settimeout(0.5)
try:
while not stop.is_set():
try:
_data, addr = sock.recvfrom(64)
results[port] = ("ok", addr)
return
except socket.timeout:
continue
finally:
sock.close()
except OSError as exc:
results[port] = ("bind_error", str(exc))
def ip_echo_check(
entrypoint_host: str,
entrypoint_port: int,
udp_ports: list[int],
) -> tuple[str, dict[int, bool]]:
"""Run one ip_echo exchange and return (seen_ip, {port: reachable}).
Raises on TCP failure (caller retries).
"""
udp_ports = [p for p in udp_ports if p != 0][:4]
# Start UDP listeners before sending the TCP request
results: dict[int, tuple] = {}
stop = threading.Event()
threads = []
for port in udp_ports:
t = threading.Thread(target=_listen_udp, args=(port, results, stop), daemon=True)
t.start()
threads.append(t)
time.sleep(0.1) # let listeners bind
# TCP: send request, read response
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(IO_TIMEOUT)
try:
sock.connect((entrypoint_host, entrypoint_port))
sock.sendall(build_request([], udp_ports))
resp = sock.recv(RESPONSE_BUF)
finally:
sock.close()
seen_ip, shred_version = parse_response(resp)
log.info(
"entrypoint %s:%d sees us as %s (shred_version=%s)",
entrypoint_host, entrypoint_port, seen_ip, shred_version,
)
# Wait for UDP probes
deadline = time.monotonic() + PROBE_TIMEOUT
while time.monotonic() < deadline:
if all(p in results for p in udp_ports):
break
time.sleep(0.2)
stop.set()
for t in threads:
t.join(timeout=1)
port_ok: dict[int, bool] = {}
for port in udp_ports:
if port not in results:
log.error("port %d: no probe received within %.0fs", port, PROBE_TIMEOUT)
port_ok[port] = False
else:
status, detail = results[port]
if status == "ok":
log.info("port %d: probe received from %s", port, detail)
port_ok[port] = True
else:
log.error("port %d: %s: %s", port, status, detail)
port_ok[port] = False
return seen_ip, port_ok
def run_preflight(
entrypoint_host: str,
entrypoint_port: int,
udp_ports: list[int],
expected_ip: str = "",
) -> bool:
"""Run ip_echo check with retries. Returns True if all ports pass."""
for attempt in range(1, MAX_RETRIES + 1):
log.info("ip_echo attempt %d/%d%s:%d, ports %s",
attempt, MAX_RETRIES, entrypoint_host, entrypoint_port, udp_ports)
try:
seen_ip, port_ok = ip_echo_check(entrypoint_host, entrypoint_port, udp_ports)
except Exception as exc:
log.error("attempt %d TCP failed: %s", attempt, exc)
if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY)
continue
if expected_ip and seen_ip != expected_ip:
log.error(
"IP MISMATCH: entrypoint sees %s, expected %s (GOSSIP_HOST). "
"Outbound mangle/SNAT path is broken.",
seen_ip, expected_ip,
)
if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY)
continue
reachable = [p for p, ok in port_ok.items() if ok]
unreachable = [p for p, ok in port_ok.items() if not ok]
if not unreachable:
log.info("PASS: all ports reachable %s, seen as %s", reachable, seen_ip)
return True
log.error(
"attempt %d: unreachable %s, reachable %s, seen as %s",
attempt, unreachable, reachable, seen_ip,
)
if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY)
log.error("FAIL: ip_echo preflight exhausted %d attempts", MAX_RETRIES)
return False
def main() -> int:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
datefmt="%H:%M:%S",
)
# Parse entrypoint — VALIDATOR_ENTRYPOINT is "host:port"
raw = os.environ.get("VALIDATOR_ENTRYPOINT", "")
if not raw and len(sys.argv) > 1:
raw = sys.argv[1]
if not raw:
log.error("set VALIDATOR_ENTRYPOINT or pass host:port as argument")
return 1
if ":" in raw:
host, port_str = raw.rsplit(":", 1)
ep_port = int(port_str)
else:
host = raw
ep_port = 8001
gossip_port = int(os.environ.get("GOSSIP_PORT", "8001"))
dynamic_range = os.environ.get("DYNAMIC_PORT_RANGE", "9000-10000")
range_start = int(dynamic_range.split("-")[0])
expected_ip = os.environ.get("GOSSIP_HOST", "")
# Test gossip + first 3 ports from dynamic range (4 max per ip_echo message)
udp_ports = [gossip_port, range_start, range_start + 2, range_start + 3]
ok = run_preflight(host, ep_port, udp_ports, expected_ip)
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,878 +0,0 @@
#!/usr/bin/env python3
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
cluster, probing each for available snapshots, benchmarking download speed,
and downloading from the fastest source using aria2c (16 connections by default).
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
the single-connection wget download with aria2c parallel chunked downloads.
Usage:
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
./snapshot_download.py -o /srv/kind/solana/snapshots
# Dry run — find best source, print URL
./snapshot_download.py --dry-run
# Custom RPC for cluster discovery + 32 connections
./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32
# Testnet
./snapshot_download.py -c testnet -o /data/snapshots
# Programmatic use from entrypoint.py:
from snapshot_download import download_best_snapshot
ok = download_best_snapshot("/data/snapshots")
Requirements:
- aria2c (apt install aria2)
- python3 >= 3.10 (stdlib only, no pip dependencies)
"""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from http.client import HTTPResponse
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("snapshot-download")
CLUSTER_RPC: dict[str, str] = {
"mainnet-beta": "https://api.mainnet-beta.solana.com",
"testnet": "https://api.testnet.solana.com",
"devnet": "https://api.devnet.solana.com",
}
# Snapshot filenames:
# snapshot-<slot>-<hash>.tar.zst
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
INCR_SNAP_RE: re.Pattern[str] = re.compile(
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
@dataclass
class SnapshotSource:
"""A snapshot file available from a specific RPC node."""
rpc_address: str
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
file_paths: list[str] = field(default_factory=list)
slots_diff: int = 0
latency_ms: float = 0.0
download_speed: float = 0.0 # bytes/sec
# -- JSON-RPC helpers ----------------------------------------------------------
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Handler that captures redirect Location instead of following it."""
def redirect_request(
self,
req: Request,
fp: HTTPResponse,
code: int,
msg: str,
headers: dict[str, str], # type: ignore[override]
newurl: str,
) -> None:
return None
def rpc_post(url: str, method: str, params: list[object] | None = None,
timeout: int = 25) -> object | None:
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
payload: bytes = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": method, "params": params or [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data: dict[str, object] = json.loads(resp.read())
return data.get("result")
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
log.debug("rpc_post %s %s failed: %s", url, method, e)
return None
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
"""HEAD request without following redirects.
Returns (Location header value, latency_sec) if the server returned a
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
"""
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
req = Request(url, method="HEAD")
try:
start: float = time.monotonic()
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
latency: float = time.monotonic() - start
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
location: str | None = resp.headers.get("Location")
resp.close()
return location, latency
except urllib.error.HTTPError as e:
# 3xx redirects raise HTTPError with the redirect info
latency = time.monotonic() - start # type: ignore[possibly-undefined]
location = e.headers.get("Location")
if location and 300 <= e.code < 400:
return location, latency
return None, 0.0
except (urllib.error.URLError, OSError, TimeoutError):
return None, 0.0
# -- Discovery -----------------------------------------------------------------
def get_current_slot(rpc_url: str) -> int | None:
"""Get current slot from RPC."""
result: object | None = rpc_post(rpc_url, "getSlot")
if isinstance(result, int):
return result
return None
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
"""Get all RPC node addresses from getClusterNodes."""
result: object | None = rpc_post(rpc_url, "getClusterNodes")
if not isinstance(result, list):
return []
rpc_addrs: list[str] = []
for node in result:
if not isinstance(node, dict):
continue
if version_filter is not None:
node_version: str | None = node.get("version")
if node_version and not node_version.startswith(version_filter):
continue
rpc: str | None = node.get("rpc")
if rpc:
rpc_addrs.append(rpc)
return list(set(rpc_addrs))
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
"""Extract filename and full redirect path from Location header.
Returns (filename, full_path). full_path includes any path prefix
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
"""
# Location may be absolute URL or relative path
if location.startswith("http://") or location.startswith("https://"):
# Absolute URL — extract path
from urllib.parse import urlparse
path: str = urlparse(location).path
else:
path = location
filename: str = path.rsplit("/", 1)[-1]
return filename, path
def probe_rpc_snapshot(
rpc_address: str,
current_slot: int,
) -> SnapshotSource | None:
"""Probe a single RPC node for available snapshots.
Discovery only no filtering. Returns a SnapshotSource with all available
info so the caller can decide what to keep. Filtering happens after all
probes complete, so rejected sources are still visible for debugging.
"""
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
# Full snapshot is required — every source must have one
full_location, full_latency = head_no_follow(full_url, timeout=2)
if not full_location:
return None
latency_ms: float = full_latency * 1000
full_filename, full_path = _parse_snapshot_filename(full_location)
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
if not fm:
return None
full_snap_slot: int = int(fm.group(1))
slots_diff: int = current_slot - full_snap_slot
file_paths: list[str] = [full_path]
# Also check for incremental snapshot
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
inc_location, _ = head_no_follow(inc_url, timeout=2)
if inc_location:
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
if m:
inc_base_slot: int = int(m.group(1))
# Incremental must be based on this source's full snapshot
if inc_base_slot == full_snap_slot:
file_paths.append(inc_path)
return SnapshotSource(
rpc_address=rpc_address,
file_paths=file_paths,
slots_diff=slots_diff,
latency_ms=latency_ms,
)
def discover_sources(
rpc_url: str,
current_slot: int,
max_age_slots: int,
max_latency_ms: float,
threads: int,
version_filter: str | None,
) -> list[SnapshotSource]:
"""Discover all snapshot sources, then filter.
Probing and filtering are separate: all reachable sources are collected
first so we can report what exists even if filters reject everything.
"""
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
if not rpc_nodes:
log.error("No RPC nodes found via getClusterNodes")
return []
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
all_sources: list[SnapshotSource] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
for addr in rpc_nodes
}
done: int = 0
for future in concurrent.futures.as_completed(futures):
done += 1
if done % 200 == 0:
log.info(" probed %d/%d nodes, %d reachable",
done, len(rpc_nodes), len(all_sources))
try:
result: SnapshotSource | None = future.result()
except (urllib.error.URLError, OSError, TimeoutError) as e:
log.debug("Probe failed for %s: %s", futures[future], e)
continue
if result:
all_sources.append(result)
log.info("Discovered %d reachable sources", len(all_sources))
# Apply filters
filtered: list[SnapshotSource] = []
rejected_age: int = 0
rejected_latency: int = 0
for src in all_sources:
if src.slots_diff > max_age_slots or src.slots_diff < -100:
rejected_age += 1
continue
if src.latency_ms > max_latency_ms:
rejected_latency += 1
continue
filtered.append(src)
if rejected_age or rejected_latency:
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
if not filtered and all_sources:
# Show what was available so the user can adjust filters
all_sources.sort(key=lambda s: s.slots_diff)
best = all_sources[0]
log.warning("All %d sources rejected by filters. Best available: "
"%s (age=%d slots, latency=%.0fms). "
"Try --max-snapshot-age %d --max-latency %.0f",
len(all_sources), best.rpc_address,
best.slots_diff, best.latency_ms,
best.slots_diff + 500,
max(best.latency_ms * 1.5, 500))
log.info("Found %d sources after filtering", len(filtered))
return filtered
# -- Speed benchmark -----------------------------------------------------------
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
"""Measure download speed from an RPC node. Returns bytes/sec."""
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
req = Request(url)
try:
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
start: float = time.monotonic()
total: int = 0
while True:
elapsed: float = time.monotonic() - start
if elapsed >= measure_time:
break
chunk: bytes = resp.read(81920)
if not chunk:
break
total += len(chunk)
elapsed = time.monotonic() - start
if elapsed <= 0:
return 0.0
return total / elapsed
except (urllib.error.URLError, OSError, TimeoutError):
return 0.0
# -- Incremental probing -------------------------------------------------------
def probe_incremental(
fast_sources: list[SnapshotSource],
full_snap_slot: int,
) -> tuple[str | None, list[str]]:
"""Probe fast sources for the best incremental matching full_snap_slot.
Returns (filename, mirror_urls) or (None, []) if no match found.
The "best" incremental is the one with the highest slot (closest to head).
"""
best_filename: str | None = None
best_slot: int = 0
best_source: SnapshotSource | None = None
best_path: str | None = None
for source in fast_sources:
inc_url: str = f"http://{source.rpc_address}/incremental-snapshot.tar.bz2"
inc_location, _ = head_no_follow(inc_url, timeout=2)
if not inc_location:
continue
inc_fn, inc_fp = _parse_snapshot_filename(inc_location)
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn)
if not m:
continue
if int(m.group(1)) != full_snap_slot:
log.debug(" %s: incremental base slot %s != full %d, skipping",
source.rpc_address, m.group(1), full_snap_slot)
continue
inc_slot: int = int(m.group(2))
if inc_slot > best_slot:
best_slot = inc_slot
best_filename = inc_fn
best_source = source
best_path = inc_fp
if best_filename is None or best_source is None or best_path is None:
return None, []
# Build mirror list — check other sources for the same filename
mirror_urls: list[str] = [f"http://{best_source.rpc_address}{best_path}"]
for other in fast_sources:
if other.rpc_address == best_source.rpc_address:
continue
other_loc, _ = head_no_follow(
f"http://{other.rpc_address}/incremental-snapshot.tar.bz2", timeout=2)
if other_loc:
other_fn, other_fp = _parse_snapshot_filename(other_loc)
if other_fn == best_filename:
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
return best_filename, mirror_urls
# -- Download ------------------------------------------------------------------
def download_aria2c(
urls: list[str],
output_dir: str,
filename: str,
connections: int = 16,
) -> bool:
"""Download a file using aria2c with parallel connections.
When multiple URLs are provided, aria2c treats them as mirrors of the
same file and distributes chunks across all of them.
"""
num_mirrors: int = len(urls)
total_splits: int = max(connections, connections * num_mirrors)
cmd: list[str] = [
"aria2c",
"--file-allocation=none",
"--continue=false",
f"--max-connection-per-server={connections}",
f"--split={total_splits}",
"--min-split-size=50M",
# aria2c retries individual chunk connections on transient network
# errors (TCP reset, timeout). This is transport-level retry analogous
# to TCP retransmit, not application-level retry of a failed operation.
"--max-tries=5",
"--retry-wait=5",
"--timeout=60",
"--connect-timeout=10",
"--summary-interval=10",
"--console-log-level=notice",
f"--dir={output_dir}",
f"--out={filename}",
"--auto-file-renaming=false",
"--allow-overwrite=true",
*urls,
]
log.info("Downloading %s", filename)
log.info(" aria2c: %d connections x %d mirrors (%d splits)",
connections, num_mirrors, total_splits)
start: float = time.monotonic()
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
elapsed: float = time.monotonic() - start
if result.returncode != 0:
log.error("aria2c failed with exit code %d", result.returncode)
return False
filepath: Path = Path(output_dir) / filename
if not filepath.exists():
log.error("aria2c reported success but %s does not exist", filepath)
return False
size_bytes: int = filepath.stat().st_size
size_gb: float = size_bytes / (1024 ** 3)
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
return True
# -- Shared helpers ------------------------------------------------------------
def _discover_and_benchmark(
rpc_url: str,
current_slot: int,
*,
max_snapshot_age: int = 10000,
max_latency: float = 500,
threads: int = 500,
min_download_speed: int = 20,
measurement_time: int = 7,
max_speed_checks: int = 15,
version_filter: str | None = None,
) -> list[SnapshotSource]:
"""Discover snapshot sources and benchmark download speed.
Returns sources that meet the minimum speed requirement, sorted by speed.
"""
sources: list[SnapshotSource] = discover_sources(
rpc_url, current_slot,
max_age_slots=max_snapshot_age,
max_latency_ms=max_latency,
threads=threads,
version_filter=version_filter,
)
if not sources:
return []
sources.sort(key=lambda s: s.latency_ms)
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
fast_sources: list[SnapshotSource] = []
checked: int = 0
min_speed_bytes: int = min_download_speed * 1024 * 1024
for source in sources:
if checked >= max_speed_checks:
break
checked += 1
speed: float = measure_speed(source.rpc_address, measurement_time)
source.download_speed = speed
speed_mib: float = speed / (1024 ** 2)
if speed < min_speed_bytes:
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
source.rpc_address, speed_mib, min_download_speed)
continue
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
source.rpc_address, speed_mib,
source.latency_ms, source.slots_diff)
fast_sources.append(source)
return fast_sources
def _rolling_incremental_download(
fast_sources: list[SnapshotSource],
full_snap_slot: int,
output_dir: str,
convergence_slots: int,
connections: int,
rpc_url: str,
) -> str | None:
"""Download incrementals in a loop until converged.
Probes fast_sources for incrementals matching full_snap_slot, downloads
the freshest one, then re-probes until the gap to head is within
convergence_slots. Returns the filename of the final incremental,
or None if no incremental was found.
"""
prev_inc_filename: str | None = None
loop_start: float = time.monotonic()
max_convergence_time: float = 1800.0 # 30 min wall-clock limit
while True:
if time.monotonic() - loop_start > max_convergence_time:
if prev_inc_filename:
log.warning("Convergence timeout (%.0fs) — using %s",
max_convergence_time, prev_inc_filename)
else:
log.warning("Convergence timeout (%.0fs) — no incremental downloaded",
max_convergence_time)
break
inc_fn, inc_mirrors = probe_incremental(fast_sources, full_snap_slot)
if inc_fn is None:
if prev_inc_filename is None:
log.error("No matching incremental found for base slot %d",
full_snap_slot)
else:
log.info("No newer incremental available, using %s", prev_inc_filename)
break
m_inc: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn)
assert m_inc is not None
inc_slot: int = int(m_inc.group(2))
head_slot: int | None = get_current_slot(rpc_url)
if head_slot is None:
log.warning("Cannot get current slot — downloading best available incremental")
gap: int = convergence_slots + 1
else:
gap = head_slot - inc_slot
if inc_fn == prev_inc_filename:
if gap <= convergence_slots:
log.info("Incremental %s already downloaded (gap %d slots, converged)",
inc_fn, gap)
break
log.info("No newer incremental yet (slot %d, gap %d slots), waiting...",
inc_slot, gap)
time.sleep(10)
continue
if prev_inc_filename is not None:
old_path: Path = Path(output_dir) / prev_inc_filename
if old_path.exists():
log.info("Removing superseded incremental %s", prev_inc_filename)
old_path.unlink()
log.info("Downloading incremental %s (%d mirrors, slot %d, gap %d slots)",
inc_fn, len(inc_mirrors), inc_slot, gap)
if not download_aria2c(inc_mirrors, output_dir, inc_fn, connections):
log.warning("Failed to download incremental %s — re-probing in 10s", inc_fn)
time.sleep(10)
continue
prev_inc_filename = inc_fn
if gap <= convergence_slots:
log.info("Converged: incremental slot %d is %d slots behind head",
inc_slot, gap)
break
if head_slot is None:
break
log.info("Not converged (gap %d > %d), re-probing in 10s...",
gap, convergence_slots)
time.sleep(10)
return prev_inc_filename
# -- Public API ----------------------------------------------------------------
def download_incremental_for_slot(
output_dir: str,
full_snap_slot: int,
*,
cluster: str = "mainnet-beta",
rpc_url: str | None = None,
connections: int = 16,
threads: int = 500,
max_snapshot_age: int = 10000,
max_latency: float = 500,
min_download_speed: int = 20,
measurement_time: int = 7,
max_speed_checks: int = 15,
version_filter: str | None = None,
convergence_slots: int = 500,
) -> bool:
"""Download an incremental snapshot for an existing full snapshot.
Discovers sources, benchmarks speed, then runs the rolling incremental
download loop for the given full snapshot base slot. Does NOT download
a full snapshot.
Returns True if an incremental was downloaded, False otherwise.
"""
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
if not shutil.which("aria2c"):
log.error("aria2c not found. Install with: apt install aria2")
return False
log.info("Incremental download for base slot %d", full_snap_slot)
current_slot: int | None = get_current_slot(resolved_rpc)
if current_slot is None:
log.error("Cannot get current slot from %s", resolved_rpc)
return False
fast_sources: list[SnapshotSource] = _discover_and_benchmark(
resolved_rpc, current_slot,
max_snapshot_age=max_snapshot_age,
max_latency=max_latency,
threads=threads,
min_download_speed=min_download_speed,
measurement_time=measurement_time,
max_speed_checks=max_speed_checks,
version_filter=version_filter,
)
if not fast_sources:
log.error("No fast sources found")
return False
os.makedirs(output_dir, exist_ok=True)
result: str | None = _rolling_incremental_download(
fast_sources, full_snap_slot, output_dir,
convergence_slots, connections, resolved_rpc,
)
return result is not None
def download_best_snapshot(
output_dir: str,
*,
cluster: str = "mainnet-beta",
rpc_url: str | None = None,
connections: int = 16,
threads: int = 500,
max_snapshot_age: int = 10000,
max_latency: float = 500,
min_download_speed: int = 20,
measurement_time: int = 7,
max_speed_checks: int = 15,
version_filter: str | None = None,
full_only: bool = False,
convergence_slots: int = 500,
) -> bool:
"""Download the best available snapshot to output_dir.
This is the programmatic API called by entrypoint.py for automatic
snapshot download. Returns True on success, False on failure.
All parameters have sensible defaults matching the CLI interface.
"""
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
if not shutil.which("aria2c"):
log.error("aria2c not found. Install with: apt install aria2")
return False
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
current_slot: int | None = get_current_slot(resolved_rpc)
if current_slot is None:
log.error("Cannot get current slot from %s", resolved_rpc)
return False
log.info("Current slot: %d", current_slot)
fast_sources: list[SnapshotSource] = _discover_and_benchmark(
resolved_rpc, current_slot,
max_snapshot_age=max_snapshot_age,
max_latency=max_latency,
threads=threads,
min_download_speed=min_download_speed,
measurement_time=measurement_time,
max_speed_checks=max_speed_checks,
version_filter=version_filter,
)
if not fast_sources:
log.error("No fast sources found")
return False
# Use the fastest source as primary, build full snapshot download plan
best: SnapshotSource = fast_sources[0]
full_paths: list[str] = [fp for fp in best.file_paths
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
if not full_paths:
log.error("Best source has no full snapshot")
return False
# Build mirror URLs for the full snapshot
full_filename: str = full_paths[0].rsplit("/", 1)[-1]
full_mirrors: list[str] = [f"http://{best.rpc_address}{full_paths[0]}"]
for other in fast_sources[1:]:
for other_fp in other.file_paths:
if other_fp.rsplit("/", 1)[-1] == full_filename:
full_mirrors.append(f"http://{other.rpc_address}{other_fp}")
break
speed_mib: float = best.download_speed / (1024 ** 2)
log.info("Best source: %s (%.1f MiB/s), %d mirrors",
best.rpc_address, speed_mib, len(full_mirrors))
# Download full snapshot
os.makedirs(output_dir, exist_ok=True)
total_start: float = time.monotonic()
filepath: Path = Path(output_dir) / full_filename
if filepath.exists() and filepath.stat().st_size > 0:
log.info("Skipping %s (already exists: %.1f GB)",
full_filename, filepath.stat().st_size / (1024 ** 3))
else:
if not download_aria2c(full_mirrors, output_dir, full_filename, connections):
log.error("Failed to download %s", full_filename)
return False
# Download incremental separately — the full download took minutes,
# so any incremental from discovery is stale. Re-probe for fresh ones.
if not full_only:
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
if fm:
full_snap_slot: int = int(fm.group(1))
log.info("Downloading incremental for base slot %d...", full_snap_slot)
_rolling_incremental_download(
fast_sources, full_snap_slot, output_dir,
convergence_slots, connections, resolved_rpc,
)
total_elapsed: float = time.monotonic() - total_start
log.info("All downloads complete in %.0fs", total_elapsed)
return True
# -- Main (CLI) ----------------------------------------------------------------
def main() -> int:
p: argparse.ArgumentParser = argparse.ArgumentParser(
description="Download Solana snapshots with aria2c parallel downloads",
)
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
p.add_argument("-c", "--cluster", default="mainnet-beta",
choices=list(CLUSTER_RPC),
help="Solana cluster (default: mainnet-beta)")
p.add_argument("-r", "--rpc", default=None,
help="RPC URL for cluster discovery (default: public RPC)")
p.add_argument("-n", "--connections", type=int, default=16,
help="aria2c connections per download (default: 16)")
p.add_argument("-t", "--threads", type=int, default=500,
help="Threads for parallel RPC probing (default: 500)")
p.add_argument("--max-snapshot-age", type=int, default=10000,
help="Max snapshot age in slots (default: 10000)")
p.add_argument("--max-latency", type=float, default=500,
help="Max RPC probe latency in ms (default: 500)")
p.add_argument("--min-download-speed", type=int, default=20,
help="Min download speed in MiB/s (default: 20)")
p.add_argument("--measurement-time", type=int, default=7,
help="Speed measurement duration in seconds (default: 7)")
p.add_argument("--max-speed-checks", type=int, default=15,
help="Max nodes to benchmark before giving up (default: 15)")
p.add_argument("--version", default=None,
help="Filter nodes by version prefix (e.g. '2.2')")
p.add_argument("--convergence-slots", type=int, default=500,
help="Max slot gap for incremental convergence (default: 500)")
p.add_argument("--full-only", action="store_true",
help="Download only full snapshot, skip incremental")
p.add_argument("--dry-run", action="store_true",
help="Find best source and print URL, don't download")
p.add_argument("--post-cmd",
help="Shell command to run after successful download "
"(e.g. 'kubectl scale deployment ... --replicas=1')")
p.add_argument("-v", "--verbose", action="store_true")
args: argparse.Namespace = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
# Dry-run uses the original inline flow (needs access to sources for URL printing)
if args.dry_run:
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
current_slot: int | None = get_current_slot(rpc_url)
if current_slot is None:
log.error("Cannot get current slot from %s", rpc_url)
return 1
sources: list[SnapshotSource] = discover_sources(
rpc_url, current_slot,
max_age_slots=args.max_snapshot_age,
max_latency_ms=args.max_latency,
threads=args.threads,
version_filter=args.version,
)
if not sources:
log.error("No snapshot sources found")
return 1
sources.sort(key=lambda s: s.latency_ms)
best = sources[0]
for fp in best.file_paths:
print(f"http://{best.rpc_address}{fp}")
return 0
ok: bool = download_best_snapshot(
args.output,
cluster=args.cluster,
rpc_url=args.rpc,
connections=args.connections,
threads=args.threads,
max_snapshot_age=args.max_snapshot_age,
max_latency=args.max_latency,
min_download_speed=args.min_download_speed,
measurement_time=args.measurement_time,
max_speed_checks=args.max_speed_checks,
version_filter=args.version,
full_only=args.full_only,
convergence_slots=args.convergence_slots,
)
if ok and args.post_cmd:
log.info("Running post-download command: %s", args.post_cmd)
result: subprocess.CompletedProcess[bytes] = subprocess.run(
args.post_cmd, shell=True,
)
if result.returncode != 0:
log.error("Post-download command failed with exit code %d",
result.returncode)
return 1
log.info("Post-download command completed successfully")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,112 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# -----------------------------------------------------------------------
# Start solana-test-validator with optional SPL token setup
#
# Environment variables:
# FACILITATOR_PUBKEY - facilitator fee-payer public key (base58)
# SERVER_PUBKEY - server/payee wallet public key (base58)
# CLIENT_PUBKEY - client/payer wallet public key (base58)
# MINT_DECIMALS - token decimals (default: 6, matching USDC)
# MINT_AMOUNT - amount to mint to client (default: 1000000000)
# LEDGER_DIR - ledger directory (default: /data/ledger)
# -----------------------------------------------------------------------
LEDGER_DIR="${LEDGER_DIR:-/data/ledger}"
MINT_DECIMALS="${MINT_DECIMALS:-6}"
MINT_AMOUNT="${MINT_AMOUNT:-1000000000}"
SETUP_MARKER="${LEDGER_DIR}/.setup-done"
sudo chown -R "$(id -u):$(id -g)" "$LEDGER_DIR" 2>/dev/null || true
# Start test-validator in the background
solana-test-validator \
--ledger "${LEDGER_DIR}" \
--rpc-port 8899 \
--bind-address 0.0.0.0 \
--quiet &
VALIDATOR_PID=$!
# Wait for RPC to become available
echo "Waiting for test-validator RPC..."
for i in $(seq 1 60); do
if solana cluster-version --url http://127.0.0.1:8899 >/dev/null 2>&1; then
echo "Test-validator is ready (attempt ${i})"
break
fi
sleep 1
done
solana config set --url http://127.0.0.1:8899
# Only run setup once (idempotent via marker file)
if [ ! -f "${SETUP_MARKER}" ]; then
echo "Running first-time setup..."
# Airdrop SOL to all wallets for gas
for PUBKEY in "${FACILITATOR_PUBKEY:-}" "${SERVER_PUBKEY:-}" "${CLIENT_PUBKEY:-}"; do
if [ -n "${PUBKEY}" ]; then
echo "Airdropping 100 SOL to ${PUBKEY}..."
solana airdrop 100 "${PUBKEY}" --url http://127.0.0.1:8899 || true
fi
done
# Create a USDC-equivalent SPL token mint if any pubkeys are set
if [ -n "${CLIENT_PUBKEY:-}" ] || [ -n "${FACILITATOR_PUBKEY:-}" ] || [ -n "${SERVER_PUBKEY:-}" ]; then
MINT_AUTHORITY_FILE="${LEDGER_DIR}/mint-authority.json"
if [ ! -f "${MINT_AUTHORITY_FILE}" ]; then
solana-keygen new --no-bip39-passphrase --outfile "${MINT_AUTHORITY_FILE}" --force
MINT_AUTH_PUBKEY=$(solana-keygen pubkey "${MINT_AUTHORITY_FILE}")
solana airdrop 10 "${MINT_AUTH_PUBKEY}" --url http://127.0.0.1:8899
fi
MINT_ADDRESS_FILE="${LEDGER_DIR}/usdc-mint-address.txt"
if [ ! -f "${MINT_ADDRESS_FILE}" ]; then
spl-token create-token \
--decimals "${MINT_DECIMALS}" \
--mint-authority "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 \
2>&1 | grep "Creating token" | awk '{print $3}' > "${MINT_ADDRESS_FILE}"
echo "Created USDC mint: $(cat "${MINT_ADDRESS_FILE}")"
fi
USDC_MINT=$(cat "${MINT_ADDRESS_FILE}")
# Create ATAs and mint tokens for the client
if [ -n "${CLIENT_PUBKEY:-}" ]; then
echo "Creating ATA for client ${CLIENT_PUBKEY}..."
spl-token create-account "${USDC_MINT}" \
--owner "${CLIENT_PUBKEY}" \
--fee-payer "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
echo "Minting ${MINT_AMOUNT} tokens to client..."
spl-token mint "${USDC_MINT}" "${MINT_AMOUNT}" \
--recipient-owner "${CLIENT_PUBKEY}" \
--mint-authority "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
fi
# Create ATAs for server and facilitator
for PUBKEY in "${SERVER_PUBKEY:-}" "${FACILITATOR_PUBKEY:-}"; do
if [ -n "${PUBKEY}" ]; then
echo "Creating ATA for ${PUBKEY}..."
spl-token create-account "${USDC_MINT}" \
--owner "${PUBKEY}" \
--fee-payer "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
fi
done
# Expose mint address for other containers
cp "${MINT_ADDRESS_FILE}" /tmp/usdc-mint-address.txt 2>/dev/null || true
fi
touch "${SETUP_MARKER}"
echo "Setup complete."
fi
echo "solana-test-validator running (PID ${VALIDATOR_PID})"
wait ${VALIDATOR_PID}