Squashed 'scripts/agave-container/' content from commit 4b5c875

git-subtree-dir: scripts/agave-container
git-subtree-split: 4b5c875a05cbbfbde38eeb053fd5443a8a50228c
fix/kind-mount-propagation
A. F. Dudley 2026-03-08 19:13:38 +00:00
commit f4b3a46109
5 changed files with 1336 additions and 0 deletions

81
Dockerfile 100644
View File

@ -0,0 +1,81 @@
# Unified Agave/Jito Solana image
# Supports three modes via AGAVE_MODE env: test, rpc, validator
#
# Build args:
# AGAVE_REPO - git repo URL (anza-xyz/agave or jito-foundation/jito-solana)
# AGAVE_VERSION - git tag to build (e.g. v3.1.9, v3.1.8-jito)
ARG AGAVE_REPO=https://github.com/anza-xyz/agave.git
ARG AGAVE_VERSION=v3.1.9
# ---------- Stage 1: Build ----------
FROM rust:1.85-bookworm AS builder
ARG AGAVE_REPO
ARG AGAVE_VERSION
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
pkg-config \
libssl-dev \
libudev-dev \
libclang-dev \
protobuf-compiler \
ca-certificates \
git \
cmake \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
RUN git clone "$AGAVE_REPO" --depth 1 --branch "$AGAVE_VERSION" --recurse-submodules agave
WORKDIR /build/agave
# Cherry-pick --public-tvu-address support (anza-xyz/agave PR #6778, commit 9f4b3ae)
# This flag only exists on master, not in v3.1.9 — fetch the PR ref and cherry-pick
ARG TVU_ADDRESS_PR=6778
RUN if [ -n "$TVU_ADDRESS_PR" ]; then \
git fetch --depth 50 origin "pull/${TVU_ADDRESS_PR}/head:tvu-pr" && \
git cherry-pick --no-commit tvu-pr; \
fi
# Build all binaries using the upstream install script
RUN CI_COMMIT=$(git rev-parse HEAD) scripts/cargo-install-all.sh /solana-release
# ---------- Stage 2: Runtime ----------
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libssl3 \
libudev1 \
curl \
sudo \
aria2 \
python3 \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user with sudo
RUN useradd -m -s /bin/bash agave \
&& echo "agave ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
# Copy all compiled binaries
COPY --from=builder /solana-release/bin/ /usr/local/bin/
# Copy entrypoint and support scripts
COPY entrypoint.py snapshot_download.py /usr/local/bin/
COPY start-test.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.py /usr/local/bin/start-test.sh
# Create data directories
RUN mkdir -p /data/config /data/ledger /data/accounts /data/snapshots \
&& chown -R agave:agave /data
USER agave
WORKDIR /data
ENV RUST_LOG=info
ENV RUST_BACKTRACE=1
EXPOSE 8899 8900 8001 8001/udp
ENTRYPOINT ["entrypoint.py"]

17
build.sh 100644
View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
# Build laconicnetwork/agave
# Set AGAVE_REPO and AGAVE_VERSION env vars to build Jito or a different version
source ${CERC_CONTAINER_BASE_DIR}/build-base.sh
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
AGAVE_REPO="${AGAVE_REPO:-https://github.com/anza-xyz/agave.git}"
AGAVE_VERSION="${AGAVE_VERSION:-v3.1.9}"
docker build -t laconicnetwork/agave:local \
--build-arg AGAVE_REPO="$AGAVE_REPO" \
--build-arg AGAVE_VERSION="$AGAVE_VERSION" \
${build_command_args} \
-f ${SCRIPT_DIR}/Dockerfile \
${SCRIPT_DIR}

485
entrypoint.py 100644
View File

