chore: remove scripts/agave-container before subtree add
Moving container scripts into agave-stack subtree (correct direction). The source of truth will be agave-stack/ in this repo, pushed out to LaconicNetwork/agave-stack via git subtree push. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>fix/kind-mount-propagation
parent
08380ec070
commit
7c58809cc1
|
|
@ -1,81 +0,0 @@
|
||||||
# Unified Agave/Jito Solana image
|
|
||||||
# Supports three modes via AGAVE_MODE env: test, rpc, validator
|
|
||||||
#
|
|
||||||
# Build args:
|
|
||||||
# AGAVE_REPO - git repo URL (anza-xyz/agave or jito-foundation/jito-solana)
|
|
||||||
# AGAVE_VERSION - git tag to build (e.g. v3.1.9, v3.1.8-jito)
|
|
||||||
|
|
||||||
ARG AGAVE_REPO=https://github.com/anza-xyz/agave.git
|
|
||||||
ARG AGAVE_VERSION=v3.1.9
|
|
||||||
|
|
||||||
# ---------- Stage 1: Build ----------
|
|
||||||
FROM rust:1.85-bookworm AS builder
|
|
||||||
|
|
||||||
ARG AGAVE_REPO
|
|
||||||
ARG AGAVE_VERSION
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
build-essential \
|
|
||||||
pkg-config \
|
|
||||||
libssl-dev \
|
|
||||||
libudev-dev \
|
|
||||||
libclang-dev \
|
|
||||||
protobuf-compiler \
|
|
||||||
ca-certificates \
|
|
||||||
git \
|
|
||||||
cmake \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
RUN git clone "$AGAVE_REPO" --depth 1 --branch "$AGAVE_VERSION" --recurse-submodules agave
|
|
||||||
WORKDIR /build/agave
|
|
||||||
|
|
||||||
# Cherry-pick --public-tvu-address support (anza-xyz/agave PR #6778, commit 9f4b3ae)
|
|
||||||
# This flag only exists on master, not in v3.1.9 — fetch the PR ref and cherry-pick
|
|
||||||
ARG TVU_ADDRESS_PR=6778
|
|
||||||
RUN if [ -n "$TVU_ADDRESS_PR" ]; then \
|
|
||||||
git fetch --depth 50 origin "pull/${TVU_ADDRESS_PR}/head:tvu-pr" && \
|
|
||||||
git cherry-pick --no-commit tvu-pr; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build all binaries using the upstream install script
|
|
||||||
RUN CI_COMMIT=$(git rev-parse HEAD) scripts/cargo-install-all.sh /solana-release
|
|
||||||
|
|
||||||
# ---------- Stage 2: Runtime ----------
|
|
||||||
FROM debian:bookworm-slim
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
ca-certificates \
|
|
||||||
libssl3 \
|
|
||||||
libudev1 \
|
|
||||||
curl \
|
|
||||||
sudo \
|
|
||||||
aria2 \
|
|
||||||
python3 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Create non-root user with sudo
|
|
||||||
RUN useradd -m -s /bin/bash agave \
|
|
||||||
&& echo "agave ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
|
|
||||||
|
|
||||||
# Copy all compiled binaries
|
|
||||||
COPY --from=builder /solana-release/bin/ /usr/local/bin/
|
|
||||||
|
|
||||||
# Copy entrypoint and support scripts
|
|
||||||
COPY entrypoint.py snapshot_download.py ip_echo_preflight.py /usr/local/bin/
|
|
||||||
COPY start-test.sh /usr/local/bin/
|
|
||||||
RUN chmod +x /usr/local/bin/entrypoint.py /usr/local/bin/start-test.sh
|
|
||||||
|
|
||||||
# Create data directories
|
|
||||||
RUN mkdir -p /data/config /data/ledger /data/accounts /data/snapshots \
|
|
||||||
&& chown -R agave:agave /data
|
|
||||||
|
|
||||||
USER agave
|
|
||||||
WORKDIR /data
|
|
||||||
|
|
||||||
ENV RUST_LOG=info
|
|
||||||
ENV RUST_BACKTRACE=1
|
|
||||||
|
|
||||||
EXPOSE 8899 8900 8001 8001/udp
|
|
||||||
|
|
||||||
ENTRYPOINT ["entrypoint.py"]
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
# Build laconicnetwork/agave
|
|
||||||
# Set AGAVE_REPO and AGAVE_VERSION env vars to build Jito or a different version
|
|
||||||
source ${CERC_CONTAINER_BASE_DIR}/build-base.sh
|
|
||||||
|
|
||||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
|
||||||
|
|
||||||
AGAVE_REPO="${AGAVE_REPO:-https://github.com/anza-xyz/agave.git}"
|
|
||||||
AGAVE_VERSION="${AGAVE_VERSION:-v3.1.9}"
|
|
||||||
|
|
||||||
docker build -t laconicnetwork/agave:local \
|
|
||||||
--build-arg AGAVE_REPO="$AGAVE_REPO" \
|
|
||||||
--build-arg AGAVE_VERSION="$AGAVE_VERSION" \
|
|
||||||
${build_command_args} \
|
|
||||||
-f ${SCRIPT_DIR}/Dockerfile \
|
|
||||||
${SCRIPT_DIR}
|
|
||||||
|
|
@ -1,686 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Agave validator entrypoint — snapshot management, arg construction, liveness probe.
|
|
||||||
|
|
||||||
Two subcommands:
|
|
||||||
entrypoint.py serve (default) — snapshot freshness check + run agave-validator
|
|
||||||
entrypoint.py probe — liveness probe (slot lag check, exits 0/1)
|
|
||||||
|
|
||||||
Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single
|
|
||||||
Python module. Test mode still dispatches to start-test.sh.
|
|
||||||
|
|
||||||
Python stays as PID 1 and traps SIGTERM. On SIGTERM, it runs
|
|
||||||
``agave-validator exit --force --ledger /data/ledger`` which connects to the
|
|
||||||
admin RPC Unix socket and tells the validator to flush I/O and exit cleanly.
|
|
||||||
This avoids the io_uring/ZFS deadlock that occurs when the process is killed.
|
|
||||||
|
|
||||||
All configuration comes from environment variables — same vars as the original
|
|
||||||
bash scripts. See compose files for defaults.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import signal
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.request import Request
|
|
||||||
|
|
||||||
log: logging.Logger = logging.getLogger("entrypoint")
|
|
||||||
|
|
||||||
# Directories
|
|
||||||
CONFIG_DIR = "/data/config"
|
|
||||||
LEDGER_DIR = "/data/ledger"
|
|
||||||
ACCOUNTS_DIR = "/data/accounts"
|
|
||||||
SNAPSHOTS_DIR = "/data/snapshots"
|
|
||||||
LOG_DIR = "/data/log"
|
|
||||||
IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json"
|
|
||||||
|
|
||||||
# Snapshot filename patterns
|
|
||||||
FULL_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
INCR_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^incremental-snapshot-(\d+)-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
|
|
||||||
MAINNET_RPC = "https://api.mainnet-beta.solana.com"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Helpers -------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def env(name: str, default: str = "") -> str:
|
|
||||||
"""Read env var with default."""
|
|
||||||
return os.environ.get(name, default)
|
|
||||||
|
|
||||||
|
|
||||||
def env_required(name: str) -> str:
|
|
||||||
"""Read required env var, exit if missing."""
|
|
||||||
val = os.environ.get(name)
|
|
||||||
if not val:
|
|
||||||
log.error("%s is required but not set", name)
|
|
||||||
sys.exit(1)
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def env_bool(name: str, default: bool = False) -> bool:
|
|
||||||
"""Read boolean env var (true/false/1/0)."""
|
|
||||||
val = os.environ.get(name, "").lower()
|
|
||||||
if not val:
|
|
||||||
return default
|
|
||||||
return val in ("true", "1", "yes")
|
|
||||||
|
|
||||||
|
|
||||||
def rpc_get_slot(url: str, timeout: int = 10) -> int | None:
|
|
||||||
"""Get current slot from a Solana RPC endpoint."""
|
|
||||||
payload = json.dumps({
|
|
||||||
"jsonrpc": "2.0", "id": 1,
|
|
||||||
"method": "getSlot", "params": [],
|
|
||||||
}).encode()
|
|
||||||
req = Request(url, data=payload,
|
|
||||||
headers={"Content-Type": "application/json"})
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
||||||
data = json.loads(resp.read())
|
|
||||||
result = data.get("result")
|
|
||||||
if isinstance(result, int):
|
|
||||||
return result
|
|
||||||
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError):
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# -- Snapshot management -------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_snapshot_slot(snapshots_dir: str) -> int | None:
|
|
||||||
"""Find the highest slot among local snapshot files."""
|
|
||||||
best_slot: int | None = None
|
|
||||||
snap_path = Path(snapshots_dir)
|
|
||||||
if not snap_path.is_dir():
|
|
||||||
return None
|
|
||||||
for entry in snap_path.iterdir():
|
|
||||||
m = FULL_SNAP_RE.match(entry.name)
|
|
||||||
if m:
|
|
||||||
slot = int(m.group(1))
|
|
||||||
if best_slot is None or slot > best_slot:
|
|
||||||
best_slot = slot
|
|
||||||
return best_slot
|
|
||||||
|
|
||||||
|
|
||||||
def clean_snapshots(snapshots_dir: str) -> None:
|
|
||||||
"""Remove all snapshot files from the directory."""
|
|
||||||
snap_path = Path(snapshots_dir)
|
|
||||||
if not snap_path.is_dir():
|
|
||||||
return
|
|
||||||
for entry in snap_path.iterdir():
|
|
||||||
if entry.name.startswith(("snapshot-", "incremental-snapshot-")):
|
|
||||||
log.info("Removing old snapshot: %s", entry.name)
|
|
||||||
entry.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_incremental_slot(snapshots_dir: str, full_slot: int | None) -> int | None:
|
|
||||||
"""Get the highest incremental snapshot slot matching the full's base slot."""
|
|
||||||
if full_slot is None:
|
|
||||||
return None
|
|
||||||
snap_path = Path(snapshots_dir)
|
|
||||||
if not snap_path.is_dir():
|
|
||||||
return None
|
|
||||||
best: int | None = None
|
|
||||||
for entry in snap_path.iterdir():
|
|
||||||
m = INCR_SNAP_RE.match(entry.name)
|
|
||||||
if m and int(m.group(1)) == full_slot:
|
|
||||||
slot = int(m.group(2))
|
|
||||||
if best is None or slot > best:
|
|
||||||
best = slot
|
|
||||||
return best
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_download_snapshot(snapshots_dir: str) -> None:
|
|
||||||
"""Ensure full + incremental snapshots exist before starting.
|
|
||||||
|
|
||||||
The validator should always start from a full + incremental pair to
|
|
||||||
minimize replay time. If either is missing or the full is too old,
|
|
||||||
download fresh ones via download_best_snapshot (which does rolling
|
|
||||||
incremental convergence after downloading the full).
|
|
||||||
|
|
||||||
Controlled by env vars:
|
|
||||||
SNAPSHOT_AUTO_DOWNLOAD (default: true) — enable/disable
|
|
||||||
SNAPSHOT_MAX_AGE_SLOTS (default: 100000) — full snapshot staleness threshold
|
|
||||||
(one full snapshot generation, ~11 hours)
|
|
||||||
"""
|
|
||||||
if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True):
|
|
||||||
log.info("Snapshot auto-download disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "100000"))
|
|
||||||
|
|
||||||
mainnet_slot = rpc_get_slot(MAINNET_RPC)
|
|
||||||
if mainnet_slot is None:
|
|
||||||
log.warning("Cannot reach mainnet RPC — skipping snapshot check")
|
|
||||||
return
|
|
||||||
|
|
||||||
script_dir = Path(__file__).resolve().parent
|
|
||||||
sys.path.insert(0, str(script_dir))
|
|
||||||
from snapshot_download import download_best_snapshot, download_incremental_for_slot
|
|
||||||
|
|
||||||
convergence = int(env("SNAPSHOT_CONVERGENCE_SLOTS", "500"))
|
|
||||||
retry_delay = int(env("SNAPSHOT_RETRY_DELAY", "60"))
|
|
||||||
|
|
||||||
# Check local full snapshot
|
|
||||||
local_slot = get_local_snapshot_slot(snapshots_dir)
|
|
||||||
have_fresh_full = (local_slot is not None
|
|
||||||
and (mainnet_slot - local_slot) <= max_age)
|
|
||||||
|
|
||||||
if have_fresh_full:
|
|
||||||
assert local_slot is not None
|
|
||||||
inc_slot = get_incremental_slot(snapshots_dir, local_slot)
|
|
||||||
if inc_slot is not None:
|
|
||||||
inc_gap = mainnet_slot - inc_slot
|
|
||||||
if inc_gap <= convergence:
|
|
||||||
log.info("Full (slot %d) + incremental (slot %d, gap %d) "
|
|
||||||
"within convergence, starting",
|
|
||||||
local_slot, inc_slot, inc_gap)
|
|
||||||
return
|
|
||||||
log.info("Incremental too stale (slot %d, gap %d > %d)",
|
|
||||||
inc_slot, inc_gap, convergence)
|
|
||||||
# Fresh full, need a fresh incremental
|
|
||||||
log.info("Downloading incremental for full at slot %d", local_slot)
|
|
||||||
while True:
|
|
||||||
if download_incremental_for_slot(snapshots_dir, local_slot,
|
|
||||||
convergence_slots=convergence):
|
|
||||||
return
|
|
||||||
log.warning("Incremental download failed — retrying in %ds",
|
|
||||||
retry_delay)
|
|
||||||
time.sleep(retry_delay)
|
|
||||||
|
|
||||||
# No full or full too old — download both
|
|
||||||
log.info("Downloading full + incremental")
|
|
||||||
clean_snapshots(snapshots_dir)
|
|
||||||
while True:
|
|
||||||
if download_best_snapshot(snapshots_dir, convergence_slots=convergence):
|
|
||||||
return
|
|
||||||
log.warning("Snapshot download failed — retrying in %ds", retry_delay)
|
|
||||||
time.sleep(retry_delay)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Directory and identity setup ----------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_dirs(*dirs: str) -> None:
|
|
||||||
"""Create directories and fix ownership."""
|
|
||||||
uid = os.getuid()
|
|
||||||
gid = os.getgid()
|
|
||||||
for d in dirs:
|
|
||||||
os.makedirs(d, exist_ok=True)
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
["sudo", "chown", "-R", f"{uid}:{gid}", d],
|
|
||||||
check=False, capture_output=True,
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass # sudo not available — dirs already owned correctly
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_identity_rpc() -> None:
|
|
||||||
"""Generate ephemeral identity keypair for RPC mode if not mounted."""
|
|
||||||
if os.path.isfile(IDENTITY_FILE):
|
|
||||||
return
|
|
||||||
log.info("Generating RPC node identity keypair...")
|
|
||||||
subprocess.run(
|
|
||||||
["solana-keygen", "new", "--no-passphrase", "--silent",
|
|
||||||
"--force", "--outfile", IDENTITY_FILE],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def print_identity() -> None:
|
|
||||||
"""Print the node identity pubkey."""
|
|
||||||
result = subprocess.run(
|
|
||||||
["solana-keygen", "pubkey", IDENTITY_FILE],
|
|
||||||
capture_output=True, text=True, check=False,
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
|
||||||
log.info("Node identity: %s", result.stdout.strip())
|
|
||||||
|
|
||||||
|
|
||||||
# -- Arg construction ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def build_common_args() -> list[str]:
|
|
||||||
"""Build agave-validator args common to both RPC and validator modes."""
|
|
||||||
args: list[str] = [
|
|
||||||
"--identity", IDENTITY_FILE,
|
|
||||||
"--entrypoint", env_required("VALIDATOR_ENTRYPOINT"),
|
|
||||||
"--known-validator", env_required("KNOWN_VALIDATOR"),
|
|
||||||
"--ledger", LEDGER_DIR,
|
|
||||||
"--accounts", ACCOUNTS_DIR,
|
|
||||||
"--snapshots", SNAPSHOTS_DIR,
|
|
||||||
"--rpc-port", env("RPC_PORT", "8899"),
|
|
||||||
"--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"),
|
|
||||||
"--gossip-port", env("GOSSIP_PORT", "8001"),
|
|
||||||
"--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"),
|
|
||||||
"--no-os-network-limits-test",
|
|
||||||
"--wal-recovery-mode", "skip_any_corrupted_record",
|
|
||||||
"--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"),
|
|
||||||
"--no-snapshot-fetch", # entrypoint handles snapshot download
|
|
||||||
]
|
|
||||||
|
|
||||||
# Snapshot generation
|
|
||||||
if env("NO_SNAPSHOTS") == "true":
|
|
||||||
args.append("--no-snapshots")
|
|
||||||
else:
|
|
||||||
args += [
|
|
||||||
"--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"),
|
|
||||||
"--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "1"),
|
|
||||||
]
|
|
||||||
if env("NO_INCREMENTAL_SNAPSHOTS") != "true":
|
|
||||||
args += ["--maximum-incremental-snapshots-to-retain", "2"]
|
|
||||||
|
|
||||||
# Account indexes
|
|
||||||
account_indexes = env("ACCOUNT_INDEXES")
|
|
||||||
if account_indexes:
|
|
||||||
for idx in account_indexes.split(","):
|
|
||||||
idx = idx.strip()
|
|
||||||
if idx:
|
|
||||||
args += ["--account-index", idx]
|
|
||||||
|
|
||||||
# Additional entrypoints
|
|
||||||
for ep in env("EXTRA_ENTRYPOINTS").split():
|
|
||||||
if ep:
|
|
||||||
args += ["--entrypoint", ep]
|
|
||||||
|
|
||||||
# Additional known validators
|
|
||||||
for kv in env("EXTRA_KNOWN_VALIDATORS").split():
|
|
||||||
if kv:
|
|
||||||
args += ["--known-validator", kv]
|
|
||||||
|
|
||||||
# Cluster verification
|
|
||||||
genesis_hash = env("EXPECTED_GENESIS_HASH")
|
|
||||||
if genesis_hash:
|
|
||||||
args += ["--expected-genesis-hash", genesis_hash]
|
|
||||||
shred_version = env("EXPECTED_SHRED_VERSION")
|
|
||||||
if shred_version:
|
|
||||||
args += ["--expected-shred-version", shred_version]
|
|
||||||
|
|
||||||
# Metrics — just needs to be in the environment, agave reads it directly
|
|
||||||
# (env var is already set, nothing to pass as arg)
|
|
||||||
|
|
||||||
# Gossip host / TVU address
|
|
||||||
gossip_host = env("GOSSIP_HOST")
|
|
||||||
if gossip_host:
|
|
||||||
args += ["--gossip-host", gossip_host]
|
|
||||||
elif env("PUBLIC_TVU_ADDRESS"):
|
|
||||||
args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")]
|
|
||||||
|
|
||||||
# Jito flags
|
|
||||||
if env("JITO_ENABLE") == "true":
|
|
||||||
log.info("Jito MEV enabled")
|
|
||||||
jito_flags: list[tuple[str, str]] = [
|
|
||||||
("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"),
|
|
||||||
("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"),
|
|
||||||
("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"),
|
|
||||||
("JITO_COMMISSION_BPS", "--commission-bps"),
|
|
||||||
("JITO_BLOCK_ENGINE_URL", "--block-engine-url"),
|
|
||||||
("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"),
|
|
||||||
]
|
|
||||||
for env_name, flag in jito_flags:
|
|
||||||
val = env(env_name)
|
|
||||||
if val:
|
|
||||||
args += [flag, val]
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def build_rpc_args() -> list[str]:
|
|
||||||
"""Build agave-validator args for RPC (non-voting) mode."""
|
|
||||||
args = build_common_args()
|
|
||||||
args += [
|
|
||||||
"--no-voting",
|
|
||||||
"--log", f"{LOG_DIR}/validator.log",
|
|
||||||
"--full-rpc-api",
|
|
||||||
"--enable-rpc-transaction-history",
|
|
||||||
"--rpc-pubsub-enable-block-subscription",
|
|
||||||
"--enable-extended-tx-metadata-storage",
|
|
||||||
"--no-wait-for-vote-to-start-leader",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Public vs private RPC
|
|
||||||
public_rpc = env("PUBLIC_RPC_ADDRESS")
|
|
||||||
if public_rpc:
|
|
||||||
args += ["--public-rpc-address", public_rpc]
|
|
||||||
else:
|
|
||||||
args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"]
|
|
||||||
|
|
||||||
# Jito relayer URL (RPC mode doesn't use it, but validator mode does —
|
|
||||||
# handled in build_validator_args)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def build_validator_args() -> list[str]:
|
|
||||||
"""Build agave-validator args for voting validator mode."""
|
|
||||||
vote_keypair = env("VOTE_ACCOUNT_KEYPAIR",
|
|
||||||
"/data/config/vote-account-keypair.json")
|
|
||||||
|
|
||||||
# Identity must be mounted for validator mode
|
|
||||||
if not os.path.isfile(IDENTITY_FILE):
|
|
||||||
log.error("Validator identity keypair not found at %s", IDENTITY_FILE)
|
|
||||||
log.error("Mount your validator keypair to %s", IDENTITY_FILE)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Vote account keypair must exist
|
|
||||||
if not os.path.isfile(vote_keypair):
|
|
||||||
log.error("Vote account keypair not found at %s", vote_keypair)
|
|
||||||
log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Print vote account pubkey
|
|
||||||
result = subprocess.run(
|
|
||||||
["solana-keygen", "pubkey", vote_keypair],
|
|
||||||
capture_output=True, text=True, check=False,
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
|
||||||
log.info("Vote account: %s", result.stdout.strip())
|
|
||||||
|
|
||||||
args = build_common_args()
|
|
||||||
args += [
|
|
||||||
"--vote-account", vote_keypair,
|
|
||||||
"--log", "-",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Jito relayer URL (validator-only)
|
|
||||||
relayer_url = env("JITO_RELAYER_URL")
|
|
||||||
if env("JITO_ENABLE") == "true" and relayer_url:
|
|
||||||
args += ["--relayer-url", relayer_url]
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def append_extra_args(args: list[str]) -> list[str]:
|
|
||||||
"""Append EXTRA_ARGS passthrough flags."""
|
|
||||||
extra = env("EXTRA_ARGS")
|
|
||||||
if extra:
|
|
||||||
args += extra.split()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
# -- Graceful shutdown --------------------------------------------------------
|
|
||||||
|
|
||||||
# Timeout for graceful exit via admin RPC. Leave 30s margin for k8s
|
|
||||||
# terminationGracePeriodSeconds (300s).
|
|
||||||
GRACEFUL_EXIT_TIMEOUT = 270
|
|
||||||
|
|
||||||
|
|
||||||
def graceful_exit(child: subprocess.Popen[bytes], reason: str = "SIGTERM") -> None:
|
|
||||||
"""Request graceful shutdown via the admin RPC Unix socket.
|
|
||||||
|
|
||||||
Runs ``agave-validator exit --force --ledger /data/ledger`` which connects
|
|
||||||
to the admin RPC socket at ``/data/ledger/admin.rpc`` and sets the
|
|
||||||
validator's exit flag. The validator flushes all I/O and exits cleanly,
|
|
||||||
avoiding the io_uring/ZFS deadlock.
|
|
||||||
|
|
||||||
If the admin RPC exit fails or the child doesn't exit within the timeout,
|
|
||||||
falls back to SIGTERM then SIGKILL.
|
|
||||||
"""
|
|
||||||
log.info("%s — requesting graceful exit via admin RPC", reason)
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["agave-validator", "exit", "--force", "--ledger", LEDGER_DIR],
|
|
||||||
capture_output=True, text=True, timeout=30,
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
|
||||||
log.info("Admin RPC exit requested successfully")
|
|
||||||
else:
|
|
||||||
log.warning(
|
|
||||||
"Admin RPC exit returned %d: %s",
|
|
||||||
result.returncode, result.stderr.strip(),
|
|
||||||
)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
log.warning("Admin RPC exit command timed out after 30s")
|
|
||||||
except FileNotFoundError:
|
|
||||||
log.warning("agave-validator binary not found for exit command")
|
|
||||||
|
|
||||||
# Wait for child to exit
|
|
||||||
try:
|
|
||||||
child.wait(timeout=GRACEFUL_EXIT_TIMEOUT)
|
|
||||||
log.info("Validator exited cleanly with code %d", child.returncode)
|
|
||||||
return
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
log.warning(
|
|
||||||
"Validator did not exit within %ds — sending SIGTERM",
|
|
||||||
GRACEFUL_EXIT_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback: SIGTERM
|
|
||||||
child.terminate()
|
|
||||||
try:
|
|
||||||
child.wait(timeout=15)
|
|
||||||
log.info("Validator exited after SIGTERM with code %d", child.returncode)
|
|
||||||
return
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
log.warning("Validator did not exit after SIGTERM — sending SIGKILL")
|
|
||||||
|
|
||||||
# Last resort: SIGKILL
|
|
||||||
child.kill()
|
|
||||||
child.wait()
|
|
||||||
log.info("Validator killed with SIGKILL, code %d", child.returncode)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Serve subcommand ---------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _gap_monitor(
|
|
||||||
child: subprocess.Popen[bytes],
|
|
||||||
leapfrog: threading.Event,
|
|
||||||
shutting_down: threading.Event,
|
|
||||||
) -> None:
|
|
||||||
"""Background thread: poll slot gap and trigger leapfrog if too far behind.
|
|
||||||
|
|
||||||
Waits for a grace period (SNAPSHOT_MONITOR_GRACE, default 600s) before
|
|
||||||
monitoring — the validator needs time to extract snapshots and catch up.
|
|
||||||
Then polls every SNAPSHOT_MONITOR_INTERVAL (default 30s). If the gap
|
|
||||||
exceeds SNAPSHOT_LEAPFROG_SLOTS (default 5000) for SNAPSHOT_LEAPFROG_CHECKS
|
|
||||||
(default 3) consecutive checks, triggers graceful shutdown and sets the
|
|
||||||
leapfrog event so cmd_serve loops back to download a fresh incremental.
|
|
||||||
"""
|
|
||||||
threshold = int(env("SNAPSHOT_LEAPFROG_SLOTS", "5000"))
|
|
||||||
required_checks = int(env("SNAPSHOT_LEAPFROG_CHECKS", "3"))
|
|
||||||
interval = int(env("SNAPSHOT_MONITOR_INTERVAL", "30"))
|
|
||||||
grace = int(env("SNAPSHOT_MONITOR_GRACE", "600"))
|
|
||||||
rpc_port = env("RPC_PORT", "8899")
|
|
||||||
local_url = f"http://127.0.0.1:{rpc_port}"
|
|
||||||
|
|
||||||
# Grace period — don't monitor during initial catch-up
|
|
||||||
if shutting_down.wait(grace):
|
|
||||||
return
|
|
||||||
|
|
||||||
consecutive = 0
|
|
||||||
while not shutting_down.is_set():
|
|
||||||
local_slot = rpc_get_slot(local_url, timeout=5)
|
|
||||||
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
|
|
||||||
|
|
||||||
if local_slot is not None and mainnet_slot is not None:
|
|
||||||
gap = mainnet_slot - local_slot
|
|
||||||
if gap > threshold:
|
|
||||||
consecutive += 1
|
|
||||||
log.warning("Gap %d > %d (%d/%d consecutive)",
|
|
||||||
gap, threshold, consecutive, required_checks)
|
|
||||||
if consecutive >= required_checks:
|
|
||||||
log.warning("Leapfrog triggered: gap %d", gap)
|
|
||||||
leapfrog.set()
|
|
||||||
graceful_exit(child, reason="Leapfrog")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if consecutive > 0:
|
|
||||||
log.info("Gap %d within threshold, resetting counter", gap)
|
|
||||||
consecutive = 0
|
|
||||||
|
|
||||||
shutting_down.wait(interval)
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_serve() -> None:
|
|
||||||
"""Main serve flow: snapshot download, run validator, monitor gap, leapfrog.
|
|
||||||
|
|
||||||
Python stays as PID 1. On each iteration:
|
|
||||||
1. Download full + incremental snapshots (if needed)
|
|
||||||
2. Start agave-validator as child process
|
|
||||||
3. Monitor slot gap in background thread
|
|
||||||
4. If gap exceeds threshold → graceful stop → loop back to step 1
|
|
||||||
5. If SIGTERM → graceful stop → exit
|
|
||||||
6. If validator crashes → exit with its return code
|
|
||||||
"""
|
|
||||||
mode = env("AGAVE_MODE", "test")
|
|
||||||
log.info("AGAVE_MODE=%s", mode)
|
|
||||||
|
|
||||||
if mode == "test":
|
|
||||||
os.execvp("start-test.sh", ["start-test.sh"])
|
|
||||||
|
|
||||||
if mode not in ("rpc", "validator"):
|
|
||||||
log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# One-time setup
|
|
||||||
dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR]
|
|
||||||
if mode == "rpc":
|
|
||||||
dirs.append(LOG_DIR)
|
|
||||||
ensure_dirs(*dirs)
|
|
||||||
|
|
||||||
if not env_bool("SKIP_IP_ECHO_PREFLIGHT"):
|
|
||||||
script_dir = Path(__file__).resolve().parent
|
|
||||||
sys.path.insert(0, str(script_dir))
|
|
||||||
from ip_echo_preflight import main as ip_echo_main
|
|
||||||
if ip_echo_main() != 0:
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if mode == "rpc":
|
|
||||||
ensure_identity_rpc()
|
|
||||||
print_identity()
|
|
||||||
|
|
||||||
if mode == "rpc":
|
|
||||||
args = build_rpc_args()
|
|
||||||
else:
|
|
||||||
args = build_validator_args()
|
|
||||||
args = append_extra_args(args)
|
|
||||||
|
|
||||||
# Main loop: download → run → monitor → leapfrog if needed
|
|
||||||
while True:
|
|
||||||
maybe_download_snapshot(SNAPSHOTS_DIR)
|
|
||||||
|
|
||||||
Path("/tmp/entrypoint-start").write_text(str(time.time()))
|
|
||||||
log.info("Starting agave-validator with %d arguments", len(args))
|
|
||||||
child = subprocess.Popen(["agave-validator"] + args)
|
|
||||||
|
|
||||||
shutting_down = threading.Event()
|
|
||||||
leapfrog = threading.Event()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGUSR1,
|
|
||||||
lambda _sig, _frame: child.send_signal(signal.SIGUSR1))
|
|
||||||
|
|
||||||
def _on_sigterm(_sig: int, _frame: object) -> None:
|
|
||||||
shutting_down.set()
|
|
||||||
threading.Thread(
|
|
||||||
target=graceful_exit, args=(child,), daemon=True,
|
|
||||||
).start()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, _on_sigterm)
|
|
||||||
|
|
||||||
# Start gap monitor
|
|
||||||
monitor = threading.Thread(
|
|
||||||
target=_gap_monitor,
|
|
||||||
args=(child, leapfrog, shutting_down),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
monitor.start()
|
|
||||||
|
|
||||||
child.wait()
|
|
||||||
|
|
||||||
if leapfrog.is_set():
|
|
||||||
log.info("Leapfrog: restarting with fresh incremental")
|
|
||||||
continue
|
|
||||||
|
|
||||||
sys.exit(child.returncode)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Probe subcommand ---------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_probe() -> None:
|
|
||||||
"""Liveness probe: check local RPC slot vs mainnet.
|
|
||||||
|
|
||||||
Exit 0 = healthy, exit 1 = unhealthy.
|
|
||||||
|
|
||||||
Grace period: PROBE_GRACE_SECONDS (default 600) — probe always passes
|
|
||||||
during grace period to allow for snapshot unpacking and initial replay.
|
|
||||||
"""
|
|
||||||
grace_seconds = int(env("PROBE_GRACE_SECONDS", "600"))
|
|
||||||
max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000"))
|
|
||||||
|
|
||||||
# Check grace period
|
|
||||||
start_file = Path("/tmp/entrypoint-start")
|
|
||||||
if start_file.exists():
|
|
||||||
try:
|
|
||||||
start_time = float(start_file.read_text().strip())
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
if elapsed < grace_seconds:
|
|
||||||
# Within grace period — always healthy
|
|
||||||
sys.exit(0)
|
|
||||||
except (ValueError, OSError):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# No start file — serve hasn't started yet, within grace
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Query local RPC
|
|
||||||
rpc_port = env("RPC_PORT", "8899")
|
|
||||||
local_url = f"http://127.0.0.1:{rpc_port}"
|
|
||||||
local_slot = rpc_get_slot(local_url, timeout=5)
|
|
||||||
if local_slot is None:
|
|
||||||
# Local RPC unreachable after grace period — unhealthy
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Query mainnet
|
|
||||||
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
|
|
||||||
if mainnet_slot is None:
|
|
||||||
# Can't reach mainnet to compare — assume healthy (don't penalize
|
|
||||||
# the validator for mainnet RPC being down)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
lag = mainnet_slot - local_slot
|
|
||||||
if lag > max_lag:
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Main ----------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve"
|
|
||||||
|
|
||||||
if subcmd == "serve":
|
|
||||||
cmd_serve()
|
|
||||||
elif subcmd == "probe":
|
|
||||||
cmd_probe()
|
|
||||||
else:
|
|
||||||
log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,249 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""ip_echo preflight — verify UDP port reachability before starting the validator.
|
|
||||||
|
|
||||||
Implements the Solana ip_echo client protocol exactly:
|
|
||||||
1. Bind UDP sockets on the ports the validator will use
|
|
||||||
2. TCP connect to entrypoint gossip port, send IpEchoServerMessage
|
|
||||||
3. Parse IpEchoServerResponse (our IP as seen by entrypoint)
|
|
||||||
4. Wait for entrypoint's UDP probes on each port
|
|
||||||
5. Exit 0 if all ports reachable, exit 1 if any fail
|
|
||||||
|
|
||||||
Wire format (from agave net-utils/src/):
|
|
||||||
Request: 4 null bytes + [u16; 4] tcp_ports LE + [u16; 4] udp_ports LE + \n
|
|
||||||
Response: 4 null bytes + bincode IpAddr (variant byte + addr) + optional shred_version
|
|
||||||
|
|
||||||
Called from entrypoint.py before snapshot download. Prevents wasting hours
|
|
||||||
downloading a snapshot only to crash-loop on port reachability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import struct
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
log = logging.getLogger("ip_echo_preflight")
|
|
||||||
|
|
||||||
HEADER = b"\x00\x00\x00\x00"
|
|
||||||
TERMINUS = b"\x0a"
|
|
||||||
RESPONSE_BUF = 27
|
|
||||||
IO_TIMEOUT = 5.0
|
|
||||||
PROBE_TIMEOUT = 10.0
|
|
||||||
MAX_RETRIES = 3
|
|
||||||
RETRY_DELAY = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def build_request(tcp_ports: list[int], udp_ports: list[int]) -> bytes:
|
|
||||||
"""Build IpEchoServerMessage: header + [u16;4] tcp + [u16;4] udp + newline."""
|
|
||||||
tcp = (tcp_ports + [0, 0, 0, 0])[:4]
|
|
||||||
udp = (udp_ports + [0, 0, 0, 0])[:4]
|
|
||||||
return HEADER + struct.pack("<4H", *tcp) + struct.pack("<4H", *udp) + TERMINUS
|
|
||||||
|
|
||||||
|
|
||||||
def parse_response(data: bytes) -> tuple[str, int | None]:
|
|
||||||
"""Parse IpEchoServerResponse → (ip_string, shred_version | None).
|
|
||||||
|
|
||||||
Wire format (bincode):
|
|
||||||
4 bytes header (\0\0\0\0)
|
|
||||||
4 bytes IpAddr enum variant (u32 LE: 0=IPv4, 1=IPv6)
|
|
||||||
4|16 bytes address octets
|
|
||||||
1 byte Option tag (0=None, 1=Some)
|
|
||||||
2 bytes shred_version (u16 LE, only if Some)
|
|
||||||
"""
|
|
||||||
if len(data) < 8:
|
|
||||||
raise ValueError(f"response too short: {len(data)} bytes")
|
|
||||||
if data[:4] == b"HTTP":
|
|
||||||
raise ValueError("got HTTP response — not an ip_echo server")
|
|
||||||
if data[:4] != HEADER:
|
|
||||||
raise ValueError(f"unexpected header: {data[:4].hex()}")
|
|
||||||
variant = struct.unpack("<I", data[4:8])[0]
|
|
||||||
if variant == 0: # IPv4
|
|
||||||
if len(data) < 12:
|
|
||||||
raise ValueError(f"IPv4 response truncated: {len(data)} bytes")
|
|
||||||
ip = socket.inet_ntoa(data[8:12])
|
|
||||||
rest = data[12:]
|
|
||||||
elif variant == 1: # IPv6
|
|
||||||
if len(data) < 24:
|
|
||||||
raise ValueError(f"IPv6 response truncated: {len(data)} bytes")
|
|
||||||
ip = socket.inet_ntop(socket.AF_INET6, data[8:24])
|
|
||||||
rest = data[24:]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown IpAddr variant: {variant}")
|
|
||||||
shred_version = None
|
|
||||||
if len(rest) >= 3 and rest[0] == 1:
|
|
||||||
shred_version = struct.unpack("<H", rest[1:3])[0]
|
|
||||||
return ip, shred_version
|
|
||||||
|
|
||||||
|
|
||||||
def _listen_udp(port: int, results: dict, stop: threading.Event) -> None:
|
|
||||||
"""Bind a UDP socket and wait for a probe packet."""
|
|
||||||
try:
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
sock.bind(("0.0.0.0", port))
|
|
||||||
sock.settimeout(0.5)
|
|
||||||
try:
|
|
||||||
while not stop.is_set():
|
|
||||||
try:
|
|
||||||
_data, addr = sock.recvfrom(64)
|
|
||||||
results[port] = ("ok", addr)
|
|
||||||
return
|
|
||||||
except socket.timeout:
|
|
||||||
continue
|
|
||||||
finally:
|
|
||||||
sock.close()
|
|
||||||
except OSError as exc:
|
|
||||||
results[port] = ("bind_error", str(exc))
|
|
||||||
|
|
||||||
|
|
||||||
def ip_echo_check(
|
|
||||||
entrypoint_host: str,
|
|
||||||
entrypoint_port: int,
|
|
||||||
udp_ports: list[int],
|
|
||||||
) -> tuple[str, dict[int, bool]]:
|
|
||||||
"""Run one ip_echo exchange and return (seen_ip, {port: reachable}).
|
|
||||||
|
|
||||||
Raises on TCP failure (caller retries).
|
|
||||||
"""
|
|
||||||
udp_ports = [p for p in udp_ports if p != 0][:4]
|
|
||||||
|
|
||||||
# Start UDP listeners before sending the TCP request
|
|
||||||
results: dict[int, tuple] = {}
|
|
||||||
stop = threading.Event()
|
|
||||||
threads = []
|
|
||||||
for port in udp_ports:
|
|
||||||
t = threading.Thread(target=_listen_udp, args=(port, results, stop), daemon=True)
|
|
||||||
t.start()
|
|
||||||
threads.append(t)
|
|
||||||
time.sleep(0.1) # let listeners bind
|
|
||||||
|
|
||||||
# TCP: send request, read response
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
sock.settimeout(IO_TIMEOUT)
|
|
||||||
try:
|
|
||||||
sock.connect((entrypoint_host, entrypoint_port))
|
|
||||||
sock.sendall(build_request([], udp_ports))
|
|
||||||
resp = sock.recv(RESPONSE_BUF)
|
|
||||||
finally:
|
|
||||||
sock.close()
|
|
||||||
|
|
||||||
seen_ip, shred_version = parse_response(resp)
|
|
||||||
log.info(
|
|
||||||
"entrypoint %s:%d sees us as %s (shred_version=%s)",
|
|
||||||
entrypoint_host, entrypoint_port, seen_ip, shred_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for UDP probes
|
|
||||||
deadline = time.monotonic() + PROBE_TIMEOUT
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
if all(p in results for p in udp_ports):
|
|
||||||
break
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
stop.set()
|
|
||||||
for t in threads:
|
|
||||||
t.join(timeout=1)
|
|
||||||
|
|
||||||
port_ok: dict[int, bool] = {}
|
|
||||||
for port in udp_ports:
|
|
||||||
if port not in results:
|
|
||||||
log.error("port %d: no probe received within %.0fs", port, PROBE_TIMEOUT)
|
|
||||||
port_ok[port] = False
|
|
||||||
else:
|
|
||||||
status, detail = results[port]
|
|
||||||
if status == "ok":
|
|
||||||
log.info("port %d: probe received from %s", port, detail)
|
|
||||||
port_ok[port] = True
|
|
||||||
else:
|
|
||||||
log.error("port %d: %s: %s", port, status, detail)
|
|
||||||
port_ok[port] = False
|
|
||||||
|
|
||||||
return seen_ip, port_ok
|
|
||||||
|
|
||||||
|
|
||||||
def run_preflight(
|
|
||||||
entrypoint_host: str,
|
|
||||||
entrypoint_port: int,
|
|
||||||
udp_ports: list[int],
|
|
||||||
expected_ip: str = "",
|
|
||||||
) -> bool:
|
|
||||||
"""Run ip_echo check with retries. Returns True if all ports pass."""
|
|
||||||
for attempt in range(1, MAX_RETRIES + 1):
|
|
||||||
log.info("ip_echo attempt %d/%d → %s:%d, ports %s",
|
|
||||||
attempt, MAX_RETRIES, entrypoint_host, entrypoint_port, udp_ports)
|
|
||||||
try:
|
|
||||||
seen_ip, port_ok = ip_echo_check(entrypoint_host, entrypoint_port, udp_ports)
|
|
||||||
except Exception as exc:
|
|
||||||
log.error("attempt %d TCP failed: %s", attempt, exc)
|
|
||||||
if attempt < MAX_RETRIES:
|
|
||||||
time.sleep(RETRY_DELAY)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if expected_ip and seen_ip != expected_ip:
|
|
||||||
log.error(
|
|
||||||
"IP MISMATCH: entrypoint sees %s, expected %s (GOSSIP_HOST). "
|
|
||||||
"Outbound mangle/SNAT path is broken.",
|
|
||||||
seen_ip, expected_ip,
|
|
||||||
)
|
|
||||||
if attempt < MAX_RETRIES:
|
|
||||||
time.sleep(RETRY_DELAY)
|
|
||||||
continue
|
|
||||||
|
|
||||||
reachable = [p for p, ok in port_ok.items() if ok]
|
|
||||||
unreachable = [p for p, ok in port_ok.items() if not ok]
|
|
||||||
|
|
||||||
if not unreachable:
|
|
||||||
log.info("PASS: all ports reachable %s, seen as %s", reachable, seen_ip)
|
|
||||||
return True
|
|
||||||
|
|
||||||
log.error(
|
|
||||||
"attempt %d: unreachable %s, reachable %s, seen as %s",
|
|
||||||
attempt, unreachable, reachable, seen_ip,
|
|
||||||
)
|
|
||||||
if attempt < MAX_RETRIES:
|
|
||||||
time.sleep(RETRY_DELAY)
|
|
||||||
|
|
||||||
log.error("FAIL: ip_echo preflight exhausted %d attempts", MAX_RETRIES)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse entrypoint — VALIDATOR_ENTRYPOINT is "host:port"
|
|
||||||
raw = os.environ.get("VALIDATOR_ENTRYPOINT", "")
|
|
||||||
if not raw and len(sys.argv) > 1:
|
|
||||||
raw = sys.argv[1]
|
|
||||||
if not raw:
|
|
||||||
log.error("set VALIDATOR_ENTRYPOINT or pass host:port as argument")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if ":" in raw:
|
|
||||||
host, port_str = raw.rsplit(":", 1)
|
|
||||||
ep_port = int(port_str)
|
|
||||||
else:
|
|
||||||
host = raw
|
|
||||||
ep_port = 8001
|
|
||||||
|
|
||||||
gossip_port = int(os.environ.get("GOSSIP_PORT", "8001"))
|
|
||||||
dynamic_range = os.environ.get("DYNAMIC_PORT_RANGE", "9000-10000")
|
|
||||||
range_start = int(dynamic_range.split("-")[0])
|
|
||||||
expected_ip = os.environ.get("GOSSIP_HOST", "")
|
|
||||||
|
|
||||||
# Test gossip + first 3 ports from dynamic range (4 max per ip_echo message)
|
|
||||||
udp_ports = [gossip_port, range_start, range_start + 2, range_start + 3]
|
|
||||||
|
|
||||||
ok = run_preflight(host, ep_port, udp_ports, expected_ip)
|
|
||||||
return 0 if ok else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -1,878 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""Download Solana snapshots using aria2c for parallel multi-connection downloads.
|
|
||||||
|
|
||||||
Discovers snapshot sources by querying getClusterNodes for all RPCs in the
|
|
||||||
cluster, probing each for available snapshots, benchmarking download speed,
|
|
||||||
and downloading from the fastest source using aria2c (16 connections by default).
|
|
||||||
|
|
||||||
Based on the discovery approach from etcusr/solana-snapshot-finder but replaces
|
|
||||||
the single-connection wget download with aria2c parallel chunked downloads.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Download to /srv/kind/solana/snapshots (mainnet, 16 connections)
|
|
||||||
./snapshot_download.py -o /srv/kind/solana/snapshots
|
|
||||||
|
|
||||||
# Dry run — find best source, print URL
|
|
||||||
./snapshot_download.py --dry-run
|
|
||||||
|
|
||||||
# Custom RPC for cluster discovery + 32 connections
|
|
||||||
./snapshot_download.py -r https://api.mainnet-beta.solana.com -n 32
|
|
||||||
|
|
||||||
# Testnet
|
|
||||||
./snapshot_download.py -c testnet -o /data/snapshots
|
|
||||||
|
|
||||||
# Programmatic use from entrypoint.py:
|
|
||||||
from snapshot_download import download_best_snapshot
|
|
||||||
ok = download_best_snapshot("/data/snapshots")
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- aria2c (apt install aria2)
|
|
||||||
- python3 >= 3.10 (stdlib only, no pip dependencies)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import concurrent.futures
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from http.client import HTTPResponse
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.request import Request
|
|
||||||
|
|
||||||
log: logging.Logger = logging.getLogger("snapshot-download")
|
|
||||||
|
|
||||||
CLUSTER_RPC: dict[str, str] = {
|
|
||||||
"mainnet-beta": "https://api.mainnet-beta.solana.com",
|
|
||||||
"testnet": "https://api.testnet.solana.com",
|
|
||||||
"devnet": "https://api.devnet.solana.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Snapshot filenames:
|
|
||||||
# snapshot-<slot>-<hash>.tar.zst
|
|
||||||
# incremental-snapshot-<base_slot>-<slot>-<hash>.tar.zst
|
|
||||||
FULL_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^snapshot-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
INCR_SNAP_RE: re.Pattern[str] = re.compile(
|
|
||||||
r"^incremental-snapshot-(\d+)-(\d+)-([A-Za-z0-9]+)\.tar\.(zst|bz2)$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SnapshotSource:
|
|
||||||
"""A snapshot file available from a specific RPC node."""
|
|
||||||
|
|
||||||
rpc_address: str
|
|
||||||
# Full redirect paths as returned by the server (e.g. /snapshot-123-hash.tar.zst)
|
|
||||||
file_paths: list[str] = field(default_factory=list)
|
|
||||||
slots_diff: int = 0
|
|
||||||
latency_ms: float = 0.0
|
|
||||||
download_speed: float = 0.0 # bytes/sec
|
|
||||||
|
|
||||||
|
|
||||||
# -- JSON-RPC helpers ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
||||||
"""Handler that captures redirect Location instead of following it."""
|
|
||||||
|
|
||||||
def redirect_request(
|
|
||||||
self,
|
|
||||||
req: Request,
|
|
||||||
fp: HTTPResponse,
|
|
||||||
code: int,
|
|
||||||
msg: str,
|
|
||||||
headers: dict[str, str], # type: ignore[override]
|
|
||||||
newurl: str,
|
|
||||||
) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def rpc_post(url: str, method: str, params: list[object] | None = None,
|
|
||||||
timeout: int = 25) -> object | None:
|
|
||||||
"""JSON-RPC POST. Returns parsed 'result' field or None on error."""
|
|
||||||
payload: bytes = json.dumps({
|
|
||||||
"jsonrpc": "2.0", "id": 1,
|
|
||||||
"method": method, "params": params or [],
|
|
||||||
}).encode()
|
|
||||||
req = Request(url, data=payload,
|
|
||||||
headers={"Content-Type": "application/json"})
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
||||||
data: dict[str, object] = json.loads(resp.read())
|
|
||||||
return data.get("result")
|
|
||||||
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError) as e:
|
|
||||||
log.debug("rpc_post %s %s failed: %s", url, method, e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def head_no_follow(url: str, timeout: float = 3) -> tuple[str | None, float]:
|
|
||||||
"""HEAD request without following redirects.
|
|
||||||
|
|
||||||
Returns (Location header value, latency_sec) if the server returned a
|
|
||||||
3xx redirect. Returns (None, 0.0) on any error or non-redirect response.
|
|
||||||
"""
|
|
||||||
opener: urllib.request.OpenerDirector = urllib.request.build_opener(_NoRedirectHandler)
|
|
||||||
req = Request(url, method="HEAD")
|
|
||||||
try:
|
|
||||||
start: float = time.monotonic()
|
|
||||||
resp: HTTPResponse = opener.open(req, timeout=timeout) # type: ignore[assignment]
|
|
||||||
latency: float = time.monotonic() - start
|
|
||||||
# Non-redirect (2xx) — server didn't redirect, not useful for discovery
|
|
||||||
location: str | None = resp.headers.get("Location")
|
|
||||||
resp.close()
|
|
||||||
return location, latency
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
# 3xx redirects raise HTTPError with the redirect info
|
|
||||||
latency = time.monotonic() - start # type: ignore[possibly-undefined]
|
|
||||||
location = e.headers.get("Location")
|
|
||||||
if location and 300 <= e.code < 400:
|
|
||||||
return location, latency
|
|
||||||
return None, 0.0
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError):
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
|
|
||||||
# -- Discovery -----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_slot(rpc_url: str) -> int | None:
|
|
||||||
"""Get current slot from RPC."""
|
|
||||||
result: object | None = rpc_post(rpc_url, "getSlot")
|
|
||||||
if isinstance(result, int):
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cluster_rpc_nodes(rpc_url: str, version_filter: str | None = None) -> list[str]:
|
|
||||||
"""Get all RPC node addresses from getClusterNodes."""
|
|
||||||
result: object | None = rpc_post(rpc_url, "getClusterNodes")
|
|
||||||
if not isinstance(result, list):
|
|
||||||
return []
|
|
||||||
|
|
||||||
rpc_addrs: list[str] = []
|
|
||||||
for node in result:
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
continue
|
|
||||||
if version_filter is not None:
|
|
||||||
node_version: str | None = node.get("version")
|
|
||||||
if node_version and not node_version.startswith(version_filter):
|
|
||||||
continue
|
|
||||||
rpc: str | None = node.get("rpc")
|
|
||||||
if rpc:
|
|
||||||
rpc_addrs.append(rpc)
|
|
||||||
return list(set(rpc_addrs))
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_snapshot_filename(location: str) -> tuple[str, str | None]:
|
|
||||||
"""Extract filename and full redirect path from Location header.
|
|
||||||
|
|
||||||
Returns (filename, full_path). full_path includes any path prefix
|
|
||||||
the server returned (e.g. '/snapshots/snapshot-123-hash.tar.zst').
|
|
||||||
"""
|
|
||||||
# Location may be absolute URL or relative path
|
|
||||||
if location.startswith("http://") or location.startswith("https://"):
|
|
||||||
# Absolute URL — extract path
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
path: str = urlparse(location).path
|
|
||||||
else:
|
|
||||||
path = location
|
|
||||||
|
|
||||||
filename: str = path.rsplit("/", 1)[-1]
|
|
||||||
return filename, path
|
|
||||||
|
|
||||||
|
|
||||||
def probe_rpc_snapshot(
|
|
||||||
rpc_address: str,
|
|
||||||
current_slot: int,
|
|
||||||
) -> SnapshotSource | None:
|
|
||||||
"""Probe a single RPC node for available snapshots.
|
|
||||||
|
|
||||||
Discovery only — no filtering. Returns a SnapshotSource with all available
|
|
||||||
info so the caller can decide what to keep. Filtering happens after all
|
|
||||||
probes complete, so rejected sources are still visible for debugging.
|
|
||||||
"""
|
|
||||||
full_url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
||||||
|
|
||||||
# Full snapshot is required — every source must have one
|
|
||||||
full_location, full_latency = head_no_follow(full_url, timeout=2)
|
|
||||||
if not full_location:
|
|
||||||
return None
|
|
||||||
|
|
||||||
latency_ms: float = full_latency * 1000
|
|
||||||
|
|
||||||
full_filename, full_path = _parse_snapshot_filename(full_location)
|
|
||||||
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
|
||||||
if not fm:
|
|
||||||
return None
|
|
||||||
|
|
||||||
full_snap_slot: int = int(fm.group(1))
|
|
||||||
slots_diff: int = current_slot - full_snap_slot
|
|
||||||
|
|
||||||
file_paths: list[str] = [full_path]
|
|
||||||
|
|
||||||
# Also check for incremental snapshot
|
|
||||||
inc_url: str = f"http://{rpc_address}/incremental-snapshot.tar.bz2"
|
|
||||||
inc_location, _ = head_no_follow(inc_url, timeout=2)
|
|
||||||
if inc_location:
|
|
||||||
inc_filename, inc_path = _parse_snapshot_filename(inc_location)
|
|
||||||
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_filename)
|
|
||||||
if m:
|
|
||||||
inc_base_slot: int = int(m.group(1))
|
|
||||||
# Incremental must be based on this source's full snapshot
|
|
||||||
if inc_base_slot == full_snap_slot:
|
|
||||||
file_paths.append(inc_path)
|
|
||||||
|
|
||||||
return SnapshotSource(
|
|
||||||
rpc_address=rpc_address,
|
|
||||||
file_paths=file_paths,
|
|
||||||
slots_diff=slots_diff,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def discover_sources(
|
|
||||||
rpc_url: str,
|
|
||||||
current_slot: int,
|
|
||||||
max_age_slots: int,
|
|
||||||
max_latency_ms: float,
|
|
||||||
threads: int,
|
|
||||||
version_filter: str | None,
|
|
||||||
) -> list[SnapshotSource]:
|
|
||||||
"""Discover all snapshot sources, then filter.
|
|
||||||
|
|
||||||
Probing and filtering are separate: all reachable sources are collected
|
|
||||||
first so we can report what exists even if filters reject everything.
|
|
||||||
"""
|
|
||||||
rpc_nodes: list[str] = get_cluster_rpc_nodes(rpc_url, version_filter)
|
|
||||||
if not rpc_nodes:
|
|
||||||
log.error("No RPC nodes found via getClusterNodes")
|
|
||||||
return []
|
|
||||||
|
|
||||||
log.info("Found %d RPC nodes, probing for snapshots...", len(rpc_nodes))
|
|
||||||
|
|
||||||
all_sources: list[SnapshotSource] = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
|
|
||||||
futures: dict[concurrent.futures.Future[SnapshotSource | None], str] = {
|
|
||||||
pool.submit(probe_rpc_snapshot, addr, current_slot): addr
|
|
||||||
for addr in rpc_nodes
|
|
||||||
}
|
|
||||||
done: int = 0
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
done += 1
|
|
||||||
if done % 200 == 0:
|
|
||||||
log.info(" probed %d/%d nodes, %d reachable",
|
|
||||||
done, len(rpc_nodes), len(all_sources))
|
|
||||||
try:
|
|
||||||
result: SnapshotSource | None = future.result()
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError) as e:
|
|
||||||
log.debug("Probe failed for %s: %s", futures[future], e)
|
|
||||||
continue
|
|
||||||
if result:
|
|
||||||
all_sources.append(result)
|
|
||||||
|
|
||||||
log.info("Discovered %d reachable sources", len(all_sources))
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
filtered: list[SnapshotSource] = []
|
|
||||||
rejected_age: int = 0
|
|
||||||
rejected_latency: int = 0
|
|
||||||
for src in all_sources:
|
|
||||||
if src.slots_diff > max_age_slots or src.slots_diff < -100:
|
|
||||||
rejected_age += 1
|
|
||||||
continue
|
|
||||||
if src.latency_ms > max_latency_ms:
|
|
||||||
rejected_latency += 1
|
|
||||||
continue
|
|
||||||
filtered.append(src)
|
|
||||||
|
|
||||||
if rejected_age or rejected_latency:
|
|
||||||
log.info("Filtered: %d rejected by age (>%d slots), %d by latency (>%.0fms)",
|
|
||||||
rejected_age, max_age_slots, rejected_latency, max_latency_ms)
|
|
||||||
|
|
||||||
if not filtered and all_sources:
|
|
||||||
# Show what was available so the user can adjust filters
|
|
||||||
all_sources.sort(key=lambda s: s.slots_diff)
|
|
||||||
best = all_sources[0]
|
|
||||||
log.warning("All %d sources rejected by filters. Best available: "
|
|
||||||
"%s (age=%d slots, latency=%.0fms). "
|
|
||||||
"Try --max-snapshot-age %d --max-latency %.0f",
|
|
||||||
len(all_sources), best.rpc_address,
|
|
||||||
best.slots_diff, best.latency_ms,
|
|
||||||
best.slots_diff + 500,
|
|
||||||
max(best.latency_ms * 1.5, 500))
|
|
||||||
|
|
||||||
log.info("Found %d sources after filtering", len(filtered))
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
|
|
||||||
# -- Speed benchmark -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def measure_speed(rpc_address: str, measure_time: int = 7) -> float:
|
|
||||||
"""Measure download speed from an RPC node. Returns bytes/sec."""
|
|
||||||
url: str = f"http://{rpc_address}/snapshot.tar.bz2"
|
|
||||||
req = Request(url)
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=measure_time + 5) as resp:
|
|
||||||
start: float = time.monotonic()
|
|
||||||
total: int = 0
|
|
||||||
while True:
|
|
||||||
elapsed: float = time.monotonic() - start
|
|
||||||
if elapsed >= measure_time:
|
|
||||||
break
|
|
||||||
chunk: bytes = resp.read(81920)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
total += len(chunk)
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
if elapsed <= 0:
|
|
||||||
return 0.0
|
|
||||||
return total / elapsed
|
|
||||||
except (urllib.error.URLError, OSError, TimeoutError):
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
# -- Incremental probing -------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def probe_incremental(
|
|
||||||
fast_sources: list[SnapshotSource],
|
|
||||||
full_snap_slot: int,
|
|
||||||
) -> tuple[str | None, list[str]]:
|
|
||||||
"""Probe fast sources for the best incremental matching full_snap_slot.
|
|
||||||
|
|
||||||
Returns (filename, mirror_urls) or (None, []) if no match found.
|
|
||||||
The "best" incremental is the one with the highest slot (closest to head).
|
|
||||||
"""
|
|
||||||
best_filename: str | None = None
|
|
||||||
best_slot: int = 0
|
|
||||||
best_source: SnapshotSource | None = None
|
|
||||||
best_path: str | None = None
|
|
||||||
|
|
||||||
for source in fast_sources:
|
|
||||||
inc_url: str = f"http://{source.rpc_address}/incremental-snapshot.tar.bz2"
|
|
||||||
inc_location, _ = head_no_follow(inc_url, timeout=2)
|
|
||||||
if not inc_location:
|
|
||||||
continue
|
|
||||||
inc_fn, inc_fp = _parse_snapshot_filename(inc_location)
|
|
||||||
m: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn)
|
|
||||||
if not m:
|
|
||||||
continue
|
|
||||||
if int(m.group(1)) != full_snap_slot:
|
|
||||||
log.debug(" %s: incremental base slot %s != full %d, skipping",
|
|
||||||
source.rpc_address, m.group(1), full_snap_slot)
|
|
||||||
continue
|
|
||||||
inc_slot: int = int(m.group(2))
|
|
||||||
if inc_slot > best_slot:
|
|
||||||
best_slot = inc_slot
|
|
||||||
best_filename = inc_fn
|
|
||||||
best_source = source
|
|
||||||
best_path = inc_fp
|
|
||||||
|
|
||||||
if best_filename is None or best_source is None or best_path is None:
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
# Build mirror list — check other sources for the same filename
|
|
||||||
mirror_urls: list[str] = [f"http://{best_source.rpc_address}{best_path}"]
|
|
||||||
for other in fast_sources:
|
|
||||||
if other.rpc_address == best_source.rpc_address:
|
|
||||||
continue
|
|
||||||
other_loc, _ = head_no_follow(
|
|
||||||
f"http://{other.rpc_address}/incremental-snapshot.tar.bz2", timeout=2)
|
|
||||||
if other_loc:
|
|
||||||
other_fn, other_fp = _parse_snapshot_filename(other_loc)
|
|
||||||
if other_fn == best_filename:
|
|
||||||
mirror_urls.append(f"http://{other.rpc_address}{other_fp}")
|
|
||||||
|
|
||||||
return best_filename, mirror_urls
|
|
||||||
|
|
||||||
|
|
||||||
# -- Download ------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def download_aria2c(
|
|
||||||
urls: list[str],
|
|
||||||
output_dir: str,
|
|
||||||
filename: str,
|
|
||||||
connections: int = 16,
|
|
||||||
) -> bool:
|
|
||||||
"""Download a file using aria2c with parallel connections.
|
|
||||||
|
|
||||||
When multiple URLs are provided, aria2c treats them as mirrors of the
|
|
||||||
same file and distributes chunks across all of them.
|
|
||||||
"""
|
|
||||||
num_mirrors: int = len(urls)
|
|
||||||
total_splits: int = max(connections, connections * num_mirrors)
|
|
||||||
cmd: list[str] = [
|
|
||||||
"aria2c",
|
|
||||||
"--file-allocation=none",
|
|
||||||
"--continue=false",
|
|
||||||
f"--max-connection-per-server={connections}",
|
|
||||||
f"--split={total_splits}",
|
|
||||||
"--min-split-size=50M",
|
|
||||||
# aria2c retries individual chunk connections on transient network
|
|
||||||
# errors (TCP reset, timeout). This is transport-level retry analogous
|
|
||||||
# to TCP retransmit, not application-level retry of a failed operation.
|
|
||||||
"--max-tries=5",
|
|
||||||
"--retry-wait=5",
|
|
||||||
"--timeout=60",
|
|
||||||
"--connect-timeout=10",
|
|
||||||
"--summary-interval=10",
|
|
||||||
"--console-log-level=notice",
|
|
||||||
f"--dir={output_dir}",
|
|
||||||
f"--out={filename}",
|
|
||||||
"--auto-file-renaming=false",
|
|
||||||
"--allow-overwrite=true",
|
|
||||||
*urls,
|
|
||||||
]
|
|
||||||
|
|
||||||
log.info("Downloading %s", filename)
|
|
||||||
log.info(" aria2c: %d connections x %d mirrors (%d splits)",
|
|
||||||
connections, num_mirrors, total_splits)
|
|
||||||
|
|
||||||
start: float = time.monotonic()
|
|
||||||
result: subprocess.CompletedProcess[bytes] = subprocess.run(cmd)
|
|
||||||
elapsed: float = time.monotonic() - start
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
log.error("aria2c failed with exit code %d", result.returncode)
|
|
||||||
return False
|
|
||||||
|
|
||||||
filepath: Path = Path(output_dir) / filename
|
|
||||||
if not filepath.exists():
|
|
||||||
log.error("aria2c reported success but %s does not exist", filepath)
|
|
||||||
return False
|
|
||||||
|
|
||||||
size_bytes: int = filepath.stat().st_size
|
|
||||||
size_gb: float = size_bytes / (1024 ** 3)
|
|
||||||
avg_mb: float = size_bytes / elapsed / (1024 ** 2) if elapsed > 0 else 0
|
|
||||||
log.info(" Done: %.1f GB in %.0fs (%.1f MiB/s avg)", size_gb, elapsed, avg_mb)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# -- Shared helpers ------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _discover_and_benchmark(
|
|
||||||
rpc_url: str,
|
|
||||||
current_slot: int,
|
|
||||||
*,
|
|
||||||
max_snapshot_age: int = 10000,
|
|
||||||
max_latency: float = 500,
|
|
||||||
threads: int = 500,
|
|
||||||
min_download_speed: int = 20,
|
|
||||||
measurement_time: int = 7,
|
|
||||||
max_speed_checks: int = 15,
|
|
||||||
version_filter: str | None = None,
|
|
||||||
) -> list[SnapshotSource]:
|
|
||||||
"""Discover snapshot sources and benchmark download speed.
|
|
||||||
|
|
||||||
Returns sources that meet the minimum speed requirement, sorted by speed.
|
|
||||||
"""
|
|
||||||
sources: list[SnapshotSource] = discover_sources(
|
|
||||||
rpc_url, current_slot,
|
|
||||||
max_age_slots=max_snapshot_age,
|
|
||||||
max_latency_ms=max_latency,
|
|
||||||
threads=threads,
|
|
||||||
version_filter=version_filter,
|
|
||||||
)
|
|
||||||
if not sources:
|
|
||||||
return []
|
|
||||||
|
|
||||||
sources.sort(key=lambda s: s.latency_ms)
|
|
||||||
|
|
||||||
log.info("Benchmarking download speed on top %d sources...", max_speed_checks)
|
|
||||||
fast_sources: list[SnapshotSource] = []
|
|
||||||
checked: int = 0
|
|
||||||
min_speed_bytes: int = min_download_speed * 1024 * 1024
|
|
||||||
|
|
||||||
for source in sources:
|
|
||||||
if checked >= max_speed_checks:
|
|
||||||
break
|
|
||||||
checked += 1
|
|
||||||
|
|
||||||
speed: float = measure_speed(source.rpc_address, measurement_time)
|
|
||||||
source.download_speed = speed
|
|
||||||
speed_mib: float = speed / (1024 ** 2)
|
|
||||||
|
|
||||||
if speed < min_speed_bytes:
|
|
||||||
log.info(" %s: %.1f MiB/s (too slow, need >=%d MiB/s)",
|
|
||||||
source.rpc_address, speed_mib, min_download_speed)
|
|
||||||
continue
|
|
||||||
|
|
||||||
log.info(" %s: %.1f MiB/s (latency: %.0fms, age: %d slots)",
|
|
||||||
source.rpc_address, speed_mib,
|
|
||||||
source.latency_ms, source.slots_diff)
|
|
||||||
fast_sources.append(source)
|
|
||||||
|
|
||||||
return fast_sources
|
|
||||||
|
|
||||||
|
|
||||||
def _rolling_incremental_download(
|
|
||||||
fast_sources: list[SnapshotSource],
|
|
||||||
full_snap_slot: int,
|
|
||||||
output_dir: str,
|
|
||||||
convergence_slots: int,
|
|
||||||
connections: int,
|
|
||||||
rpc_url: str,
|
|
||||||
) -> str | None:
|
|
||||||
"""Download incrementals in a loop until converged.
|
|
||||||
|
|
||||||
Probes fast_sources for incrementals matching full_snap_slot, downloads
|
|
||||||
the freshest one, then re-probes until the gap to head is within
|
|
||||||
convergence_slots. Returns the filename of the final incremental,
|
|
||||||
or None if no incremental was found.
|
|
||||||
"""
|
|
||||||
prev_inc_filename: str | None = None
|
|
||||||
loop_start: float = time.monotonic()
|
|
||||||
max_convergence_time: float = 1800.0 # 30 min wall-clock limit
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if time.monotonic() - loop_start > max_convergence_time:
|
|
||||||
if prev_inc_filename:
|
|
||||||
log.warning("Convergence timeout (%.0fs) — using %s",
|
|
||||||
max_convergence_time, prev_inc_filename)
|
|
||||||
else:
|
|
||||||
log.warning("Convergence timeout (%.0fs) — no incremental downloaded",
|
|
||||||
max_convergence_time)
|
|
||||||
break
|
|
||||||
|
|
||||||
inc_fn, inc_mirrors = probe_incremental(fast_sources, full_snap_slot)
|
|
||||||
if inc_fn is None:
|
|
||||||
if prev_inc_filename is None:
|
|
||||||
log.error("No matching incremental found for base slot %d",
|
|
||||||
full_snap_slot)
|
|
||||||
else:
|
|
||||||
log.info("No newer incremental available, using %s", prev_inc_filename)
|
|
||||||
break
|
|
||||||
|
|
||||||
m_inc: re.Match[str] | None = INCR_SNAP_RE.match(inc_fn)
|
|
||||||
assert m_inc is not None
|
|
||||||
inc_slot: int = int(m_inc.group(2))
|
|
||||||
|
|
||||||
head_slot: int | None = get_current_slot(rpc_url)
|
|
||||||
if head_slot is None:
|
|
||||||
log.warning("Cannot get current slot — downloading best available incremental")
|
|
||||||
gap: int = convergence_slots + 1
|
|
||||||
else:
|
|
||||||
gap = head_slot - inc_slot
|
|
||||||
|
|
||||||
if inc_fn == prev_inc_filename:
|
|
||||||
if gap <= convergence_slots:
|
|
||||||
log.info("Incremental %s already downloaded (gap %d slots, converged)",
|
|
||||||
inc_fn, gap)
|
|
||||||
break
|
|
||||||
log.info("No newer incremental yet (slot %d, gap %d slots), waiting...",
|
|
||||||
inc_slot, gap)
|
|
||||||
time.sleep(10)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if prev_inc_filename is not None:
|
|
||||||
old_path: Path = Path(output_dir) / prev_inc_filename
|
|
||||||
if old_path.exists():
|
|
||||||
log.info("Removing superseded incremental %s", prev_inc_filename)
|
|
||||||
old_path.unlink()
|
|
||||||
|
|
||||||
log.info("Downloading incremental %s (%d mirrors, slot %d, gap %d slots)",
|
|
||||||
inc_fn, len(inc_mirrors), inc_slot, gap)
|
|
||||||
if not download_aria2c(inc_mirrors, output_dir, inc_fn, connections):
|
|
||||||
log.warning("Failed to download incremental %s — re-probing in 10s", inc_fn)
|
|
||||||
time.sleep(10)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev_inc_filename = inc_fn
|
|
||||||
|
|
||||||
if gap <= convergence_slots:
|
|
||||||
log.info("Converged: incremental slot %d is %d slots behind head",
|
|
||||||
inc_slot, gap)
|
|
||||||
break
|
|
||||||
|
|
||||||
if head_slot is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
log.info("Not converged (gap %d > %d), re-probing in 10s...",
|
|
||||||
gap, convergence_slots)
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
return prev_inc_filename
|
|
||||||
|
|
||||||
|
|
||||||
# -- Public API ----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def download_incremental_for_slot(
|
|
||||||
output_dir: str,
|
|
||||||
full_snap_slot: int,
|
|
||||||
*,
|
|
||||||
cluster: str = "mainnet-beta",
|
|
||||||
rpc_url: str | None = None,
|
|
||||||
connections: int = 16,
|
|
||||||
threads: int = 500,
|
|
||||||
max_snapshot_age: int = 10000,
|
|
||||||
max_latency: float = 500,
|
|
||||||
min_download_speed: int = 20,
|
|
||||||
measurement_time: int = 7,
|
|
||||||
max_speed_checks: int = 15,
|
|
||||||
version_filter: str | None = None,
|
|
||||||
convergence_slots: int = 500,
|
|
||||||
) -> bool:
|
|
||||||
"""Download an incremental snapshot for an existing full snapshot.
|
|
||||||
|
|
||||||
Discovers sources, benchmarks speed, then runs the rolling incremental
|
|
||||||
download loop for the given full snapshot base slot. Does NOT download
|
|
||||||
a full snapshot.
|
|
||||||
|
|
||||||
Returns True if an incremental was downloaded, False otherwise.
|
|
||||||
"""
|
|
||||||
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
|
|
||||||
|
|
||||||
if not shutil.which("aria2c"):
|
|
||||||
log.error("aria2c not found. Install with: apt install aria2")
|
|
||||||
return False
|
|
||||||
|
|
||||||
log.info("Incremental download for base slot %d", full_snap_slot)
|
|
||||||
current_slot: int | None = get_current_slot(resolved_rpc)
|
|
||||||
if current_slot is None:
|
|
||||||
log.error("Cannot get current slot from %s", resolved_rpc)
|
|
||||||
return False
|
|
||||||
|
|
||||||
fast_sources: list[SnapshotSource] = _discover_and_benchmark(
|
|
||||||
resolved_rpc, current_slot,
|
|
||||||
max_snapshot_age=max_snapshot_age,
|
|
||||||
max_latency=max_latency,
|
|
||||||
threads=threads,
|
|
||||||
min_download_speed=min_download_speed,
|
|
||||||
measurement_time=measurement_time,
|
|
||||||
max_speed_checks=max_speed_checks,
|
|
||||||
version_filter=version_filter,
|
|
||||||
)
|
|
||||||
if not fast_sources:
|
|
||||||
log.error("No fast sources found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
result: str | None = _rolling_incremental_download(
|
|
||||||
fast_sources, full_snap_slot, output_dir,
|
|
||||||
convergence_slots, connections, resolved_rpc,
|
|
||||||
)
|
|
||||||
return result is not None
|
|
||||||
|
|
||||||
|
|
||||||
def download_best_snapshot(
|
|
||||||
output_dir: str,
|
|
||||||
*,
|
|
||||||
cluster: str = "mainnet-beta",
|
|
||||||
rpc_url: str | None = None,
|
|
||||||
connections: int = 16,
|
|
||||||
threads: int = 500,
|
|
||||||
max_snapshot_age: int = 10000,
|
|
||||||
max_latency: float = 500,
|
|
||||||
min_download_speed: int = 20,
|
|
||||||
measurement_time: int = 7,
|
|
||||||
max_speed_checks: int = 15,
|
|
||||||
version_filter: str | None = None,
|
|
||||||
full_only: bool = False,
|
|
||||||
convergence_slots: int = 500,
|
|
||||||
) -> bool:
|
|
||||||
"""Download the best available snapshot to output_dir.
|
|
||||||
|
|
||||||
This is the programmatic API — called by entrypoint.py for automatic
|
|
||||||
snapshot download. Returns True on success, False on failure.
|
|
||||||
|
|
||||||
All parameters have sensible defaults matching the CLI interface.
|
|
||||||
"""
|
|
||||||
resolved_rpc: str = rpc_url or CLUSTER_RPC[cluster]
|
|
||||||
|
|
||||||
if not shutil.which("aria2c"):
|
|
||||||
log.error("aria2c not found. Install with: apt install aria2")
|
|
||||||
return False
|
|
||||||
|
|
||||||
log.info("Cluster: %s | RPC: %s", cluster, resolved_rpc)
|
|
||||||
current_slot: int | None = get_current_slot(resolved_rpc)
|
|
||||||
if current_slot is None:
|
|
||||||
log.error("Cannot get current slot from %s", resolved_rpc)
|
|
||||||
return False
|
|
||||||
log.info("Current slot: %d", current_slot)
|
|
||||||
|
|
||||||
fast_sources: list[SnapshotSource] = _discover_and_benchmark(
|
|
||||||
resolved_rpc, current_slot,
|
|
||||||
max_snapshot_age=max_snapshot_age,
|
|
||||||
max_latency=max_latency,
|
|
||||||
threads=threads,
|
|
||||||
min_download_speed=min_download_speed,
|
|
||||||
measurement_time=measurement_time,
|
|
||||||
max_speed_checks=max_speed_checks,
|
|
||||||
version_filter=version_filter,
|
|
||||||
)
|
|
||||||
if not fast_sources:
|
|
||||||
log.error("No fast sources found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Use the fastest source as primary, build full snapshot download plan
|
|
||||||
best: SnapshotSource = fast_sources[0]
|
|
||||||
full_paths: list[str] = [fp for fp in best.file_paths
|
|
||||||
if fp.rsplit("/", 1)[-1].startswith("snapshot-")]
|
|
||||||
if not full_paths:
|
|
||||||
log.error("Best source has no full snapshot")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Build mirror URLs for the full snapshot
|
|
||||||
full_filename: str = full_paths[0].rsplit("/", 1)[-1]
|
|
||||||
full_mirrors: list[str] = [f"http://{best.rpc_address}{full_paths[0]}"]
|
|
||||||
for other in fast_sources[1:]:
|
|
||||||
for other_fp in other.file_paths:
|
|
||||||
if other_fp.rsplit("/", 1)[-1] == full_filename:
|
|
||||||
full_mirrors.append(f"http://{other.rpc_address}{other_fp}")
|
|
||||||
break
|
|
||||||
|
|
||||||
speed_mib: float = best.download_speed / (1024 ** 2)
|
|
||||||
log.info("Best source: %s (%.1f MiB/s), %d mirrors",
|
|
||||||
best.rpc_address, speed_mib, len(full_mirrors))
|
|
||||||
|
|
||||||
# Download full snapshot
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
total_start: float = time.monotonic()
|
|
||||||
|
|
||||||
filepath: Path = Path(output_dir) / full_filename
|
|
||||||
if filepath.exists() and filepath.stat().st_size > 0:
|
|
||||||
log.info("Skipping %s (already exists: %.1f GB)",
|
|
||||||
full_filename, filepath.stat().st_size / (1024 ** 3))
|
|
||||||
else:
|
|
||||||
if not download_aria2c(full_mirrors, output_dir, full_filename, connections):
|
|
||||||
log.error("Failed to download %s", full_filename)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Download incremental separately — the full download took minutes,
|
|
||||||
# so any incremental from discovery is stale. Re-probe for fresh ones.
|
|
||||||
if not full_only:
|
|
||||||
fm: re.Match[str] | None = FULL_SNAP_RE.match(full_filename)
|
|
||||||
if fm:
|
|
||||||
full_snap_slot: int = int(fm.group(1))
|
|
||||||
log.info("Downloading incremental for base slot %d...", full_snap_slot)
|
|
||||||
_rolling_incremental_download(
|
|
||||||
fast_sources, full_snap_slot, output_dir,
|
|
||||||
convergence_slots, connections, resolved_rpc,
|
|
||||||
)
|
|
||||||
|
|
||||||
total_elapsed: float = time.monotonic() - total_start
|
|
||||||
log.info("All downloads complete in %.0fs", total_elapsed)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# -- Main (CLI) ----------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
p: argparse.ArgumentParser = argparse.ArgumentParser(
|
|
||||||
description="Download Solana snapshots with aria2c parallel downloads",
|
|
||||||
)
|
|
||||||
p.add_argument("-o", "--output", default="/srv/kind/solana/snapshots",
|
|
||||||
help="Snapshot output directory (default: /srv/kind/solana/snapshots)")
|
|
||||||
p.add_argument("-c", "--cluster", default="mainnet-beta",
|
|
||||||
choices=list(CLUSTER_RPC),
|
|
||||||
help="Solana cluster (default: mainnet-beta)")
|
|
||||||
p.add_argument("-r", "--rpc", default=None,
|
|
||||||
help="RPC URL for cluster discovery (default: public RPC)")
|
|
||||||
p.add_argument("-n", "--connections", type=int, default=16,
|
|
||||||
help="aria2c connections per download (default: 16)")
|
|
||||||
p.add_argument("-t", "--threads", type=int, default=500,
|
|
||||||
help="Threads for parallel RPC probing (default: 500)")
|
|
||||||
p.add_argument("--max-snapshot-age", type=int, default=10000,
|
|
||||||
help="Max snapshot age in slots (default: 10000)")
|
|
||||||
p.add_argument("--max-latency", type=float, default=500,
|
|
||||||
help="Max RPC probe latency in ms (default: 500)")
|
|
||||||
p.add_argument("--min-download-speed", type=int, default=20,
|
|
||||||
help="Min download speed in MiB/s (default: 20)")
|
|
||||||
p.add_argument("--measurement-time", type=int, default=7,
|
|
||||||
help="Speed measurement duration in seconds (default: 7)")
|
|
||||||
p.add_argument("--max-speed-checks", type=int, default=15,
|
|
||||||
help="Max nodes to benchmark before giving up (default: 15)")
|
|
||||||
p.add_argument("--version", default=None,
|
|
||||||
help="Filter nodes by version prefix (e.g. '2.2')")
|
|
||||||
p.add_argument("--convergence-slots", type=int, default=500,
|
|
||||||
help="Max slot gap for incremental convergence (default: 500)")
|
|
||||||
p.add_argument("--full-only", action="store_true",
|
|
||||||
help="Download only full snapshot, skip incremental")
|
|
||||||
p.add_argument("--dry-run", action="store_true",
|
|
||||||
help="Find best source and print URL, don't download")
|
|
||||||
p.add_argument("--post-cmd",
|
|
||||||
help="Shell command to run after successful download "
|
|
||||||
"(e.g. 'kubectl scale deployment ... --replicas=1')")
|
|
||||||
p.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
args: argparse.Namespace = p.parse_args()
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dry-run uses the original inline flow (needs access to sources for URL printing)
|
|
||||||
if args.dry_run:
|
|
||||||
rpc_url: str = args.rpc or CLUSTER_RPC[args.cluster]
|
|
||||||
current_slot: int | None = get_current_slot(rpc_url)
|
|
||||||
if current_slot is None:
|
|
||||||
log.error("Cannot get current slot from %s", rpc_url)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
sources: list[SnapshotSource] = discover_sources(
|
|
||||||
rpc_url, current_slot,
|
|
||||||
max_age_slots=args.max_snapshot_age,
|
|
||||||
max_latency_ms=args.max_latency,
|
|
||||||
threads=args.threads,
|
|
||||||
version_filter=args.version,
|
|
||||||
)
|
|
||||||
if not sources:
|
|
||||||
log.error("No snapshot sources found")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
sources.sort(key=lambda s: s.latency_ms)
|
|
||||||
best = sources[0]
|
|
||||||
for fp in best.file_paths:
|
|
||||||
print(f"http://{best.rpc_address}{fp}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
ok: bool = download_best_snapshot(
|
|
||||||
args.output,
|
|
||||||
cluster=args.cluster,
|
|
||||||
rpc_url=args.rpc,
|
|
||||||
connections=args.connections,
|
|
||||||
threads=args.threads,
|
|
||||||
max_snapshot_age=args.max_snapshot_age,
|
|
||||||
max_latency=args.max_latency,
|
|
||||||
min_download_speed=args.min_download_speed,
|
|
||||||
measurement_time=args.measurement_time,
|
|
||||||
max_speed_checks=args.max_speed_checks,
|
|
||||||
version_filter=args.version,
|
|
||||||
full_only=args.full_only,
|
|
||||||
convergence_slots=args.convergence_slots,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ok and args.post_cmd:
|
|
||||||
log.info("Running post-download command: %s", args.post_cmd)
|
|
||||||
result: subprocess.CompletedProcess[bytes] = subprocess.run(
|
|
||||||
args.post_cmd, shell=True,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
log.error("Post-download command failed with exit code %d",
|
|
||||||
result.returncode)
|
|
||||||
return 1
|
|
||||||
log.info("Post-download command completed successfully")
|
|
||||||
|
|
||||||
return 0 if ok else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -1,112 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Start solana-test-validator with optional SPL token setup
|
|
||||||
#
|
|
||||||
# Environment variables:
|
|
||||||
# FACILITATOR_PUBKEY - facilitator fee-payer public key (base58)
|
|
||||||
# SERVER_PUBKEY - server/payee wallet public key (base58)
|
|
||||||
# CLIENT_PUBKEY - client/payer wallet public key (base58)
|
|
||||||
# MINT_DECIMALS - token decimals (default: 6, matching USDC)
|
|
||||||
# MINT_AMOUNT - amount to mint to client (default: 1000000000)
|
|
||||||
# LEDGER_DIR - ledger directory (default: /data/ledger)
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
|
|
||||||
LEDGER_DIR="${LEDGER_DIR:-/data/ledger}"
|
|
||||||
MINT_DECIMALS="${MINT_DECIMALS:-6}"
|
|
||||||
MINT_AMOUNT="${MINT_AMOUNT:-1000000000}"
|
|
||||||
SETUP_MARKER="${LEDGER_DIR}/.setup-done"
|
|
||||||
|
|
||||||
sudo chown -R "$(id -u):$(id -g)" "$LEDGER_DIR" 2>/dev/null || true
|
|
||||||
|
|
||||||
# Start test-validator in the background
|
|
||||||
solana-test-validator \
|
|
||||||
--ledger "${LEDGER_DIR}" \
|
|
||||||
--rpc-port 8899 \
|
|
||||||
--bind-address 0.0.0.0 \
|
|
||||||
--quiet &
|
|
||||||
|
|
||||||
VALIDATOR_PID=$!
|
|
||||||
|
|
||||||
# Wait for RPC to become available
|
|
||||||
echo "Waiting for test-validator RPC..."
|
|
||||||
for i in $(seq 1 60); do
|
|
||||||
if solana cluster-version --url http://127.0.0.1:8899 >/dev/null 2>&1; then
|
|
||||||
echo "Test-validator is ready (attempt ${i})"
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
|
|
||||||
solana config set --url http://127.0.0.1:8899
|
|
||||||
|
|
||||||
# Only run setup once (idempotent via marker file)
|
|
||||||
if [ ! -f "${SETUP_MARKER}" ]; then
|
|
||||||
echo "Running first-time setup..."
|
|
||||||
|
|
||||||
# Airdrop SOL to all wallets for gas
|
|
||||||
for PUBKEY in "${FACILITATOR_PUBKEY:-}" "${SERVER_PUBKEY:-}" "${CLIENT_PUBKEY:-}"; do
|
|
||||||
if [ -n "${PUBKEY}" ]; then
|
|
||||||
echo "Airdropping 100 SOL to ${PUBKEY}..."
|
|
||||||
solana airdrop 100 "${PUBKEY}" --url http://127.0.0.1:8899 || true
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Create a USDC-equivalent SPL token mint if any pubkeys are set
|
|
||||||
if [ -n "${CLIENT_PUBKEY:-}" ] || [ -n "${FACILITATOR_PUBKEY:-}" ] || [ -n "${SERVER_PUBKEY:-}" ]; then
|
|
||||||
MINT_AUTHORITY_FILE="${LEDGER_DIR}/mint-authority.json"
|
|
||||||
if [ ! -f "${MINT_AUTHORITY_FILE}" ]; then
|
|
||||||
solana-keygen new --no-bip39-passphrase --outfile "${MINT_AUTHORITY_FILE}" --force
|
|
||||||
MINT_AUTH_PUBKEY=$(solana-keygen pubkey "${MINT_AUTHORITY_FILE}")
|
|
||||||
solana airdrop 10 "${MINT_AUTH_PUBKEY}" --url http://127.0.0.1:8899
|
|
||||||
fi
|
|
||||||
|
|
||||||
MINT_ADDRESS_FILE="${LEDGER_DIR}/usdc-mint-address.txt"
|
|
||||||
if [ ! -f "${MINT_ADDRESS_FILE}" ]; then
|
|
||||||
spl-token create-token \
|
|
||||||
--decimals "${MINT_DECIMALS}" \
|
|
||||||
--mint-authority "${MINT_AUTHORITY_FILE}" \
|
|
||||||
--url http://127.0.0.1:8899 \
|
|
||||||
2>&1 | grep "Creating token" | awk '{print $3}' > "${MINT_ADDRESS_FILE}"
|
|
||||||
echo "Created USDC mint: $(cat "${MINT_ADDRESS_FILE}")"
|
|
||||||
fi
|
|
||||||
|
|
||||||
USDC_MINT=$(cat "${MINT_ADDRESS_FILE}")
|
|
||||||
|
|
||||||
# Create ATAs and mint tokens for the client
|
|
||||||
if [ -n "${CLIENT_PUBKEY:-}" ]; then
|
|
||||||
echo "Creating ATA for client ${CLIENT_PUBKEY}..."
|
|
||||||
spl-token create-account "${USDC_MINT}" \
|
|
||||||
--owner "${CLIENT_PUBKEY}" \
|
|
||||||
--fee-payer "${MINT_AUTHORITY_FILE}" \
|
|
||||||
--url http://127.0.0.1:8899 || true
|
|
||||||
|
|
||||||
echo "Minting ${MINT_AMOUNT} tokens to client..."
|
|
||||||
spl-token mint "${USDC_MINT}" "${MINT_AMOUNT}" \
|
|
||||||
--recipient-owner "${CLIENT_PUBKEY}" \
|
|
||||||
--mint-authority "${MINT_AUTHORITY_FILE}" \
|
|
||||||
--url http://127.0.0.1:8899 || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Create ATAs for server and facilitator
|
|
||||||
for PUBKEY in "${SERVER_PUBKEY:-}" "${FACILITATOR_PUBKEY:-}"; do
|
|
||||||
if [ -n "${PUBKEY}" ]; then
|
|
||||||
echo "Creating ATA for ${PUBKEY}..."
|
|
||||||
spl-token create-account "${USDC_MINT}" \
|
|
||||||
--owner "${PUBKEY}" \
|
|
||||||
--fee-payer "${MINT_AUTHORITY_FILE}" \
|
|
||||||
--url http://127.0.0.1:8899 || true
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Expose mint address for other containers
|
|
||||||
cp "${MINT_ADDRESS_FILE}" /tmp/usdc-mint-address.txt 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
touch "${SETUP_MARKER}"
|
|
||||||
echo "Setup complete."
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "solana-test-validator running (PID ${VALIDATOR_PID})"
|
|
||||||
wait ${VALIDATOR_PID}
|
|
||||||
Loading…
Reference in New Issue