stack-orchestrator/scripts/agave-container/entrypoint.py

486 lines
15 KiB
Python

#!/usr/bin/env python3
"""Agave validator entrypoint — snapshot management, arg construction, liveness probe.
Two subcommands:
entrypoint.py serve (default) — snapshot freshness check + exec agave-validator
entrypoint.py probe — liveness probe (slot lag check, exits 0/1)
Replaces the bash entrypoint.sh / start-rpc.sh / start-validator.sh with a single
Python module. Test mode still dispatches to start-test.sh.
All configuration comes from environment variables — same vars as the original
bash scripts. See compose files for defaults.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
from urllib.request import Request
log: logging.Logger = logging.getLogger("entrypoint")
# Directories
CONFIG_DIR = "/data/config"
LEDGER_DIR = "/data/ledger"
ACCOUNTS_DIR = "/data/accounts"
SNAPSHOTS_DIR = "/data/snapshots"
LOG_DIR = "/data/log"
IDENTITY_FILE = f"{CONFIG_DIR}/validator-identity.json"
# Snapshot filename pattern
FULL_SNAP_RE: re.Pattern[str] = re.compile(
r"^snapshot-(\d+)-[A-Za-z0-9]+\.tar\.(zst|bz2)$"
)
MAINNET_RPC = "https://api.mainnet-beta.solana.com"
# -- Helpers -------------------------------------------------------------------
def env(name: str, default: str = "") -> str:
"""Read env var with default."""
return os.environ.get(name, default)
def env_required(name: str) -> str:
"""Read required env var, exit if missing."""
val = os.environ.get(name)
if not val:
log.error("%s is required but not set", name)
sys.exit(1)
return val
def env_bool(name: str, default: bool = False) -> bool:
"""Read boolean env var (true/false/1/0)."""
val = os.environ.get(name, "").lower()
if not val:
return default
return val in ("true", "1", "yes")
def rpc_get_slot(url: str, timeout: int = 10) -> int | None:
"""Get current slot from a Solana RPC endpoint."""
payload = json.dumps({
"jsonrpc": "2.0", "id": 1,
"method": "getSlot", "params": [],
}).encode()
req = Request(url, data=payload,
headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read())
result = data.get("result")
if isinstance(result, int):
return result
except (urllib.error.URLError, json.JSONDecodeError, OSError, TimeoutError):
pass
return None
# -- Snapshot management -------------------------------------------------------
def get_local_snapshot_slot(snapshots_dir: str) -> int | None:
"""Find the highest slot among local snapshot files."""
best_slot: int | None = None
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return None
for entry in snap_path.iterdir():
m = FULL_SNAP_RE.match(entry.name)
if m:
slot = int(m.group(1))
if best_slot is None or slot > best_slot:
best_slot = slot
return best_slot
def clean_snapshots(snapshots_dir: str) -> None:
"""Remove all snapshot files from the directory."""
snap_path = Path(snapshots_dir)
if not snap_path.is_dir():
return
for entry in snap_path.iterdir():
if entry.name.startswith(("snapshot-", "incremental-snapshot-")):
log.info("Removing old snapshot: %s", entry.name)
entry.unlink(missing_ok=True)
def maybe_download_snapshot(snapshots_dir: str) -> None:
"""Check snapshot freshness and download if needed.
Controlled by env vars:
SNAPSHOT_AUTO_DOWNLOAD (default: true) — enable/disable
SNAPSHOT_MAX_AGE_SLOTS (default: 20000) — staleness threshold
"""
if not env_bool("SNAPSHOT_AUTO_DOWNLOAD", default=True):
log.info("Snapshot auto-download disabled")
return
max_age = int(env("SNAPSHOT_MAX_AGE_SLOTS", "20000"))
# Get mainnet current slot
mainnet_slot = rpc_get_slot(MAINNET_RPC)
if mainnet_slot is None:
log.warning("Cannot reach mainnet RPC — skipping snapshot check")
return
# Check local snapshot
local_slot = get_local_snapshot_slot(snapshots_dir)
if local_slot is not None:
age = mainnet_slot - local_slot
log.info("Local snapshot at slot %d (mainnet: %d, age: %d slots)",
local_slot, mainnet_slot, age)
if age <= max_age:
log.info("Snapshot is fresh enough (age %d <= %d), skipping download", age, max_age)
return
log.info("Snapshot is stale (age %d > %d), downloading fresh", age, max_age)
else:
log.info("No local snapshot found, downloading")
# Clean old snapshots before downloading
clean_snapshots(snapshots_dir)
# Import and call snapshot download
# snapshot_download.py is installed alongside this file in /usr/local/bin/
script_dir = Path(__file__).resolve().parent
sys.path.insert(0, str(script_dir))
from snapshot_download import download_best_snapshot
ok = download_best_snapshot(snapshots_dir)
if not ok:
log.error("Snapshot download failed — starting without fresh snapshot")
# -- Directory and identity setup ----------------------------------------------
def ensure_dirs(*dirs: str) -> None:
"""Create directories and fix ownership."""
uid = os.getuid()
gid = os.getgid()
for d in dirs:
os.makedirs(d, exist_ok=True)
try:
subprocess.run(
["sudo", "chown", "-R", f"{uid}:{gid}", d],
check=False, capture_output=True,
)
except FileNotFoundError:
pass # sudo not available — dirs already owned correctly
def ensure_identity_rpc() -> None:
"""Generate ephemeral identity keypair for RPC mode if not mounted."""
if os.path.isfile(IDENTITY_FILE):
return
log.info("Generating RPC node identity keypair...")
subprocess.run(
["solana-keygen", "new", "--no-passphrase", "--silent",
"--force", "--outfile", IDENTITY_FILE],
check=True,
)
def print_identity() -> None:
"""Print the node identity pubkey."""
result = subprocess.run(
["solana-keygen", "pubkey", IDENTITY_FILE],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Node identity: %s", result.stdout.strip())
# -- Arg construction ----------------------------------------------------------
def build_common_args() -> list[str]:
"""Build agave-validator args common to both RPC and validator modes."""
args: list[str] = [
"--identity", IDENTITY_FILE,
"--entrypoint", env_required("VALIDATOR_ENTRYPOINT"),
"--known-validator", env_required("KNOWN_VALIDATOR"),
"--ledger", LEDGER_DIR,
"--accounts", ACCOUNTS_DIR,
"--snapshots", SNAPSHOTS_DIR,
"--rpc-port", env("RPC_PORT", "8899"),
"--rpc-bind-address", env("RPC_BIND_ADDRESS", "127.0.0.1"),
"--gossip-port", env("GOSSIP_PORT", "8001"),
"--dynamic-port-range", env("DYNAMIC_PORT_RANGE", "9000-10000"),
"--no-os-network-limits-test",
"--wal-recovery-mode", "skip_any_corrupted_record",
"--limit-ledger-size", env("LIMIT_LEDGER_SIZE", "50000000"),
]
# Snapshot generation
if env("NO_SNAPSHOTS") == "true":
args.append("--no-snapshots")
else:
args += [
"--full-snapshot-interval-slots", env("SNAPSHOT_INTERVAL_SLOTS", "100000"),
"--maximum-full-snapshots-to-retain", env("MAXIMUM_SNAPSHOTS_TO_RETAIN", "5"),
]
if env("NO_INCREMENTAL_SNAPSHOTS") != "true":
args += ["--maximum-incremental-snapshots-to-retain", "2"]
# Account indexes
account_indexes = env("ACCOUNT_INDEXES")
if account_indexes:
for idx in account_indexes.split(","):
idx = idx.strip()
if idx:
args += ["--account-index", idx]
# Additional entrypoints
for ep in env("EXTRA_ENTRYPOINTS").split():
if ep:
args += ["--entrypoint", ep]
# Additional known validators
for kv in env("EXTRA_KNOWN_VALIDATORS").split():
if kv:
args += ["--known-validator", kv]
# Cluster verification
genesis_hash = env("EXPECTED_GENESIS_HASH")
if genesis_hash:
args += ["--expected-genesis-hash", genesis_hash]
shred_version = env("EXPECTED_SHRED_VERSION")
if shred_version:
args += ["--expected-shred-version", shred_version]
# Metrics — just needs to be in the environment, agave reads it directly
# (env var is already set, nothing to pass as arg)
# Gossip host / TVU address
gossip_host = env("GOSSIP_HOST")
if gossip_host:
args += ["--gossip-host", gossip_host]
elif env("PUBLIC_TVU_ADDRESS"):
args += ["--public-tvu-address", env("PUBLIC_TVU_ADDRESS")]
# Jito flags
if env("JITO_ENABLE") == "true":
log.info("Jito MEV enabled")
jito_flags: list[tuple[str, str]] = [
("JITO_TIP_PAYMENT_PROGRAM", "--tip-payment-program-pubkey"),
("JITO_DISTRIBUTION_PROGRAM", "--tip-distribution-program-pubkey"),
("JITO_MERKLE_ROOT_AUTHORITY", "--merkle-root-upload-authority"),
("JITO_COMMISSION_BPS", "--commission-bps"),
("JITO_BLOCK_ENGINE_URL", "--block-engine-url"),
("JITO_SHRED_RECEIVER_ADDR", "--shred-receiver-address"),
]
for env_name, flag in jito_flags:
val = env(env_name)
if val:
args += [flag, val]
return args
def build_rpc_args() -> list[str]:
"""Build agave-validator args for RPC (non-voting) mode."""
args = build_common_args()
args += [
"--no-voting",
"--log", f"{LOG_DIR}/validator.log",
"--full-rpc-api",
"--enable-rpc-transaction-history",
"--rpc-pubsub-enable-block-subscription",
"--enable-extended-tx-metadata-storage",
"--no-wait-for-vote-to-start-leader",
"--no-snapshot-fetch",
]
# Public vs private RPC
public_rpc = env("PUBLIC_RPC_ADDRESS")
if public_rpc:
args += ["--public-rpc-address", public_rpc]
else:
args += ["--private-rpc", "--allow-private-addr", "--only-known-rpc"]
# Jito relayer URL (RPC mode doesn't use it, but validator mode does —
# handled in build_validator_args)
return args
def build_validator_args() -> list[str]:
"""Build agave-validator args for voting validator mode."""
vote_keypair = env("VOTE_ACCOUNT_KEYPAIR",
"/data/config/vote-account-keypair.json")
# Identity must be mounted for validator mode
if not os.path.isfile(IDENTITY_FILE):
log.error("Validator identity keypair not found at %s", IDENTITY_FILE)
log.error("Mount your validator keypair to %s", IDENTITY_FILE)
sys.exit(1)
# Vote account keypair must exist
if not os.path.isfile(vote_keypair):
log.error("Vote account keypair not found at %s", vote_keypair)
log.error("Mount your vote account keypair or set VOTE_ACCOUNT_KEYPAIR")
sys.exit(1)
# Print vote account pubkey
result = subprocess.run(
["solana-keygen", "pubkey", vote_keypair],
capture_output=True, text=True, check=False,
)
if result.returncode == 0:
log.info("Vote account: %s", result.stdout.strip())
args = build_common_args()
args += [
"--vote-account", vote_keypair,
"--log", "-",
]
# Jito relayer URL (validator-only)
relayer_url = env("JITO_RELAYER_URL")
if env("JITO_ENABLE") == "true" and relayer_url:
args += ["--relayer-url", relayer_url]
return args
def append_extra_args(args: list[str]) -> list[str]:
"""Append EXTRA_ARGS passthrough flags."""
extra = env("EXTRA_ARGS")
if extra:
args += extra.split()
return args
# -- Serve subcommand ---------------------------------------------------------
def cmd_serve() -> None:
"""Main serve flow: snapshot check, setup, exec agave-validator."""
mode = env("AGAVE_MODE", "test")
log.info("AGAVE_MODE=%s", mode)
# Test mode dispatches to start-test.sh
if mode == "test":
os.execvp("start-test.sh", ["start-test.sh"])
if mode not in ("rpc", "validator"):
log.error("Unknown AGAVE_MODE: %s (valid: test, rpc, validator)", mode)
sys.exit(1)
# Ensure directories
dirs = [CONFIG_DIR, LEDGER_DIR, ACCOUNTS_DIR, SNAPSHOTS_DIR]
if mode == "rpc":
dirs.append(LOG_DIR)
ensure_dirs(*dirs)
# Snapshot freshness check and auto-download
maybe_download_snapshot(SNAPSHOTS_DIR)
# Identity setup
if mode == "rpc":
ensure_identity_rpc()
print_identity()
# Build args
if mode == "rpc":
args = build_rpc_args()
else:
args = build_validator_args()
args = append_extra_args(args)
# Write startup timestamp for probe grace period
Path("/tmp/entrypoint-start").write_text(str(time.time()))
log.info("Starting agave-validator with %d arguments", len(args))
os.execvp("agave-validator", ["agave-validator"] + args)
# -- Probe subcommand ---------------------------------------------------------
def cmd_probe() -> None:
"""Liveness probe: check local RPC slot vs mainnet.
Exit 0 = healthy, exit 1 = unhealthy.
Grace period: PROBE_GRACE_SECONDS (default 600) — probe always passes
during grace period to allow for snapshot unpacking and initial replay.
"""
grace_seconds = int(env("PROBE_GRACE_SECONDS", "600"))
max_lag = int(env("PROBE_MAX_SLOT_LAG", "20000"))
# Check grace period
start_file = Path("/tmp/entrypoint-start")
if start_file.exists():
try:
start_time = float(start_file.read_text().strip())
elapsed = time.time() - start_time
if elapsed < grace_seconds:
# Within grace period — always healthy
sys.exit(0)
except (ValueError, OSError):
pass
else:
# No start file — serve hasn't started yet, within grace
sys.exit(0)
# Query local RPC
rpc_port = env("RPC_PORT", "8899")
local_url = f"http://127.0.0.1:{rpc_port}"
local_slot = rpc_get_slot(local_url, timeout=5)
if local_slot is None:
# Local RPC unreachable after grace period — unhealthy
sys.exit(1)
# Query mainnet
mainnet_slot = rpc_get_slot(MAINNET_RPC, timeout=10)
if mainnet_slot is None:
# Can't reach mainnet to compare — assume healthy (don't penalize
# the validator for mainnet RPC being down)
sys.exit(0)
lag = mainnet_slot - local_slot
if lag > max_lag:
sys.exit(1)
sys.exit(0)
# -- Main ----------------------------------------------------------------------
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
datefmt="%H:%M:%S",
)
subcmd = sys.argv[1] if len(sys.argv) > 1 else "serve"
if subcmd == "serve":
cmd_serve()
elif subcmd == "probe":
cmd_probe()
else:
log.error("Unknown subcommand: %s (valid: serve, probe)", subcmd)
sys.exit(1)
if __name__ == "__main__":
main()