@ -0,0 +1,485 @@
#!/usr/bin/env python3
"""Agave validator entrypoint — snapshot management, arg construction, liveness probe.
Two subcommands:
entrypoint.py serve (default) snapshot freshness check + exec agave-validator
entrypoint.py probe liveness probe (slot lag check, exits 0/1)
Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single
Python module. Test mode still dispatches to start-test.sh.
All configuration comes from environment variables same vars as the original
bash scripts. See compose files for defaults.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("entrypoint")
# Directories
CONFIG_DIR = "/data/config"
LEDGER_DIR = "/data/ledger"
ACCOUNTS_DIR = "/data/accounts"
SNAPSHOTS_DIR = "/data/snapshots"
LOG_DIR = "/data/log"
IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json"
# Snapshot filename pattern
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
)
MAINNET_RPC = "https://api.mainnet-beta.solana.com"
# -- Helpers -------------------------------------------------------------------
def env(name: str, default: str = "") -> str:
"""Read env var with default."""
return os.environ.get(name, default)
def env_required(name: str) -> str:
"""Read required env var, exit if missing."""
val = os.environ.get(name)
if not val:
log.error("%s is required but not set", name)
sys.exit(1)
return val
def env_bool(name: str, default: bool = False) -> bool:
"""Read boolean env var (true/false/1/0)."""
val = os.environ.get(name, "").lower()
if not val:
return default
return val in ("true", "1", "yes")
def rpc_get_slot(url: str, timeout: int = 10) -> int | None:
"""Get current slot from a Solana RPC endpoint."""
payload = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": "getSlot", "params": [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read())
result = data.get("result")
if isinstance(result, int):
return result
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError):
pass
return None
# -- Snapshot management -------------------------------------------------------
def get_local_snapshot_slot(snapshots_dir: str) -> int | None:
"""Find the highest slot among local snapshot files."""
best_slot: int | None = None
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return None
for entry in snap_path.iterdir():
m = FULL_SNAP_RE.match(entry.name)
if m:
slot = int(m.group(1))
if best_slot is None or slot > best_slot:
best_slot = slot
return best_slot
def clean_snapshots(snapshots_dir: str) -> None:
"""Remove all snapshot files from the directory."""
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return
for entry in snap_path.iterdir():
if entry.name.startswith(("snapshot-", "incremental-snapshot-")):
log.info("Removing old snapshot: %s", entry.name)
entry.unlink(missing_ok=True)
def maybe_download_snapshot(snapshots_dir: str) -> None:
"""Check snapshot freshness and download if needed.
Controlled by env vars:
SNAPSHOT_AUTO_DOWNLOAD (default: true) enable/disable
SNAPSHOT_MAX_AGE_SLOTS (default: 20000) staleness threshold
"""
if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True):
log.info("Snapshot auto-download disabled")
return
max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "20000"))
# Get mainnet current slot
mainnet_slot = rpc_get_slot(MAINNET_RPC)
if mainnet_slot is None:
log.warning("Cannot reach mainnet RPC — skipping snapshot check")
return
# Check local snapshot
local_slot = get_local_snapshot_slot(snapshots_dir)
if local_slot is not None:
age = mainnet_slot - local_slot
log.info("Local snapshot at slot %d (mainnet: %d, age: %d slots)",
local_slot, mainnet_slot, age)
if age <= max_age:
log.info("Snapshot is fresh enough (age %d <= %d), skipping download", age, max_age)
return
log.info("Snapshot is stale (age %d > %d), downloading fresh", age, max_age)
else:
log.info("No local snapshot found, downloading")
# Clean old snapshots before downloading
clean_snapshots(snapshots_dir)
# Import and call snapshot download
# snapshot_download.py is installed alongside this file in /usr/local/bin/
script_dir = Path(__file__).resolve().parent
sys.path.insert(0, str(script_dir))
from snapshot_download import download_best_snapshot
ok = download_best_snapshot(snapshots_dir)
if not ok:
log.error("Snapshot download failed — starting without fresh snapshot")
# -- Directory and identity setup ----------------------------------------------
def ensure_dirs(*dirs: str) -> None:
"""Create directories and fix ownership."""
uid = os.getuid()
gid = os.getgid()
for d in dirs:
os.makedirs(d, exist_ok=True)
try:
subprocess.run(
["sudo", "chown", "-R", f"{uid}:{gid}", d],
check=False, capture_output=True,
)
except FileNotFoundError:
pass # sudo not available — dirs already owned correctly
def ensure_identity_rpc() -> None:
"""Generate ephemeral identity keypair for RPC mode if not mounted."""
if os.path.isfile(IDENTITY_FILE):
return
log.info("Generating RPC node identity keypair...")
subprocess.run(
["solana-keygen", "new", "--no-passphrase", "--silent",
"--force", "--outfile", IDENTITY_FILE],
check=True,
)
def print_identity() -> None:
"""Print the node identity pubkey."""
result = subprocess.run(
["solana-keygen", "pubkey", IDENTITY_FILE],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Node identity: %s", result.stdout.strip())
# -- Arg construction ----------------------------------------------------------
def build_common_args() -> list[str]:
"""Build agave-validator args common to both RPC and validator modes."""
args: list[str] = [
"--identity", IDENTITY_FILE,
"--entrypoint", env_required("VALIDATOR_ENTRYPOINT"),
"--known-validator", env_required("KNOWN_VALIDATOR"),
"--ledger", LEDGER_DIR,
"--accounts", ACCOUNTS_DIR,
"--snapshots", SNAPSHOTS_DIR,
"--rpc-port", env("RPC_PORT", "8899"),
"--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"),
"--gossip-port", env("GOSSIP_PORT", "8001"),
"--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"),
"--no-os-network-limits-test",
"--wal-recovery-mode", "skip_any_corrupted_record",
"--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"),
]
# Snapshot generation
if env("NO_SNAPSHOTS") == "true":
args.append("--no-snapshots")
else:
args += [
"--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"),
"--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "5"),
]
if env("NO_INCREMENTAL_SNAPSHOTS") != "true":
args += ["--maximum-incremental-snapshots-to-retain", "2"]
# Account indexes
account_indexes = env("ACCOUNT_INDEXES")
if account_indexes:
for idx in account_indexes.split(","):
idx = idx.strip()
if idx:
args += ["--account-index", idx]
# Additional entrypoints
for ep in env("EXTRA_ENTRYPOINTS").split():
if ep:
args += ["--entrypoint", ep]
# Additional known validators
for kv in env("EXTRA_KNOWN_VALIDATORS").split():
if kv:
args += ["--known-validator", kv]
# Cluster verification
genesis_hash = env("EXPECTED_GENESIS_HASH")
if genesis_hash:
args += ["--expected-genesis-hash", genesis_hash]
shred_version = env("EXPECTED_SHRED_VERSION")
if shred_version:
args += ["--expected-shred-version", shred_version]
# Metrics — just needs to be in the environment, agave reads it directly
# (env var is already set, nothing to pass as arg)
# Gossip host / TVU address
gossip_host = env("GOSSIP_HOST")
if gossip_host:
args += ["--gossip-host", gossip_host]
elif env("PUBLIC_TVU_ADDRESS"):
args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")]
# Jito flags
if env("JITO_ENABLE") == "true":
log.info("Jito MEV enabled")
jito_flags: list[tuple[str, str]] = [
("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"),
("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"),
("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"),
("JITO_COMMISSION_BPS", "--commission-bps"),
("JITO_BLOCK_ENGINE_URL", "--block-engine-url"),
("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"),
]
for env_name, flag in jito_flags:
val = env(env_name)
if val:
args += [flag, val]
return args
def build_rpc_args() -> list[str]:
"""Build agave-validator args for RPC (non-voting) mode."""
args = build_common_args()
args += [
"--no-voting",
"--log", f"{LOG_DIR}/validator.log",
"--full-rpc-api",
"--enable-rpc-transaction-history",
"--rpc-pubsub-enable-block-subscription",
"--enable-extended-tx-metadata-storage",
"--no-wait-for-vote-to-start-leader",
"--no-snapshot-fetch",
]
# Public vs private RPC
public_rpc = env("PUBLIC_RPC_ADDRESS")
if public_rpc:
args += ["--public-rpc-address", public_rpc]
else:
args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"]
# Jito relayer URL (RPC mode doesn't use it, but validator mode does —
# handled in build_validator_args)
return args
def build_validator_args() -> list[str]:
"""Build agave-validator args for voting validator mode."""
vote_keypair = env("VOTE_ACCOUNT_KEYPAIR",
"/data/config/vote-account-keypair.json")
# Identity must be mounted for validator mode
if not os.path.isfile(IDENTITY_FILE):
log.error("Validator identity keypair not found at %s", IDENTITY_FILE)
log.error("Mount your validator keypair to %s", IDENTITY_FILE)
sys.exit(1)
# Vote account keypair must exist
if not os.path.isfile(vote_keypair):
log.error("Vote account keypair not found at %s", vote_keypair)
log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR")
sys.exit(1)
# Print vote account pubkey
result = subprocess.run(
["solana-keygen", "pubkey", vote_keypair],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Vote account: %s", result.stdout.strip())
args = build_common_args()
args += [
"--vote-account", vote_keypair,
"--log", "-",
]
# Jito relayer URL (validator-only)
relayer_url = env("JITO_RELAYER_URL")
if env("JITO_ENABLE") == "true" and relayer_url:
args += ["--relayer-url", relayer_url]
return args
def append_extra_args(args: list[str]) -> list[str]:
"""Append EXTRA_ARGS passthrough flags."""
extra = env("EXTRA_ARGS")
if extra:
args += extra.split()
return args
# -- Serve subcommand ---------------------------------------------------------
def cmd_serve() -> None:
"""Main serve flow: snapshot check, setup, exec agave-validator."""
mode = env("AGAVE_MODE", "test")
log.info("AGAVE_MODE=%s", mode)
# Test mode dispatches to start-test.sh
if mode == "test":
os.execvp("start-test.sh", ["start-test.sh"])
if mode not in ("rpc", "validator"):
log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode)
sys.exit(1)
# Ensure directories
dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR]
if mode == "rpc":
dirs.append(LOG_DIR)
ensure_dirs(*dirs)
# Snapshot freshness check and auto-download
maybe_download_snapshot(SNAPSHOTS_DIR)
# Identity setup
if mode == "rpc":
ensure_identity_rpc()
print_identity()
# Build args
if mode == "rpc":
args = build_rpc_args()
else:
args = build_validator_args()
args = append_extra_args(args)
# Write startup timestamp for probe grace period
Path("/tmp/entrypoint-start").write_text(str(time.time()))
log.info("Starting agave-validator with %d arguments", len(args))
os.execvp("agave-validator", ["agave-validator"] + args)
# -- Probe subcommand ---------------------------------------------------------
def cmd_probe() -> None:
"""Liveness probe: check local RPC slot vs mainnet.
Exit 0 = healthy, exit 1 = unhealthy.
Grace period: PROBE_GRACE_SECONDS (default 600) probe always passes
during grace period to allow for snapshot unpacking and initial replay.
"""
grace_seconds = int(env("PROBE_GRACE_SECONDS", "600"))
max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000"))
# Check grace period
start_file = Path("/tmp/entrypoint-start")
if start_file.exists():
try:
start_time = float(start_file.read_text().strip())
elapsed = time.time() - start_time
if elapsed < grace_seconds:
# Within grace period — always healthy
sys.exit(0)
except (ValueError, OSError):
pass
else:
# No start file — serve hasn't started yet, within grace
sys.exit(0)
# Query local RPC
rpc_port = env("RPC_PORT", "8899")
local_url = f"http://127.0.0.1:{rpc_port}"
local_slot = rpc_get_slot(local_url, timeout=5)
if local_slot is None:
# Local RPC unreachable after grace period — unhealthy
sys.exit(1)
# Query mainnet
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
if mainnet_slot is None:
# Can't reach mainnet to compare — assume healthy (don't penalize
# the validator for mainnet RPC being down)
sys.exit(0)
lag = mainnet_slot - local_slot
if lag > max_lag:
sys.exit(1)
sys.exit(0)
# -- Main ----------------------------------------------------------------------
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
datefmt="%H:%M:%S",
)
subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve"
if subcmd == "serve":
cmd_serve()
elif subcmd == "probe":
cmd_probe()
else:
log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,641 @@
#!/usr/bin/env python3
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
cluster, probing each for available snapshots, benchmarking download speed,
and downloading from the fastest source using aria2c (16 connections by default).
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
the single-connection wget download with aria2c parallel chunked downloads.
Usage:
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
./snapshot_download.py -o /srv/kind/solana/snapshots
# Dry run — find best source, print URL
./snapshot_download.py --dry-run
# Custom RPC for cluster discovery + 32 connections
./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32
# Testnet
./snapshot_download.py -c testnet -o /data/snapshots
# Programmatic use from entrypoint.py:
from snapshot_download import download_best_snapshot
ok = download_best_snapshot("/data/snapshots")
Requirements:
- aria2c (apt install aria2)
- python3 >= 3.10 (stdlib only, no pip dependencies)
"""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from http.client import HTTPResponse
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("snapshot-download")
CLUSTER_RPC: dict[str, str] = {
"mainnet-beta": "https://api.mainnet-beta.solana.com",
"testnet": "https://api.testnet.solana.com",
"devnet": "https://api.devnet.solana.com",
}
# Snapshot filenames:
# snapshot-<slot>-<hash>.tar.zst
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
INCR_SNAP_RE: re.Pattern[str] = re.compile(
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
)
@dataclass
class SnapshotSource:
"""A snapshot file available from a specific RPC node."""
rpc_address: str
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
file_paths: list[str] = field(default_factory=list)
slots_diff: int = 0
latency_ms: float = 0.0
download_speed: float = 0.0 # bytes/sec
# -- JSON-RPC helpers ----------------------------------------------------------
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Handler that captures redirect Location instead of following it."""
def redirect_request(
self,
req: Request,
fp: HTTPResponse,
code: int,
msg: str,
headers: dict[str, str], # type: ignore[override]
newurl: str,
) -> None:
return None
def rpc_post(url: str, method: str, params: list[object] | None = None,
timeout: int = 25) -> object | None:
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
payload: bytes = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": method, "params": params or [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data: dict[str, object] = json.loads(resp.read())
return data.get("result")
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
log.debug("rpc_post %s %s failed: %s", url, method, e)
return None
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
"""HEAD request without following redirects.
Returns (Location header value, latency_sec) if the server returned a
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
"""
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
req = Request(url, method="HEAD")
try:
start: float = time.monotonic()
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
latency: float = time.monotonic() - start
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
location: str | None = resp.headers.get("Location")
resp.close()
return location, latency
except urllib.error.HTTPError as e:
# 3xx redirects raise HTTPError with the redirect info
latency = time.monotonic() - start # type: ignore[possibly-undefined]
location = e.headers.get("Location")
if location and 300 <= e.code < 400:
return location, latency
return None, 0.0
except (urllib.error.URLError, OSError, TimeoutError):
return None, 0.0
# -- Discovery -----------------------------------------------------------------
def get_current_slot(rpc_url: str) -> int | None:
"""Get current slot from RPC."""
result: object | None = rpc_post(rpc_url, "getSlot")
if isinstance(result, int):
return result
return None
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
"""Get all RPC node addresses from getClusterNodes."""
result: object | None = rpc_post(rpc_url, "getClusterNodes")
if not isinstance(result, list):
return []
rpc_addrs: list[str] = []
for node in result:
if not isinstance(node, dict):
continue
if version_filter is not None:
node_version: str | None = node.get("version")
if node_version and not node_version.startswith(version_filter):
continue
rpc: str | None = node.get("rpc")
if rpc:
rpc_addrs.append(rpc)
return list(set(rpc_addrs))
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
"""Extract filename and full redirect path from Location header.
Returns (filename, full_path). full_path includes any path prefix
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
"""
# Location may be absolute URL or relative path
if location.startswith("http://") or location.startswith("https://"):
# Absolute URL — extract path
from urllib.parse import urlparse
path: str = urlparse(location).path
else:
path = location
filename: str = path.rsplit("/", 1)[-1]
return filename, path
def probe_rpc_snapshot(
rpc_address: str,
current_slot: int,
) -> SnapshotSource | None:
"""Probe a single RPC node for available snapshots.
Discovery only no filtering. Returns a SnapshotSource with all available
info so the caller can decide what to keep. Filtering happens after all
probes complete, so rejected sources are still visible for debugging.
"""
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
# Full snapshot is required — every source must have one
full_location, full_latency = head_no_follow(full_url, timeout=2)
if not full_location:
return None
latency_ms: float = full_latency * 1000
full_filename, full_path = _parse_snapshot_filename(full_location)
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
if not fm:
return None
full_snap_slot: int = int(fm.group(1))
slots_diff: int = current_slot - full_snap_slot
file_paths: list[str] = [full_path]
# Also check for incremental snapshot
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
inc_location, _ = head_no_follow(inc_url, timeout=2)
if inc_location:
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
if m:
inc_base_slot: int = int(m.group(1))
# Incremental must be based on this source's full snapshot
if inc_base_slot == full_snap_slot:
file_paths.append(inc_path)
return SnapshotSource(
rpc_address=rpc_address,
file_paths=file_paths,
slots_diff=slots_diff,
latency_ms=latency_ms,
)
def discover_sources(
rpc_url: str,
current_slot: int,
max_age_slots: int,
max_latency_ms: float,
threads: int,
version_filter: str | None,
) -> list[SnapshotSource]:
"""Discover all snapshot sources, then filter.
Probing and filtering are separate: all reachable sources are collected
first so we can report what exists even if filters reject everything.
"""
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
if not rpc_nodes:
log.error("No RPC nodes found via getClusterNodes")
return []
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
all_sources: list[SnapshotSource] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
for addr in rpc_nodes
}
done: int = 0
for future in concurrent.futures.as_completed(futures):
done += 1
if done % 200 == 0:
log.info(" probed %d/%d nodes, %d reachable",
done, len(rpc_nodes), len(all_sources))
try:
result: SnapshotSource | None = future.result()
except (urllib.error.URLError, OSError, TimeoutError) as e:
log.debug("Probe failed for %s: %s", futures[future], e)
continue
if result:
all_sources.append(result)
log.info("Discovered %d reachable sources", len(all_sources))
# Apply filters
filtered: list[SnapshotSource] = []
rejected_age: int = 0
rejected_latency: int = 0
for src in all_sources:
if src.slots_diff > max_age_slots or src.slots_diff < -100:
rejected_age += 1
continue
if src.latency_ms > max_latency_ms:
rejected_latency += 1
continue
filtered.append(src)
if rejected_age or rejected_latency:
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
if not filtered and all_sources:
# Show what was available so the user can adjust filters
all_sources.sort(key=lambda s: s.slots_diff)
best = all_sources[0]
log.warning("All %d sources rejected by filters. Best available: "
"%s (age=%d slots, latency=%.0fms). "
"Try --max-snapshot-age %d --max-latency %.0f",
len(all_sources), best.rpc_address,
best.slots_diff, best.latency_ms,
best.slots_diff + 500,
max(best.latency_ms * 1.5, 500))
log.info("Found %d sources after filtering", len(filtered))
return filtered
# -- Speed benchmark -----------------------------------------------------------
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
"""Measure download speed from an RPC node. Returns bytes/sec."""
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
req = Request(url)
try:
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
start: float = time.monotonic()
total: int = 0
while True:
elapsed: float = time.monotonic() - start
if elapsed >= measure_time:
break
chunk: bytes = resp.read(81920)
if not chunk:
break
total += len(chunk)
elapsed = time.monotonic() - start
if elapsed <= 0:
return 0.0
return total / elapsed
except (urllib.error.URLError, OSError, TimeoutError):
return 0.0
# -- Download ------------------------------------------------------------------
def download_aria2c(
urls: list[str],
output_dir: str,
filename: str,
connections: int = 16,
) -> bool:
"""Download a file using aria2c with parallel connections.
When multiple URLs are provided, aria2c treats them as mirrors of the
same file and distributes chunks across all of them.
"""
num_mirrors: int = len(urls)
total_splits: int = max(connections, connections * num_mirrors)
cmd: list[str] = [
"aria2c",
"--file-allocation=none",
"--continue=false",
f"--max-connection-per-server={connections}",
f"--split={total_splits}",
"--min-split-size=50M",
# aria2c retries individual chunk connections on transient network
# errors (TCP reset, timeout). This is transport-level retry analogous
# to TCP retransmit, not application-level retry of a failed operation.
"--max-tries=5",
"--retry-wait=5",
"--timeout=60",
"--connect-timeout=10",
"--summary-interval=10",
"--console-log-level=notice",
f"--dir={output_dir}",
f"--out={filename}",
"--auto-file-renaming=false",
"--allow-overwrite=true",
*urls,
]
log.info("Downloading %s", filename)
log.info(" aria2c: %d connections x %d mirrors (%d splits)",
connections, num_mirrors, total_splits)
start: float = time.monotonic()
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
elapsed: float = time.monotonic() - start
if result.returncode != 0:
log.error("aria2c failed with exit code %d", result.returncode)
return False
filepath: Path = Path(output_dir) / filename
if not filepath.exists():
log.error("aria2c reported success but %s does not exist", filepath)
return False
size_bytes: int = filepath.stat().st_size
size_gb: float = size_bytes / (1024 ** 3)
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
return True
# -- Public API ----------------------------------------------------------------
def download_best_snapshot(
output_dir: str,
*,
cluster: str = "mainnet-beta",
rpc_url: str | None = None,
connections: int = 16,
threads: int = 500,
max_snapshot_age: int = 10000,
max_latency: float = 500,
min_download_speed: int = 20,
measurement_time: int = 7,
max_speed_checks: int = 15,
version_filter: str | None = None,
full_only: bool = False,
) -> bool:
"""Download the best available snapshot to output_dir.
This is the programmatic API called by entrypoint.py for automatic
snapshot download. Returns True on success, False on failure.
All parameters have sensible defaults matching the CLI interface.
"""
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
if not shutil.which("aria2c"):
log.error("aria2c not found. Install with: apt install aria2")
return False
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
current_slot: int | None = get_current_slot(resolved_rpc)
if current_slot is None:
log.error("Cannot get current slot from %s", resolved_rpc)
return False
log.info("Current slot: %d", current_slot)
sources: list[SnapshotSource] = discover_sources(
resolved_rpc, current_slot,
max_age_slots=max_snapshot_age,
max_latency_ms=max_latency,
threads=threads,
version_filter=version_filter,
)
if not sources:
log.error("No snapshot sources found")
return False
# Sort by latency (lowest first) for speed benchmarking
sources.sort(key=lambda s: s.latency_ms)
# Benchmark top candidates
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
fast_sources: list[SnapshotSource] = []
checked: int = 0
min_speed_bytes: int = min_download_speed * 1024 * 1024
for source in sources:
if checked >= max_speed_checks:
break
checked += 1
speed: float = measure_speed(source.rpc_address, measurement_time)
source.download_speed = speed
speed_mib: float = speed / (1024 ** 2)
if speed < min_speed_bytes:
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
source.rpc_address, speed_mib, min_download_speed)
continue
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
source.rpc_address, speed_mib,
source.latency_ms, source.slots_diff)
fast_sources.append(source)
if not fast_sources:
log.error("No source met minimum speed requirement (%d MiB/s)",
min_download_speed)
return False
# Use the fastest source as primary, collect mirrors for each file
best: SnapshotSource = fast_sources[0]
file_paths: list[str] = best.file_paths
if full_only:
file_paths = [fp for fp in file_paths
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
# Build mirror URL lists
download_plan: list[tuple[str, list[str]]] = []
for fp in file_paths:
filename: str = fp.rsplit("/", 1)[-1]
mirror_urls: list[str] = [f"http://{best.rpc_address}{fp}"]
for other in fast_sources[1:]:
for other_fp in other.file_paths:
if other_fp.rsplit("/", 1)[-1] == filename:
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
break
download_plan.append((filename, mirror_urls))
speed_mib: float = best.download_speed / (1024 ** 2)
log.info("Best source: %s (%.1f MiB/s), %d mirrors total",
best.rpc_address, speed_mib, len(fast_sources))
for filename, mirror_urls in download_plan:
log.info(" %s (%d mirrors)", filename, len(mirror_urls))
# Download
os.makedirs(output_dir, exist_ok=True)
total_start: float = time.monotonic()
for filename, mirror_urls in download_plan:
filepath: Path = Path(output_dir) / filename
if filepath.exists() and filepath.stat().st_size > 0:
log.info("Skipping %s (already exists: %.1f GB)",
filename, filepath.stat().st_size / (1024 ** 3))
continue
if not download_aria2c(mirror_urls, output_dir, filename, connections):
log.error("Failed to download %s", filename)
return False
total_elapsed: float = time.monotonic() - total_start
log.info("All downloads complete in %.0fs", total_elapsed)
for filename, _ in download_plan:
fp_path: Path = Path(output_dir) / filename
if fp_path.exists():
log.info(" %s (%.1f GB)", fp_path.name, fp_path.stat().st_size / (1024 ** 3))
return True
# -- Main (CLI) ----------------------------------------------------------------
def main() -> int:
p: argparse.ArgumentParser = argparse.ArgumentParser(
description="Download Solana snapshots with aria2c parallel downloads",
)
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
p.add_argument("-c", "--cluster", default="mainnet-beta",
choices=list(CLUSTER_RPC),
help="Solana cluster (default: mainnet-beta)")
p.add_argument("-r", "--rpc", default=None,
help="RPC URL for cluster discovery (default: public RPC)")
p.add_argument("-n", "--connections", type=int, default=16,
help="aria2c connections per download (default: 16)")
p.add_argument("-t", "--threads", type=int, default=500,
help="Threads for parallel RPC probing (default: 500)")
p.add_argument("--max-snapshot-age", type=int, default=10000,
help="Max snapshot age in slots (default: 10000)")
p.add_argument("--max-latency", type=float, default=500,
help="Max RPC probe latency in ms (default: 500)")
p.add_argument("--min-download-speed", type=int, default=20,
help="Min download speed in MiB/s (default: 20)")
p.add_argument("--measurement-time", type=int, default=7,
help="Speed measurement duration in seconds (default: 7)")
p.add_argument("--max-speed-checks", type=int, default=15,
help="Max nodes to benchmark before giving up (default: 15)")
p.add_argument("--version", default=None,
help="Filter nodes by version prefix (e.g. '2.2')")
p.add_argument("--full-only", action="store_true",
help="Download only full snapshot, skip incremental")
p.add_argument("--dry-run", action="store_true",
help="Find best source and print URL, don't download")
p.add_argument("--post-cmd",
help="Shell command to run after successful download "
"(e.g. 'kubectl scale deployment ... --replicas=1')")
p.add_argument("-v", "--verbose", action="store_true")
args: argparse.Namespace = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
# Dry-run uses the original inline flow (needs access to sources for URL printing)
if args.dry_run:
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
current_slot: int | None = get_current_slot(rpc_url)
if current_slot is None:
log.error("Cannot get current slot from %s", rpc_url)
return 1
sources: list[SnapshotSource] = discover_sources(
rpc_url, current_slot,
max_age_slots=args.max_snapshot_age,
max_latency_ms=args.max_latency,
threads=args.threads,
version_filter=args.version,
)
if not sources:
log.error("No snapshot sources found")
return 1
sources.sort(key=lambda s: s.latency_ms)
best = sources[0]
for fp in best.file_paths:
print(f"http://{best.rpc_address}{fp}")
return 0
ok: bool = download_best_snapshot(
args.output,
cluster=args.cluster,
rpc_url=args.rpc,
connections=args.connections,
threads=args.threads,
max_snapshot_age=args.max_snapshot_age,
max_latency=args.max_latency,
min_download_speed=args.min_download_speed,
measurement_time=args.measurement_time,
max_speed_checks=args.max_speed_checks,
version_filter=args.version,
full_only=args.full_only,
)
if ok and args.post_cmd:
log.info("Running post-download command: %s", args.post_cmd)
result: subprocess.CompletedProcess[bytes] = subprocess.run(
args.post_cmd, shell=True,
)
if result.returncode != 0:
log.error("Post-download command failed with exit code %d",
result.returncode)
return 1
log.info("Post-download command completed successfully")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())

112
start-test.sh 100644
View File

@ -0,0 +1,112 @@
#!/usr/bin/env bash
set -euo pipefail
# -----------------------------------------------------------------------
# Start solana-test-validator with optional SPL token setup
#
# Environment variables:
# FACILITATOR_PUBKEY - facilitator fee-payer public key (base58)
# SERVER_PUBKEY - server/payee wallet public key (base58)
# CLIENT_PUBKEY - client/payer wallet public key (base58)
# MINT_DECIMALS - token decimals (default: 6, matching USDC)
# MINT_AMOUNT - amount to mint to client (default: 1000000000)
# LEDGER_DIR - ledger directory (default: /data/ledger)
# -----------------------------------------------------------------------
LEDGER_DIR="${LEDGER_DIR:-/data/ledger}"
MINT_DECIMALS="${MINT_DECIMALS:-6}"
MINT_AMOUNT="${MINT_AMOUNT:-1000000000}"
SETUP_MARKER="${LEDGER_DIR}/.setup-done"
sudo chown -R "$(id -u):$(id -g)" "$LEDGER_DIR" 2>/dev/null || true
# Start test-validator in the background
solana-test-validator \
--ledger "${LEDGER_DIR}" \
--rpc-port 8899 \
--bind-address 0.0.0.0 \
--quiet &
VALIDATOR_PID=$!
# Wait for RPC to become available
echo "Waiting for test-validator RPC..."
for i in $(seq 1 60); do
if solana cluster-version --url http://127.0.0.1:8899 >/dev/null 2>&1; then
echo "Test-validator is ready (attempt ${i})"
break
fi
sleep 1
done
solana config set --url http://127.0.0.1:8899
# Only run setup once (idempotent via marker file)
if [ ! -f "${SETUP_MARKER}" ]; then
echo "Running first-time setup..."
# Airdrop SOL to all wallets for gas
for PUBKEY in "${FACILITATOR_PUBKEY:-}" "${SERVER_PUBKEY:-}" "${CLIENT_PUBKEY:-}"; do
if [ -n "${PUBKEY}" ]; then
echo "Airdropping 100 SOL to ${PUBKEY}..."
solana airdrop 100 "${PUBKEY}" --url http://127.0.0.1:8899 || true
fi
done
# Create a USDC-equivalent SPL token mint if any pubkeys are set
if [ -n "${CLIENT_PUBKEY:-}" ] || [ -n "${FACILITATOR_PUBKEY:-}" ] || [ -n "${SERVER_PUBKEY:-}" ]; then
MINT_AUTHORITY_FILE="${LEDGER_DIR}/mint-authority.json"
if [ ! -f "${MINT_AUTHORITY_FILE}" ]; then
solana-keygen new --no-bip39-passphrase --outfile "${MINT_AUTHORITY_FILE}" --force
MINT_AUTH_PUBKEY=$(solana-keygen pubkey "${MINT_AUTHORITY_FILE}")
solana airdrop 10 "${MINT_AUTH_PUBKEY}" --url http://127.0.0.1:8899
fi
MINT_ADDRESS_FILE="${LEDGER_DIR}/usdc-mint-address.txt"
if [ ! -f "${MINT_ADDRESS_FILE}" ]; then
spl-token create-token \
--decimals "${MINT_DECIMALS}" \
--mint-authority "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 \
2>&1 | grep "Creating token" | awk '{print $3}' > "${MINT_ADDRESS_FILE}"
echo "Created USDC mint: $(cat "${MINT_ADDRESS_FILE}")"
fi
USDC_MINT=$(cat "${MINT_ADDRESS_FILE}")
# Create ATAs and mint tokens for the client
if [ -n "${CLIENT_PUBKEY:-}" ]; then
echo "Creating ATA for client ${CLIENT_PUBKEY}..."
spl-token create-account "${USDC_MINT}" \
--owner "${CLIENT_PUBKEY}" \
--fee-payer "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
echo "Minting ${MINT_AMOUNT} tokens to client..."
spl-token mint "${USDC_MINT}" "${MINT_AMOUNT}" \
--recipient-owner "${CLIENT_PUBKEY}" \
--mint-authority "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
fi
# Create ATAs for server and facilitator
for PUBKEY in "${SERVER_PUBKEY:-}" "${FACILITATOR_PUBKEY:-}"; do
if [ -n "${PUBKEY}" ]; then
echo "Creating ATA for ${PUBKEY}..."
spl-token create-account "${USDC_MINT}" \
--owner "${PUBKEY}" \
--fee-payer "${MINT_AUTHORITY_FILE}" \
--url http://127.0.0.1:8899 || true
fi
done
# Expose mint address for other containers
cp "${MINT_ADDRESS_FILE}" /tmp/usdc-mint-address.txt 2>/dev/null || true
fi
touch "${SETUP_MARKER}"
echo "Setup complete."
fi
echo "solana-test-validator running (PID ${VALIDATOR_PID})"
wait ${VALIDATOR_PID}