diff --git a/docs/providers/README.md b/docs/providers/README.md index 501039077..2f881d34d 100644 --- a/docs/providers/README.md +++ b/docs/providers/README.md @@ -61,7 +61,7 @@ selection metadata. Regenerate it with `node scripts/generate-provider-matrix.mj `scripts/check-docs.sh` fails when provider registration, metadata, docs paths, or this generated table drift. -Current built-in surface: 67 providers (39 SSH lease, 26 delegated run, 2 service control). +Current built-in surface: 68 providers (40 SSH lease, 26 delegated run, 2 service control). Access terms: @@ -133,6 +133,7 @@ Access terms: | [tenki](tenki.md) | built-in; `ssh-lease` · direct-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync` | `linux`; Tenki sandbox VM | `provider-managed`; GPU: unknown | Tenki; sandbox release | Managed Linux sandbox with SSH proxy | Gateway auth uses Tenki-managed key and certificate files | | [tensorlake](tensorlake.md) (`tl`, `tensorlake-sbx`) | built-in; `delegated-run` · delegated-sandbox | No SSH; `provider-owned` · direct only; features: `run-session` | `linux`; Tensorlake Firecracker sandbox | `provider-managed`; GPU: unknown | Tensorlake; provider sandbox cleanup | Hosted Firecracker-backed delegated execution | Does not expose raw Firecracker provisioning | | [upstash-box](upstash-box.md) (`upstash`, `box`, `upstashbox`) | built-in; `delegated-run` · delegated-sandbox | No SSH; `archive-sync` · direct only; features: `archive-sync`, `run-session` | `linux`; Upstash Box sandbox | `provider-managed`; GPU: no | Upstash; sandbox cleanup | Hosted short-lived delegated sandbox | No normal SSH access or coordinator routing | +| [vast](vast.md) (`vast-ai`, `vastai`) | built-in; `ssh-lease` · gpu-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync`, `cleanup` | `linux`; Vast.ai direct GPU instance | `provider-managed`; GPU: yes | Crabbox; destroy by default; optional stop or keep | Direct Linux GPU lease from the Vast.ai offer market | Direct-only and billable; capacity, quota, and offer availability vary | | [vercel-sandbox](vercel-sandbox.md) | built-in; `delegated-run` · delegated-sandbox | No SSH; `archive-sync` · direct only; features: `archive-sync`, `cleanup`, `run-session` | `linux`; Vercel Sandbox microVM | `provider-managed`; GPU: no | Vercel Sandbox; sandbox delete | Hosted delegated Linux microVM execution | Requires SDK bridge support and Vercel Sandbox auth | | [vultr](vultr.md) | built-in; `ssh-lease` · direct-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync`, `cleanup` | `linux`; Vultr instance | `cloud`; GPU: optional | Crabbox; instance and key delete | Direct Linux VM on Vultr | Direct-only; firewall groups and VPCs must already exist | | [wandb](wandb.md) (`weights-and-biases`) | built-in; `delegated-run` · gpu-cloud | No SSH; `provider-owned` · direct only; features: `run-session` | `linux`; Weights & Biases run sandbox | `provider-managed`; GPU: optional | Weights & Biases; run termination | Delegated ML or GPU run environment | Execution follows the W&B run contract | diff --git a/docs/providers/provider-metadata.json b/docs/providers/provider-metadata.json index d8f001dbe..d1fb8dc94 100644 --- a/docs/providers/provider-metadata.json +++ b/docs/providers/provider-metadata.json @@ -503,6 +503,20 @@ "caveat": "Direct-only; firewall groups and VPCs must already exist", "docs": "vultr.md" }, + "vast": { + "status": "built-in", + "category": "gpu-cloud", + "substrate": "Vast.ai direct GPU instance", + "location": "provider-managed", + "ssh": "crabbox-managed", + "sync": "crabbox-sync", + "gpu": "yes", + "lifecycle": "Crabbox", + "cleanup": "destroy by default; optional stop or keep", + "bestFit": "Direct Linux GPU lease from the Vast.ai offer market", + "caveat": "Direct-only and billable; capacity, quota, and offer availability vary", + "docs": "vast.md" + }, "ovh": { "status": "built-in", "category": "direct-cloud", diff --git a/docs/providers/vast.md b/docs/providers/vast.md new file mode 100644 index 000000000..9e2012ed7 --- /dev/null +++ b/docs/providers/vast.md @@ -0,0 +1,295 @@ +# Vast Provider + +Read this when you are: + +- choosing `provider: vast`; +- validating a direct Vast.ai SSH lease; +- changing `internal/providers/vast` or the guarded live smoke. + +Vast is a Linux-only **SSH lease** provider for Vast.ai GPU instances. Crabbox +searches Vast offers, creates one `ssh_direct` instance from the selected offer, +injects a per-lease SSH key, marks the instance with a compact Crabbox ownership +label, waits for the direct SSH endpoint, and then uses the normal Crabbox SSH +sync/run/status/list/stop/cleanup path. + +Vast is **direct-only** in this release. It does not run through the Crabbox +coordinator, so the local CLI must have a Vast API key and direct cleanup +remains the operator's responsibility. Vast instances are billable while they +are running. The default release action destroys the instance. + +## When To Use It + +Use Vast when you need a direct Linux GPU lease and local Vast credentials are +acceptable. Prefer AWS, Azure, GCP, or Hetzner when you need brokered team +credentials, coordinator-side cost accounting, or non-GPU cloud VM coverage. +Prefer Lambda, Nebius, RunPod, or NVIDIA Brev when those provider catalogs, +images, or account policies are a better fit for the workload. + +## Commands + +```sh +crabbox doctor --provider vast +crabbox warmup --provider vast --vast-gpu-name "RTX 4090" --keep +crabbox run --provider vast --vast-gpu-count 1 --no-sync -- nvidia-smi +crabbox ssh --provider vast --id my-app +crabbox stop --provider vast my-app +crabbox cleanup --provider vast --dry-run +``` + +Aliases: `vast-ai`, `vastai`. + +`--id` accepts the canonical lease id (`cbx_...`), the friendly slug, or the +Vast instance id when it resolves to a complete Crabbox-owned Vast instance. +`--class` and `--type` are not supported for `provider=vast`; use the +Vast-specific GPU, image, and offer-selection flags instead. + +## Configuration + +```yaml +provider: vast +target: linux +vast: + apiUrl: https://console.vast.ai/api/v0 + instanceType: ondemand + gpuName: "" + gpuCount: 0 + image: nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 + templateId: "" + runtype: ssh_direct + diskGB: 20 + maxDphTotal: 0 + minReliability: 0 + order: dlperf_per_dphtotal desc + user: root + workRoot: /work/crabbox + releaseAction: destroy +``` + +Config keys under `vast:`: + +| Key | Maps to | Default | Notes | +| --- | --- | --- | --- | +| `apiUrl` | `cfg.Vast.APIURL` | `https://console.vast.ai/api/v0` | Absolute Vast REST API URL without credentials, query strings, or fragments. HTTPS is required except for localhost test endpoints. | +| `instanceType` | `cfg.Vast.InstanceType` | `ondemand` | Offer type, `ondemand` or `interruptible`; `on-demand` is normalized. | +| `gpuName` | `cfg.Vast.GPUName` | empty | Optional Vast GPU name selector. | +| `gpuCount` | `cfg.Vast.GPUCount` | `0` | Minimum GPU count when greater than zero. | +| `image` | `cfg.Vast.Image` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` | Docker image requested from Vast for the instance. | +| `templateId` | `cfg.Vast.TemplateID` | empty | Optional Vast template id. | +| `runtype` | `cfg.Vast.Runtype` | `ssh_direct` | Only `ssh_direct` is supported. | +| `diskGB` | `cfg.Vast.DiskGB` | `20` | Requested disk size in GB. | +| `maxDphTotal` | `cfg.Vast.MaxDphTotal` | `0` | Maximum dollars per hour when greater than zero. | +| `minReliability` | `cfg.Vast.MinReliability` | `0` | Minimum reliability score from 0 to 1 when greater than zero. | +| `order` | `cfg.Vast.Order` | `dlperf_per_dphtotal desc` | Vast offer ordering expression. | +| `user` | `cfg.Vast.User` | `root` | SSH user. Explicit generic `ssh.user` still wins. | +| `workRoot` | `cfg.Vast.WorkRoot` | `/work/crabbox` | Remote work root for Crabbox sync and commands. | +| `releaseAction` | `cfg.Vast.ReleaseAction` | `destroy` | `destroy`/`delete`, `stop`, or `keep`. | + +Provider flags: + +```text +--vast-api-url +--vast-instance-type +--vast-gpu-name +--vast-gpu-count +--vast-image +--vast-template-id +--vast-runtype +--vast-disk-gb +--vast-max-dph-total +--vast-min-reliability +--vast-order +--vast-user +--vast-work-root +--vast-release-action +``` + +Environment overrides: + +```text +CRABBOX_VAST_API_KEY Vast API key for direct mode +VAST_API_KEY Fallback Vast API key +CRABBOX_VAST_API_URL Override the API URL +VAST_API_URL Fallback API URL override +CRABBOX_VAST_INSTANCE_TYPE Override `ondemand` or `interruptible` +CRABBOX_VAST_GPU_NAME Override the GPU name selector +CRABBOX_VAST_GPU_COUNT Override the minimum GPU count +CRABBOX_VAST_IMAGE Override the Docker image +CRABBOX_VAST_TEMPLATE_ID Override the template id +CRABBOX_VAST_RUNTYPE Override the runtime type; must be `ssh_direct` +CRABBOX_VAST_DISK_GB Override disk size in GB +CRABBOX_VAST_MAX_DPH_TOTAL Override maximum dollars per hour +CRABBOX_VAST_MIN_RELIABILITY Override minimum reliability score +CRABBOX_VAST_ORDER Override offer ordering +CRABBOX_VAST_USER Override the SSH user +CRABBOX_VAST_WORK_ROOT Override the remote work root +CRABBOX_VAST_RELEASE_ACTION Override release action +``` + +Do not pass the Vast API key as a command-line argument or store it in +repository config. Crabbox reads it from `CRABBOX_VAST_API_KEY` or +`VAST_API_KEY` and sends it only in the `Authorization: Bearer ...` header. + +## Token Scope + +The provider uses Vast account identity, offer search, instances, instance +state updates, instance destroy, and instance SSH-key attach/detach APIs. +`crabbox doctor --provider vast` is read-only: it checks auth, lists instances, +counts Crabbox-owned Vast instances, and reports the default order, runtime, and +SSH user. + +Keep API keys in the environment or a local secret manager. Do not commit Vast +keys, generated private keys, instance API keys, user data, or Jupyter token +URLs. + +## Lifecycle + +1. Load the Vast API key from `CRABBOX_VAST_API_KEY` or `VAST_API_KEY`. +2. List instances and allocate a Crabbox slug. +3. Generate a per-lease SSH key in the Crabbox testbox key store. +4. Search Vast offers with the configured type, GPU name/count, reliability, + max dollars per hour, and ordering. +5. Create one `ssh_direct` Vast instance from the selected offer, with the + configured image, template, disk, user, and Crabbox environment marker. +6. Attach the per-lease public SSH key to the instance. +7. Wait until the instance is running and exposes a direct SSH host and port. +8. Update the Vast label from provisioning to ready. +9. Wait for Crabbox SSH bootstrap readiness and write a local lease claim. +10. Run normal Crabbox SSH sync, command execution, status, list, and cleanup. + +The provider requires Linux. It does not advertise desktop, browser, code-server, +Tailscale, coordinator, or provider-managed sync support in this release. +Actions hydration works only as normal command execution on the resulting Linux +SSH lease. + +If create or bootstrap becomes indeterminate after a Vast instance id is known, +Crabbox records a local recovery claim when possible. Retry +`crabbox stop --provider vast ` before deleting resources +manually so Crabbox can reconcile the instance and local key material. + +## Offer Selection + +By default Crabbox searches `ondemand` verified, rentable, not-rented offers +with at least one direct SSH port and orders by `dlperf_per_dphtotal desc`. +Narrow the search when the default catalog is too broad: + +```sh +crabbox run \ + --provider vast \ + --vast-instance-type interruptible \ + --vast-gpu-name "H100" \ + --vast-gpu-count 1 \ + --vast-max-dph-total 4.25 \ + --vast-min-reliability 0.95 \ + --no-sync \ + -- nvidia-smi +``` + +`vast.maxDphTotal` and `--vast-max-dph-total` are guardrails for offer search, +not a billing cap enforced by Crabbox after provisioning. Review the selected +offer and Vast account billing before running long jobs. + +## Release And Cleanup + +The default release action is `destroy`, which deletes the Vast instance, +detaches the Crabbox-managed instance SSH key when its key id is known, removes +the local claim, and removes the local per-lease key. + +Release actions: + +- `destroy` or `delete`: destroy the Vast instance on `stop` or one-shot release. +- `stop`: request Vast to stop the instance and keep the local Crabbox claim + with `state=stopped` so later status, cleanup, or explicit destroy can + reconcile the retained resource. +- `keep`: leave the instance and local claim untouched during release. + +Use `stop` or `keep` only when you explicitly accept the retained resource and +its billing implications. Direct mode has no coordinator alarm. + +Cleanup only mutates instances with a complete Crabbox Vast ownership label and +a matching local claim: + +```sh +crabbox list --provider vast --json +crabbox cleanup --provider vast --dry-run +crabbox cleanup --provider vast +``` + +Crabbox refuses to operate on non-Crabbox Vast instances, changed ownership +labels, stale local claims, missing local claims for destructive release, or +instances whose provider identity no longer matches the local claim. Vast labels +are compact strings beginning with `cbx1|`; they encode the lease id, slug, and +state. + +## Guarded Live Smoke + +The repeatable live check is opt-in and billable: + +```sh +CRABBOX_LIVE=1 CRABBOX_LIVE_PROVIDERS=vast scripts/live-vast-smoke.sh +``` + +The script builds `bin/crabbox`, reads `CRABBOX_VAST_API_KEY` or +`VAST_API_KEY`, requires an empty Crabbox-owned Vast inventory, creates one kept +lease, waits for ready status, runs `nvidia-smi`, verifies `list --json`, stops +the lease, runs dry-run cleanup, and verifies the Crabbox-owned inventory is +empty afterward. + +Optional live-smoke overrides: + +```text +CRABBOX_LIVE_VAST_GPU_NAME GPU name selector, default empty +CRABBOX_LIVE_VAST_GPU_COUNT Minimum GPU count, default 1 +CRABBOX_LIVE_VAST_MAX_DPH_TOTAL Max dollars per hour for offer search, default 0 +CRABBOX_LIVE_VAST_INSTANCE_TYPE Offer type, default ondemand +CRABBOX_LIVE_VAST_IMAGE Docker image, default from provider config +CRABBOX_LIVE_VAST_RELEASE_ACTION Release action, default destroy +``` + +Final classifications include: + +```text +classification=live_vast_smoke_passed +classification=environment_blocked +classification=billing_blocked +classification=quota_blocked +classification=capacity_blocked +classification=validation_failed +classification=cleanup_failed +``` + +Missing opt-in flags, missing credentials, auth failures, disabled account API +access, missing GPU offers, billing blocks, quota blocks, and capacity blocks +are reported as classified outcomes. The script redacts `VAST_API_KEY`, +`CRABBOX_VAST_API_KEY`, `instance_api_key`, Jupyter tokens, user data, private +keys, and URLs carrying token-like query parameters from diagnostics. + +If cleanup fails, use the reported slug and Vast instance id with +`crabbox list --provider vast --json`, `crabbox stop --provider vast `, +and the Vast console. Do not delete unrelated instances that only look similar. + +## Capabilities + +- **OS targets**: Linux only. +- **SSH**: yes, Crabbox-managed SSH over Vast direct SSH endpoints. +- **Crabbox sync**: yes, rsync over SSH. +- **Provider-managed sync**: no. +- **GPU**: yes, provider catalog dependent. +- **Coordinator**: no; direct CLI only. +- **Cleanup**: yes, ownership-label and local-claim guarded. +- **Desktop / browser / code-server**: not advertised in this release. +- **Tailscale**: not advertised in this release. + +## Gotchas + +- Vast is direct-only. Coordinator secrets, usage limits, and cost accounting do + not cover these instances. +- Vast offers are capacity-sensitive. No matching offer, quota, billing, or + capacity failures are external blockers, not docs-check failures. +- `ssh_direct` is required. Other Vast runtime types are rejected by config + validation. +- The default image is CUDA-oriented. If it lacks workload dependencies, install + them in your repo setup or select a different image/template. +- `stop` and `keep` can retain billable resources. Use `destroy` for normal + one-shot Crabbox validation. +- Destructive release requires local Crabbox claim state. Keep the claim until + cleanup is complete. diff --git a/docs/source-map.md b/docs/source-map.md index 515e11327..f9df46484 100644 --- a/docs/source-map.md +++ b/docs/source-map.md @@ -88,9 +88,10 @@ SSH-lease providers: - Canonical Multipass local Ubuntu VM: `internal/providers/multipass` - Cirrus Labs tart local macOS VM: `internal/providers/tart` - Microsoft Hyper-V local Windows VM: `internal/providers/hyperv` -- Daytona, Morph, exe.dev, KubeVirt, External, Tenki, Namespace devbox, RunPod, Semaphore, Sprites, Lambda: +- Daytona, Morph, exe.dev, KubeVirt, External, Tenki, Namespace devbox, RunPod, Semaphore, Sprites, Lambda, Vast: `internal/providers/daytona`, `internal/providers/morph`, `internal/providers/exedev`, `internal/providers/kubevirt`, `internal/providers/external`, `internal/providers/tenki`, `internal/providers/namespace`, - `internal/providers/runpod`, `internal/providers/semaphore`, `internal/providers/sprites`, `internal/providers/lambda` + `internal/providers/runpod`, `internal/providers/semaphore`, `internal/providers/sprites`, `internal/providers/lambda`, + `internal/providers/vast`, live smoke `scripts/live-vast-smoke.sh` Delegated-run providers (no SSH lease): @@ -148,7 +149,7 @@ Actions hydration or repo scripts. Provider docs: - Per-provider feature notes: `docs/features/aws.md`, `docs/features/azure.md`, `docs/features/hetzner.md`, `docs/features/blacksmith-testbox.md`, `docs/features/namespace-devbox.md`, `docs/features/namespace-devbox-setup.md`, `docs/features/semaphore.md`, `docs/features/sprites.md`, `docs/features/daytona.md`, `docs/features/islo.md`, `docs/features/e2b.md` -- Per-provider reference: `docs/providers/README.md` plus one file per provider under `docs/providers/`, including `docs/providers/aws-lambda-microvm.md` for Lambda MicroVM delegated execution and its runner image, `docs/providers/lambda.md` for the direct Lambda GPU SSH lease provider, `docs/providers/blaxel.md` for delegated Blaxel sandbox execution, `docs/providers/apple-vz.md` for the local Apple Silicon `Virtualization.framework` path, `docs/providers/digitalocean.md` for the direct Droplet provider, `docs/providers/vultr.md` for the direct Vultr provider, `docs/providers/ovh.md` for the direct OVHcloud provider, `docs/providers/incus.md` for the separate local live validation contract, `docs/providers/superserve.md` for delegated Superserve execution and live proof, and `docs/providers/cloudflare-sandbox.md` for Cloudflare Sandbox bridge-backed delegated Linux execution +- Per-provider reference: `docs/providers/README.md` plus one file per provider under `docs/providers/`, including `docs/providers/aws-lambda-microvm.md` for Lambda MicroVM delegated execution and its runner image, `docs/providers/lambda.md` for the direct Lambda GPU SSH lease provider, `docs/providers/vast.md` for the direct Vast.ai GPU SSH lease provider, `docs/providers/blaxel.md` for delegated Blaxel sandbox execution, `docs/providers/apple-vz.md` for the local Apple Silicon `Virtualization.framework` path, `docs/providers/digitalocean.md` for the direct Droplet provider, `docs/providers/vultr.md` for the direct Vultr provider, `docs/providers/ovh.md` for the direct OVHcloud provider, `docs/providers/incus.md` for the separate local live validation contract, `docs/providers/superserve.md` for delegated Superserve execution and live proof, and `docs/providers/cloudflare-sandbox.md` for Cloudflare Sandbox bridge-backed delegated Linux execution - Provider selection, landscape, live-smoke, and backend authoring guide: `docs/features/provider-selection.md`, `docs/features/provider-landscape.md`, `docs/features/provider-live-smoke.md`, `docs/provider-backends.md`, `docs/features/provider-authoring.md` - Tailscale contract: `docs/features/tailscale.md` @@ -236,5 +237,5 @@ Provider docs: - Release workflow and Homebrew tap fallback: `.github/workflows/release.yml` - GoReleaser archives and Homebrew formula config: `.goreleaser.yaml` - Docs command-surface check, link check, site builder, and Pages deploy: `scripts/check-command-docs.mjs`, `scripts/check-docs-links.mjs`, `scripts/build-docs-site.mjs`, `.github/workflows/pages.yml` -- Live provider smoke coverage: `scripts/live-smoke.sh`, plus provider-specific guarded smokes such as `scripts/live-blaxel-smoke.sh`, `scripts/live-digitalocean-smoke.sh`, `scripts/live-vultr-smoke.sh`, and `scripts/live-superserve-smoke.sh` +- Live provider smoke coverage: `scripts/live-smoke.sh`, plus provider-specific guarded smokes such as `scripts/live-blaxel-smoke.sh`, `scripts/live-digitalocean-smoke.sh`, `scripts/live-vultr-smoke.sh`, `scripts/live-vast-smoke.sh`, and `scripts/live-superserve-smoke.sh` - Live coordinator auth smoke coverage: `scripts/live-auth-smoke.sh` diff --git a/internal/applevzhelper/cli_darwin_arm64_test.go b/internal/applevzhelper/cli_darwin_arm64_test.go index ae34048ec..d4328b245 100644 --- a/internal/applevzhelper/cli_darwin_arm64_test.go +++ b/internal/applevzhelper/cli_darwin_arm64_test.go @@ -311,7 +311,7 @@ exit 23 "--cpus", "2", "--memory-mib", "2048", "--disk-gib", "16", - "--ready-timeout", "5s", + "--ready-timeout", "15s", }, strings.NewReader(`{"image":"test.img"}`), &bytes.Buffer{}, &bytes.Buffer{}) if err == nil || (!strings.Contains(err.Error(), "helper daemon exited before the VM reached running state") && diff --git a/internal/cli/bootstrap_test.go b/internal/cli/bootstrap_test.go index d04567b6e..2e3b566a8 100644 --- a/internal/cli/bootstrap_test.go +++ b/internal/cli/bootstrap_test.go @@ -790,7 +790,7 @@ func TestWindowsWSL2BootstrapCompleteProbeUsesWindowsMarker(t *testing.T) { target := bootstrapTarget target.WindowsMode = windowsModeWSL2 - if !probeWindowsWSL2BootstrapComplete(context.Background(), bootstrapTarget, &target, 5*time.Second) { + if !probeWindowsWSL2BootstrapComplete(context.Background(), bootstrapTarget, &target, 30*time.Second) { t.Fatal("setup marker probe should pass with fake ssh") } data, err := os.ReadFile(logPath) diff --git a/internal/cli/config.go b/internal/cli/config.go index 888dd3ca5..d31d5dc52 100644 --- a/internal/cli/config.go +++ b/internal/cli/config.go @@ -163,6 +163,8 @@ type Config struct { Railway RailwayConfig FastAPICloud FastAPICloudConfig Runpod RunpodConfig + Vast VastConfig + vastWorkRootExplicit bool NvidiaBrev NvidiaBrevConfig nvidiaBrevWorkRootExplicit bool Hostinger HostingerConfig @@ -605,6 +607,26 @@ type RunpodConfig struct { WorkRoot string } +// VastConfig contains Vast.ai provider settings. APIKey is populated only from +// CRABBOX_VAST_API_KEY / VAST_API_KEY and must not be persisted or printed. +type VastConfig struct { + APIKey string + APIURL string + InstanceType string + GPUName string + GPUCount int + Image string + TemplateID string + Runtype string + DiskGB int + MaxDphTotal float64 + MinReliability float64 + Order string + User string + WorkRoot string + ReleaseAction string +} + // NvidiaBrevConfig is intentionally non-secret. Authentication stays in the // NVIDIA Brev CLI's own credential store and is never accepted as Crabbox // config or argv. @@ -1639,6 +1661,61 @@ func applyProviderConfigDefaults(cfg *Config) error { normalizeTargetConfig(cfg) return validateTargetConfig(*cfg) } + if cfg.Provider == "vast" { + cfg.Vast.InstanceType = normalizeVastInstanceType(cfg.Vast.InstanceType) + if cfg.Vast.APIURL == "" { + cfg.Vast.APIURL = "https://console.vast.ai/api/v0" + } + if cfg.Vast.InstanceType == "" { + cfg.Vast.InstanceType = "ondemand" + } + if cfg.Vast.Image == "" { + cfg.Vast.Image = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" + } + if cfg.Vast.Runtype == "" { + cfg.Vast.Runtype = "ssh_direct" + } + if cfg.Vast.DiskGB == 0 { + cfg.Vast.DiskGB = 20 + } + if cfg.Vast.Order == "" { + cfg.Vast.Order = "dlperf_per_dphtotal desc" + } + if cfg.Vast.User == "" { + cfg.Vast.User = "root" + } + if cfg.Vast.WorkRoot == "" { + cfg.Vast.WorkRoot = defaultPOSIXWorkRoot + } + if cfg.Vast.ReleaseAction == "" { + cfg.Vast.ReleaseAction = "destroy" + } + if !IsTargetExplicit(cfg) { + cfg.TargetOS = targetLinux + } + if cfg.explicitWindowsMode != "" { + cfg.WindowsMode = cfg.explicitWindowsMode + } else { + cfg.WindowsMode = windowsModeNormal + } + if cfg.explicitWorkRoot != "" && !IsVastWorkRootExplicit(cfg) { + cfg.Vast.WorkRoot = cfg.explicitWorkRoot + } + cfg.WorkRoot = cfg.Vast.WorkRoot + if cfg.explicitSSHUser != "" { + cfg.SSHUser = cfg.explicitSSHUser + } else { + cfg.SSHUser = cfg.Vast.User + } + if cfg.explicitSSHPort != "" { + cfg.SSHPort = cfg.explicitSSHPort + } else { + cfg.SSHPort = "22" + } + cfg.SSHFallbackPorts = nil + normalizeTargetConfig(cfg) + return validateTargetConfig(*cfg) + } if cfg.Provider == "nebius" { if cfg.Nebius.CLI == "" { cfg.Nebius.CLI = "nebius" @@ -2307,6 +2384,38 @@ func MarkNvidiaBrevWorkRootExplicit(cfg *Config) { cfg.nvidiaBrevWorkRootExplicit = true } +func IsVastWorkRootExplicit(cfg *Config) bool { + return cfg.vastWorkRootExplicit +} + +func MarkVastWorkRootExplicit(cfg *Config) { + cfg.vastWorkRootExplicit = true +} + +func EffectiveVastWorkRoot(cfg Config) string { + workRoot := cfg.Vast.WorkRoot + if !IsVastWorkRootExplicit(&cfg) && (workRoot == "" || workRoot == defaultPOSIXWorkRoot) && cfg.explicitWorkRoot != "" { + return cfg.explicitWorkRoot + } + if workRoot == "" { + return defaultPOSIXWorkRoot + } + return workRoot +} + +func NormalizeVastInstanceType(value string) string { + return normalizeVastInstanceType(value) +} + +func normalizeVastInstanceType(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "on-demand", "on_demand": + return "ondemand" + default: + return strings.ToLower(strings.TrimSpace(value)) + } +} + func EffectiveNvidiaBrevWorkRoot(cfg Config) string { workRoot := cfg.NvidiaBrev.WorkRoot providerDefault := workRoot == "" || workRoot == "/tmp/crabbox" @@ -2541,6 +2650,17 @@ func baseConfig() Config { Image: "runpod/pytorch:2.8.0-py3.11-cuda12.8.1-cudnn-devel-ubuntu22.04", DiskGB: 20, }, + Vast: VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Image: "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04", + Runtype: "ssh_direct", + DiskGB: 20, + Order: "dlperf_per_dphtotal desc", + User: "root", + WorkRoot: defaultPOSIXWorkRoot, + ReleaseAction: "destroy", + }, NvidiaBrev: NvidiaBrevConfig{ CLI: "brev", GPUName: "A100", @@ -2852,6 +2972,7 @@ type fileConfig struct { Railway *fileRailwayConfig `yaml:"railway,omitempty"` FastAPICloud *fileFastAPICloudConfig `yaml:"fastapiCloud,omitempty"` Runpod *fileRunpodConfig `yaml:"runpod,omitempty"` + Vast *fileVastConfig `yaml:"vast,omitempty"` NvidiaBrev *fileNvidiaBrevConfig `yaml:"nvidiaBrev,omitempty"` Hostinger *fileHostingerConfig `yaml:"hostinger,omitempty"` Wandb *fileWandbConfig `yaml:"wandb,omitempty"` @@ -3415,6 +3536,23 @@ type fileRunpodConfig struct { WorkRoot string `yaml:"workRoot,omitempty"` } +type fileVastConfig struct { + APIURL string `yaml:"apiUrl,omitempty"` + InstanceType string `yaml:"instanceType,omitempty"` + GPUName string `yaml:"gpuName,omitempty"` + GPUCount int `yaml:"gpuCount,omitempty"` + Image string `yaml:"image,omitempty"` + TemplateID string `yaml:"templateId,omitempty"` + Runtype string `yaml:"runtype,omitempty"` + DiskGB int `yaml:"diskGB,omitempty"` + MaxDphTotal *float64 `yaml:"maxDphTotal,omitempty"` + MinReliability *float64 `yaml:"minReliability,omitempty"` + Order string `yaml:"order,omitempty"` + User string `yaml:"user,omitempty"` + WorkRoot string `yaml:"workRoot,omitempty"` + ReleaseAction string `yaml:"releaseAction,omitempty"` +} + type fileNvidiaBrevConfig struct { CLI string `yaml:"cli,omitempty"` Org string `yaml:"org,omitempty"` @@ -5470,6 +5608,53 @@ func applyFileConfigWithTrust(cfg *Config, file fileConfig, trusted bool) error cfg.Runpod.WorkRoot = file.Runpod.WorkRoot } } + if file.Vast != nil { + if file.Vast.APIURL != "" { + cfg.Vast.APIURL = file.Vast.APIURL + cfg.credentialProvenance.vastAPIURL = credentialSource + } + if file.Vast.InstanceType != "" { + cfg.Vast.InstanceType = file.Vast.InstanceType + } + if file.Vast.GPUName != "" { + cfg.Vast.GPUName = file.Vast.GPUName + } + if file.Vast.GPUCount != 0 { + cfg.Vast.GPUCount = file.Vast.GPUCount + } + if file.Vast.Image != "" { + cfg.Vast.Image = file.Vast.Image + } + if file.Vast.TemplateID != "" { + cfg.Vast.TemplateID = file.Vast.TemplateID + } + if file.Vast.Runtype != "" { + cfg.Vast.Runtype = file.Vast.Runtype + } + if file.Vast.DiskGB != 0 { + cfg.Vast.DiskGB = file.Vast.DiskGB + } + if file.Vast.MaxDphTotal != nil { + cfg.Vast.MaxDphTotal = *file.Vast.MaxDphTotal + } + if file.Vast.MinReliability != nil { + cfg.Vast.MinReliability = *file.Vast.MinReliability + } + if file.Vast.Order != "" { + cfg.Vast.Order = file.Vast.Order + } + if file.Vast.User != "" { + cfg.Vast.User = file.Vast.User + } + if file.Vast.WorkRoot != "" { + cfg.Vast.WorkRoot = file.Vast.WorkRoot + MarkVastWorkRootExplicit(cfg) + } + if file.Vast.ReleaseAction != "" { + cfg.Vast.ReleaseAction = file.Vast.ReleaseAction + MarkDeleteOnReleaseExplicit(cfg, "vast") + } + } if file.NvidiaBrev != nil { if trusted && file.NvidiaBrev.CLI != "" { cfg.NvidiaBrev.CLI = file.NvidiaBrev.CLI @@ -7379,6 +7564,33 @@ func applyEnv(cfg *Config) error { cfg.Runpod.DiskGB = getenvInt("CRABBOX_RUNPOD_DISK_GB", cfg.Runpod.DiskGB) cfg.Runpod.User = getenv("CRABBOX_RUNPOD_USER", cfg.Runpod.User) cfg.Runpod.WorkRoot = getenv("CRABBOX_RUNPOD_WORK_ROOT", cfg.Runpod.WorkRoot) + if value, ok := firstNonEmptyEnv("CRABBOX_VAST_API_KEY", "VAST_API_KEY"); ok { + cfg.Vast.APIKey = value + cfg.credentialProvenance.vastAPIKey = credentialSourceEnvironment + } + if value, ok := firstNonEmptyEnv("CRABBOX_VAST_API_URL", "VAST_API_URL"); ok { + cfg.Vast.APIURL = value + cfg.credentialProvenance.vastAPIURL = credentialSourceEnvironment + } + cfg.Vast.InstanceType = getenv("CRABBOX_VAST_INSTANCE_TYPE", cfg.Vast.InstanceType) + cfg.Vast.GPUName = getenv("CRABBOX_VAST_GPU_NAME", cfg.Vast.GPUName) + cfg.Vast.GPUCount = getenvInt("CRABBOX_VAST_GPU_COUNT", cfg.Vast.GPUCount) + cfg.Vast.Image = getenv("CRABBOX_VAST_IMAGE", cfg.Vast.Image) + cfg.Vast.TemplateID = getenv("CRABBOX_VAST_TEMPLATE_ID", cfg.Vast.TemplateID) + cfg.Vast.Runtype = getenv("CRABBOX_VAST_RUNTYPE", cfg.Vast.Runtype) + cfg.Vast.DiskGB = getenvInt("CRABBOX_VAST_DISK_GB", cfg.Vast.DiskGB) + cfg.Vast.MaxDphTotal = getenvFloat("CRABBOX_VAST_MAX_DPH_TOTAL", cfg.Vast.MaxDphTotal) + cfg.Vast.MinReliability = getenvFloat("CRABBOX_VAST_MIN_RELIABILITY", cfg.Vast.MinReliability) + cfg.Vast.Order = getenv("CRABBOX_VAST_ORDER", cfg.Vast.Order) + cfg.Vast.User = getenv("CRABBOX_VAST_USER", cfg.Vast.User) + if value := os.Getenv("CRABBOX_VAST_WORK_ROOT"); value != "" { + cfg.Vast.WorkRoot = value + MarkVastWorkRootExplicit(cfg) + } + if value := os.Getenv("CRABBOX_VAST_RELEASE_ACTION"); value != "" { + cfg.Vast.ReleaseAction = value + MarkDeleteOnReleaseExplicit(cfg, "vast") + } cfg.NvidiaBrev.CLI = getenv("CRABBOX_NVIDIA_BREV_CLI", cfg.NvidiaBrev.CLI) cfg.NvidiaBrev.Org = getenv("CRABBOX_NVIDIA_BREV_ORG", cfg.NvidiaBrev.Org) cfg.NvidiaBrev.Type = getenv("CRABBOX_NVIDIA_BREV_TYPE", cfg.NvidiaBrev.Type) diff --git a/internal/cli/config_cmd.go b/internal/cli/config_cmd.go index 52fe8ce37..06ec130db 100644 --- a/internal/cli/config_cmd.go +++ b/internal/cli/config_cmd.go @@ -50,6 +50,7 @@ func (a App) configShow(args []string) error { func effectiveConfigForShow(cfg Config) Config { cfg.Hostinger.WorkRoot = EffectiveHostingerWorkRoot(cfg) + cfg.Vast.WorkRoot = EffectiveVastWorkRoot(cfg) cfg.NvidiaBrev.WorkRoot = EffectiveNvidiaBrevWorkRoot(cfg) if cfg.Provider == "digitalocean" || cfg.Provider == "linode" { base := baseConfig() @@ -98,6 +99,13 @@ func effectiveConfigForShow(cfg Config) Config { cfg.SSHFallbackPorts = nil } switch normalizeProviderName(cfg.Provider) { + case "vast", "vast-ai", "vastai": + cfg.WorkRoot = cfg.Vast.WorkRoot + if !IsSSHUserExplicit(&cfg) { + cfg.SSHUser = cfg.Vast.User + } + cfg.SSHPort = "22" + cfg.SSHFallbackPorts = nil case "nvidia-brev", "brev", "nvidia": cfg.WorkRoot = cfg.NvidiaBrev.WorkRoot } @@ -221,6 +229,23 @@ func configShowView(cfg Config) map[string]any { "user": cfg.NvidiaBrev.User, "workRoot": cfg.NvidiaBrev.WorkRoot, }, + "vast": map[string]any{ + "apiUrl": redactedConfigURLWithoutQuery(cfg.Vast.APIURL), + "auth": tokenState(cfg.Vast.APIKey), + "instanceType": cfg.Vast.InstanceType, + "gpuName": cfg.Vast.GPUName, + "gpuCount": cfg.Vast.GPUCount, + "image": cfg.Vast.Image, + "templateId": cfg.Vast.TemplateID, + "runtype": cfg.Vast.Runtype, + "diskGB": cfg.Vast.DiskGB, + "maxDphTotal": cfg.Vast.MaxDphTotal, + "minReliability": cfg.Vast.MinReliability, + "order": cfg.Vast.Order, + "user": cfg.Vast.User, + "workRoot": cfg.Vast.WorkRoot, + "releaseAction": cfg.Vast.ReleaseAction, + }, "nebius": map[string]any{ "cli": cfg.Nebius.CLI, "auth": "cli", @@ -638,6 +663,7 @@ func writeConfigShowText(w io.Writer, cfg Config) { fmt.Fprintf(w, "vultr region=%s os=%s image=%s snapshot=%s firewall_group=%s vpc_ids=%s ssh_cidrs=%s user_scheme=%s\n", cfg.Vultr.Region, blank(cfg.Vultr.OS, "-"), blank(cfg.Vultr.Image, "-"), blank(cfg.Vultr.Snapshot, "-"), blank(cfg.Vultr.FirewallGroup, "-"), blank(strings.Join(cfg.Vultr.VPCIDs, ","), "-"), blank(strings.Join(cfg.Vultr.SSHCIDRs, ","), "-"), blank(cfg.Vultr.UserScheme, "-")) fmt.Fprintf(w, "linode region=%s image=%s type=%s firewall=%s ssh_cidrs=%s\n", cfg.Linode.Region, cfg.Linode.Image, cfg.Linode.Type, blank(cfg.Linode.FirewallID, "-"), blank(strings.Join(cfg.Linode.SSHCIDRs, ","), "-")) fmt.Fprintf(w, "lambda region=%s type=%s image=%s image_family=%s firewall_ruleset=%s ssh_cidrs=%s filesystems=%s mounts=%d auth=%s\n", cfg.Lambda.Region, cfg.Lambda.Type, blank(cfg.Lambda.Image, "-"), blank(cfg.Lambda.ImageFamily, "-"), blank(cfg.Lambda.FirewallRuleset, "-"), blank(strings.Join(cfg.Lambda.SSHCIDRs, ","), "-"), blank(strings.Join(cfg.Lambda.FilesystemNames, ","), "-"), len(cfg.Lambda.FilesystemMounts), lambdaAuthState()) + fmt.Fprintf(w, "vast api_url=%s instance_type=%s gpu_name=%s gpu_count=%d image=%s template_id=%s runtype=%s disk_gb=%d max_dph_total=%.4g min_reliability=%.4g order=%s user=%s work_root=%s release_action=%s auth=%s\n", blank(redactedConfigURLWithoutQuery(cfg.Vast.APIURL), "-"), blank(cfg.Vast.InstanceType, "-"), blank(cfg.Vast.GPUName, "-"), cfg.Vast.GPUCount, blank(cfg.Vast.Image, "-"), blank(cfg.Vast.TemplateID, "-"), blank(cfg.Vast.Runtype, "-"), cfg.Vast.DiskGB, cfg.Vast.MaxDphTotal, cfg.Vast.MinReliability, blank(cfg.Vast.Order, "-"), blank(cfg.Vast.User, "-"), blank(cfg.Vast.WorkRoot, "-"), blank(cfg.Vast.ReleaseAction, "-"), tokenState(cfg.Vast.APIKey)) fmt.Fprintf(w, "nvidia_brev cli=%s org=%s type=%s gpu_name=%s provider=%s mode=%s launchable=%s startup_script=%s release_action=%s target=%s user=%s work_root=%s auth=cli\n", blank(cfg.NvidiaBrev.CLI, "-"), blank(cfg.NvidiaBrev.Org, "-"), blank(cfg.NvidiaBrev.Type, "-"), blank(cfg.NvidiaBrev.GPUName, "-"), blank(cfg.NvidiaBrev.Provider, "-"), blank(cfg.NvidiaBrev.Mode, "-"), blank(cfg.NvidiaBrev.Launchable, "-"), blank(cfg.NvidiaBrev.StartupScript, "-"), blank(cfg.NvidiaBrev.ReleaseAction, "-"), blank(cfg.NvidiaBrev.Target, "-"), blank(cfg.NvidiaBrev.User, "-"), blank(cfg.NvidiaBrev.WorkRoot, "-")) fmt.Fprintf(w, "nebius cli=%s profile=%s parent_id=%s subnet_id=%s platform=%s preset=%s image_family=%s disk_type=%s disk_size_gib=%d user=%s public_ip=%s security_group_ids=%s service_account_id=%s recovery_policy=%s auth=cli\n", blank(cfg.Nebius.CLI, "-"), blank(cfg.Nebius.Profile, "-"), blank(cfg.Nebius.ParentID, "-"), blank(cfg.Nebius.SubnetID, "-"), blank(cfg.Nebius.Platform, "-"), blank(cfg.Nebius.Preset, "-"), blank(cfg.Nebius.ImageFamily, "-"), blank(cfg.Nebius.DiskType, "-"), cfg.Nebius.DiskSizeGiB, blank(cfg.Nebius.User, "-"), blank(cfg.Nebius.PublicIP, "-"), blank(strings.Join(cfg.Nebius.SecurityGroupIDs, ","), "-"), blank(cfg.Nebius.ServiceAccountID, "-"), blank(cfg.Nebius.RecoveryPolicy, "-")) fmt.Fprintf(w, "hostinger api_url=%s item_id=%s payment_method_id=%s template_id=%s data_center_id=%s hostname_prefix=%s user=%s work_root=%s allow_purchase=%t release_action=%s auth=%s\n", blank(cfg.Hostinger.APIURL, "-"), blank(cfg.Hostinger.ItemID, "-"), blank(cfg.Hostinger.PaymentMethodID, "-"), blank(cfg.Hostinger.TemplateID, "-"), blank(cfg.Hostinger.DataCenterID, "-"), blank(cfg.Hostinger.HostnamePrefix, "-"), blank(cfg.Hostinger.User, "-"), blank(cfg.Hostinger.WorkRoot, "-"), cfg.Hostinger.AllowPurchase, blank(cfg.Hostinger.ReleaseAction, "-"), tokenState(cfg.Hostinger.APIToken)) diff --git a/internal/cli/config_cmd_test.go b/internal/cli/config_cmd_test.go index 970c7500b..f88c447b1 100644 --- a/internal/cli/config_cmd_test.go +++ b/internal/cli/config_cmd_test.go @@ -1179,6 +1179,103 @@ func TestConfigShowIncludesNvidiaBrevWithoutSecretSurface(t *testing.T) { } } +func TestConfigShowIncludesVastWithoutSecretSurface(t *testing.T) { + clearConfigEnv(t) + home := t.TempDir() + configPath := filepath.Join(home, "config.yaml") + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) + t.Setenv("CRABBOX_CONFIG", configPath) + t.Setenv("CRABBOX_VAST_API_KEY", "vast-redaction-fixture-secret") + if err := os.WriteFile(configPath, []byte(`provider: vast +ssh: + user: ubuntu +vast: + apiUrl: https://user:secret@vast.example.test/api/v0?token=hidden + instanceType: on-demand + gpuName: RTX 4090 + gpuCount: 2 + image: nvidia/cuda:vast + templateId: tpl-vast + runtype: ssh_direct + diskGB: 60 + maxDphTotal: 3.5 + minReliability: 0.9 + order: reliability desc + user: root + workRoot: /work/vast + releaseAction: stop +`), 0o600); err != nil { + t.Fatal(err) + } + + var stdout bytes.Buffer + app := App{Stdout: &stdout, Stderr: &bytes.Buffer{}} + if err := app.configShow(nil); err != nil { + t.Fatal(err) + } + text := stdout.String() + if !strings.Contains(text, "vast api_url=https://@vast.example.test/api/v0 instance_type=ondemand gpu_name=RTX 4090 gpu_count=2") || + !strings.Contains(text, "work_root=/work/vast release_action=stop auth=configured") { + t.Fatalf("config show missing vast summary: %q", text) + } + if strings.Contains(text, "vast-redaction-fixture-secret") || strings.Contains(text, "user:secret") || strings.Contains(text, "hidden") { + t.Fatalf("config show text leaked Vast secret: %q", text) + } + + stdout.Reset() + if err := app.configShow([]string{"--json"}); err != nil { + t.Fatal(err) + } + var got struct { + SSHUser string `json:"sshUser"` + SSHPort string `json:"sshPort"` + Vast struct { + APIURL string `json:"apiUrl"` + Auth string `json:"auth"` + InstanceType string `json:"instanceType"` + GPUName string `json:"gpuName"` + GPUCount int `json:"gpuCount"` + Image string `json:"image"` + TemplateID string `json:"templateId"` + Runtype string `json:"runtype"` + DiskGB int `json:"diskGB"` + MaxDphTotal float64 `json:"maxDphTotal"` + MinReliability float64 `json:"minReliability"` + Order string `json:"order"` + User string `json:"user"` + WorkRoot string `json:"workRoot"` + ReleaseAction string `json:"releaseAction"` + } `json:"vast"` + } + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got.Vast.APIURL != "https://@vast.example.test/api/v0" || + got.Vast.Auth != "configured" || + got.Vast.InstanceType != "ondemand" || + got.Vast.GPUName != "RTX 4090" || + got.Vast.GPUCount != 2 || + got.Vast.Image != "nvidia/cuda:vast" || + got.Vast.TemplateID != "tpl-vast" || + got.Vast.Runtype != "ssh_direct" || + got.Vast.DiskGB != 60 || + got.Vast.MaxDphTotal != 3.5 || + got.Vast.MinReliability != 0.9 || + got.Vast.Order != "reliability desc" || + got.Vast.User != "root" || + got.Vast.WorkRoot != "/work/vast" || + got.Vast.ReleaseAction != "stop" { + t.Fatalf("unexpected vast json: %#v", got.Vast) + } + if got.SSHUser != "ubuntu" || got.SSHPort != "22" { + t.Fatalf("unexpected vast ssh json: %#v", got) + } + if strings.Contains(stdout.String(), "vast-redaction-fixture-secret") || strings.Contains(stdout.String(), "user:secret") || strings.Contains(stdout.String(), "hidden") { + t.Fatalf("config show json leaked Vast secret: %q", stdout.String()) + } +} + func TestConfigShowIncludesNebiusWithoutSecretSurface(t *testing.T) { clearConfigEnv(t) home := t.TempDir() diff --git a/internal/cli/config_test.go b/internal/cli/config_test.go index 6ba65021a..6b1f1e73a 100644 --- a/internal/cli/config_test.go +++ b/internal/cli/config_test.go @@ -446,6 +446,38 @@ func clearConfigEnv(t *testing.T) { "FASTAPI_CLOUD_APP_ID", "CRABBOX_FASTAPI_CLOUD_TEAM_ID", "FASTAPI_CLOUD_TEAM_ID", + "RUNPOD_API_KEY", + "CRABBOX_RUNPOD_API_KEY", + "RUNPOD_API_URL", + "CRABBOX_RUNPOD_API_URL", + "RUNPOD_CLOUD_TYPE", + "CRABBOX_RUNPOD_CLOUD_TYPE", + "RUNPOD_INSTANCE_ID", + "CRABBOX_RUNPOD_INSTANCE_ID", + "RUNPOD_IMAGE", + "CRABBOX_RUNPOD_IMAGE", + "RUNPOD_TEMPLATE_ID", + "CRABBOX_RUNPOD_TEMPLATE_ID", + "CRABBOX_RUNPOD_DISK_GB", + "CRABBOX_RUNPOD_USER", + "CRABBOX_RUNPOD_WORK_ROOT", + "CRABBOX_VAST_API_KEY", + "VAST_API_KEY", + "CRABBOX_VAST_API_URL", + "VAST_API_URL", + "CRABBOX_VAST_INSTANCE_TYPE", + "CRABBOX_VAST_GPU_NAME", + "CRABBOX_VAST_GPU_COUNT", + "CRABBOX_VAST_IMAGE", + "CRABBOX_VAST_TEMPLATE_ID", + "CRABBOX_VAST_RUNTYPE", + "CRABBOX_VAST_DISK_GB", + "CRABBOX_VAST_MAX_DPH_TOTAL", + "CRABBOX_VAST_MIN_RELIABILITY", + "CRABBOX_VAST_ORDER", + "CRABBOX_VAST_USER", + "CRABBOX_VAST_WORK_ROOT", + "CRABBOX_VAST_RELEASE_ACTION", "CRABBOX_NVIDIA_BREV_CLI", "CRABBOX_NVIDIA_BREV_ORG", "CRABBOX_NVIDIA_BREV_TYPE", @@ -4555,6 +4587,21 @@ runpod: diskGB: 25 user: runpod-user workRoot: /workspaces/runpod-test +vast: + apiUrl: https://vast.example.test/api/v0 + instanceType: on-demand + gpuName: RTX 4090 + gpuCount: 2 + image: nvidia/cuda:vast-file + templateId: vast-tpl-file + runtype: ssh_direct + diskGB: 60 + maxDphTotal: 3.5 + minReliability: 0.9 + order: reliability desc + user: root + workRoot: /workspaces/vast-test + releaseAction: stop islo: baseUrl: https://islo.example.test image: docker.io/library/ubuntu:24.04 @@ -4810,6 +4857,9 @@ ssh: if cfg.Runpod.APIURL != "https://runpod.example.test/v1" || cfg.Runpod.CloudType != "SECURE" || cfg.Runpod.InstanceID != "NVIDIA L4" || cfg.Runpod.Image != "runpod/pytorch:custom" || cfg.Runpod.TemplateID != "tpl-file" || cfg.Runpod.DiskGB != 25 || cfg.Runpod.User != "runpod-user" || cfg.Runpod.WorkRoot != "/workspaces/runpod-test" { t.Fatalf("runpod config not loaded: %#v", cfg.Runpod) } + if cfg.Vast.APIURL != "https://vast.example.test/api/v0" || cfg.Vast.InstanceType != "on-demand" || cfg.Vast.GPUName != "RTX 4090" || cfg.Vast.GPUCount != 2 || cfg.Vast.Image != "nvidia/cuda:vast-file" || cfg.Vast.TemplateID != "vast-tpl-file" || cfg.Vast.Runtype != "ssh_direct" || cfg.Vast.DiskGB != 60 || cfg.Vast.MaxDphTotal != 3.5 || cfg.Vast.MinReliability != 0.9 || cfg.Vast.Order != "reliability desc" || cfg.Vast.User != "root" || cfg.Vast.WorkRoot != "/workspaces/vast-test" || cfg.Vast.ReleaseAction != "stop" { + t.Fatalf("vast config not loaded: %#v", cfg.Vast) + } if cfg.Islo.BaseURL != "https://islo.example.test" || cfg.Islo.Image != "docker.io/library/ubuntu:24.04" || cfg.Islo.Workdir != "crabbox" || cfg.Islo.GatewayProfile != "default" || cfg.Islo.SnapshotName != "snap-ready" || cfg.Islo.VCPUs != 4 || cfg.Islo.MemoryMB != 8192 || cfg.Islo.DiskGB != 40 { t.Fatalf("islo config not loaded: %#v", cfg.Islo) } @@ -5144,6 +5194,23 @@ func TestEnvOverridesConfig(t *testing.T) { t.Setenv("CRABBOX_RUNPOD_DISK_GB", "30") t.Setenv("CRABBOX_RUNPOD_USER", "runpod-env-user") t.Setenv("CRABBOX_RUNPOD_WORK_ROOT", "/work/runpod-env") + t.Setenv("VAST_API_KEY", "vast-key-file") + t.Setenv("CRABBOX_VAST_API_KEY", "vast-key-env") + t.Setenv("VAST_API_URL", "https://vast-file.example/api/v0") + t.Setenv("CRABBOX_VAST_API_URL", "https://vast-env.example/api/v0") + t.Setenv("CRABBOX_VAST_INSTANCE_TYPE", "interruptible") + t.Setenv("CRABBOX_VAST_GPU_NAME", "H100") + t.Setenv("CRABBOX_VAST_GPU_COUNT", "4") + t.Setenv("CRABBOX_VAST_IMAGE", "nvidia/cuda:vast-env") + t.Setenv("CRABBOX_VAST_TEMPLATE_ID", "vast-tpl-env") + t.Setenv("CRABBOX_VAST_RUNTYPE", "ssh_direct") + t.Setenv("CRABBOX_VAST_DISK_GB", "80") + t.Setenv("CRABBOX_VAST_MAX_DPH_TOTAL", "4.25") + t.Setenv("CRABBOX_VAST_MIN_RELIABILITY", "0.95") + t.Setenv("CRABBOX_VAST_ORDER", "dlperf desc") + t.Setenv("CRABBOX_VAST_USER", "ubuntu") + t.Setenv("CRABBOX_VAST_WORK_ROOT", "/work/vast-env") + t.Setenv("CRABBOX_VAST_RELEASE_ACTION", "keep") t.Setenv("ISLO_API_KEY", "islo-api-file") t.Setenv("CRABBOX_ISLO_API_KEY", "islo-api-env") t.Setenv("ISLO_BASE_URL", "https://islo-file.example") @@ -5398,6 +5465,9 @@ func TestEnvOverridesConfig(t *testing.T) { if cfg.Runpod.APIKey != "runpod-key-env" || cfg.Runpod.APIURL != "https://runpod-env.example/v1" || cfg.Runpod.CloudType != "SECURE" || cfg.Runpod.InstanceID != "NVIDIA L4" || cfg.Runpod.Image != "runpod/pytorch:env" || cfg.Runpod.TemplateID != "tpl-env" || cfg.Runpod.DiskGB != 30 || cfg.Runpod.User != "runpod-env-user" || cfg.Runpod.WorkRoot != "/work/runpod-env" { t.Fatalf("unexpected runpod env: %#v", cfg.Runpod) } + if cfg.Vast.APIKey != "vast-key-env" || cfg.Vast.APIURL != "https://vast-env.example/api/v0" || cfg.Vast.InstanceType != "interruptible" || cfg.Vast.GPUName != "H100" || cfg.Vast.GPUCount != 4 || cfg.Vast.Image != "nvidia/cuda:vast-env" || cfg.Vast.TemplateID != "vast-tpl-env" || cfg.Vast.Runtype != "ssh_direct" || cfg.Vast.DiskGB != 80 || cfg.Vast.MaxDphTotal != 4.25 || cfg.Vast.MinReliability != 0.95 || cfg.Vast.Order != "dlperf desc" || cfg.Vast.User != "ubuntu" || cfg.Vast.WorkRoot != "/work/vast-env" || cfg.Vast.ReleaseAction != "keep" { + t.Fatalf("unexpected vast env: %#v", cfg.Vast) + } if cfg.Islo.APIKey != "islo-api-env" || cfg.Islo.BaseURL != "https://islo-env.example" || cfg.Islo.Image != "ubuntu:env" || cfg.Islo.Workdir != "env-workdir" || cfg.Islo.GatewayProfile != "env-gateway" || cfg.Islo.SnapshotName != "env-snapshot" || cfg.Islo.VCPUs != 8 || cfg.Islo.MemoryMB != 16384 || cfg.Islo.DiskGB != 80 { t.Fatalf("unexpected islo env: %#v", cfg.Islo) } diff --git a/internal/cli/connect_test.go b/internal/cli/connect_test.go index 4bf36f090..5cc2d63d8 100644 --- a/internal/cli/connect_test.go +++ b/internal/cli/connect_test.go @@ -140,7 +140,7 @@ cat DisableHostKeyChecking: true, NoControlMaster: true, } - if !probeConnectSSHTransport(context.Background(), &target, 5*time.Second) { + if !probeConnectSSHTransport(context.Background(), &target, 30*time.Second) { t.Fatal("resolve fallback port failed") } err := runInteractiveSSH(context.Background(), target, strings.NewReader(""), &stdout, &stderr) diff --git a/internal/cli/controller_subprocess_test.go b/internal/cli/controller_subprocess_test.go index 109832d2b..7f90dfc14 100644 --- a/internal/cli/controller_subprocess_test.go +++ b/internal/cli/controller_subprocess_test.go @@ -23,9 +23,9 @@ import ( func controllerSubprocessTestTimeout(base time.Duration) time.Duration { if raceEnabled { - return 5 * base + return 10 * base } - return base + return 5 * base } type fixedIdentityExecControllerRunner struct { diff --git a/internal/cli/credential_provenance.go b/internal/cli/credential_provenance.go index c19059fb6..3e4047f05 100644 --- a/internal/cli/credential_provenance.go +++ b/internal/cli/credential_provenance.go @@ -45,6 +45,8 @@ type credentialDestinationProvenance struct { fastAPICloudToken credentialValueSource runpodAPIURL credentialValueSource runpodAPIKey credentialValueSource + vastAPIURL credentialValueSource + vastAPIKey credentialValueSource isloBaseURL credentialValueSource isloAPIKey credentialValueSource tenkiEndpoint credentialValueSource @@ -128,6 +130,9 @@ func markCredentialDestinationFlagSources(cfg *Config, fs *flag.FlagSet) { if flagWasSet(fs, "runpod-url") { provenance.runpodAPIURL = credentialSourceFlag } + if flagWasSet(fs, "vast-api-url") { + provenance.vastAPIURL = credentialSourceFlag + } if flagWasSet(fs, "islo-base-url") { provenance.isloBaseURL = credentialSourceFlag } @@ -247,6 +252,11 @@ func validateProviderCredentialDestination(cfg Config) error { inheritedCredential(sourcedCredential{cfg.Runpod.APIKey, provenance.runpodAPIKey}) { return repositoryCredentialDestinationError("runpod", "runpod.apiUrl", "CRABBOX_RUNPOD_API_URL or --runpod-url") } + case "vast": + if provenance.vastAPIURL == credentialSourceRepository && + inheritedCredential(sourcedCredential{cfg.Vast.APIKey, provenance.vastAPIKey}) { + return repositoryCredentialDestinationError("vast", "vast.apiUrl", "CRABBOX_VAST_API_URL or --vast-api-url") + } case "islo": if provenance.isloBaseURL == credentialSourceRepository && inheritedCredential(sourcedCredential{cfg.Islo.APIKey, provenance.isloAPIKey}) { diff --git a/internal/cli/credential_provenance_test.go b/internal/cli/credential_provenance_test.go index 2104deb06..91ac2e4da 100644 --- a/internal/cli/credential_provenance_test.go +++ b/internal/cli/credential_provenance_test.go @@ -135,6 +135,18 @@ func TestRepositoryCredentialDestinationsRejectInheritedCredentials(t *testing.T }, want: "runpod.apiUrl", }, + { + name: "vast api", + cfg: Config{ + Provider: "vast", + Vast: VastConfig{APIURL: "https://repo.example.test", APIKey: "secret"}, + credentialProvenance: credentialDestinationProvenance{ + vastAPIURL: credentialSourceRepository, + vastAPIKey: credentialSourceEnvironment, + }, + }, + want: "vast.apiUrl", + }, { name: "islo api", cfg: Config{ @@ -389,6 +401,31 @@ func TestRepositoryCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) } } +func TestVastCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) { + cfg := Config{ + Provider: "vast", + Vast: VastConfig{APIURL: "https://repo.example.test", APIKey: "secret"}, + credentialProvenance: credentialDestinationProvenance{ + vastAPIURL: credentialSourceRepository, + vastAPIKey: credentialSourceEnvironment, + }, + } + fs := newFlagSet("test", io.Discard) + values := registerProviderFlags(fs, cfg) + if err := parseFlags(fs, []string{"--vast-api-url", "https://approved.example.test"}); err != nil { + t.Fatal(err) + } + if err := applyProviderFlags(&cfg, fs, values); err != nil { + t.Fatal(err) + } + if err := validateProviderCredentialDestination(cfg); err != nil { + t.Fatalf("explicit vast flag override rejected: %v", err) + } + if cfg.Vast.APIURL != "https://approved.example.test" { + t.Fatalf("vast apiUrl=%q", cfg.Vast.APIURL) + } +} + func TestAzureDynamicSessionsCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) { cfg := Config{ Provider: "azure-dynamic-sessions", @@ -572,6 +609,13 @@ func TestConfigMergeSourceBindsDirectProviderCredentials(t *testing.T) { credentialEnv: "CRABBOX_RUNPOD_API_KEY", approveEnv: "CRABBOX_RUNPOD_API_URL", }, + { + name: "vast", + provider: "vast", + file: fileConfig{Vast: &fileVastConfig{APIURL: "https://repo.example.test"}}, + credentialEnv: "CRABBOX_VAST_API_KEY", + approveEnv: "CRABBOX_VAST_API_URL", + }, { name: "islo", provider: "islo", diff --git a/internal/cli/provider_categories_generated.go b/internal/cli/provider_categories_generated.go index ffa9baad1..215f8606e 100644 --- a/internal/cli/provider_categories_generated.go +++ b/internal/cli/provider_categories_generated.go @@ -65,6 +65,7 @@ var benchmarkProviderCategories = map[string]string{ "tenki": "direct-cloud", "tensorlake": "delegated-sandbox", "upstash-box": "delegated-sandbox", + "vast": "gpu-cloud", "vercel-sandbox": "delegated-sandbox", "vultr": "direct-cloud", "wandb": "gpu-cloud", diff --git a/internal/cli/providers_builtin_test.go b/internal/cli/providers_builtin_test.go index 456591bcb..5d524030c 100644 --- a/internal/cli/providers_builtin_test.go +++ b/internal/cli/providers_builtin_test.go @@ -32,6 +32,7 @@ func init() { RegisterProvider(testExternalProvider{}) RegisterProvider(testExeDevProvider{}) RegisterProvider(testRunPodProvider{}) + RegisterProvider(testVastProvider{}) RegisterProvider(testNvidiaBrevProvider{}) RegisterProvider(testBlacksmithProvider{}) RegisterProvider(testNamespaceProvider{}) @@ -1056,6 +1057,38 @@ func (p testRunPodProvider) Configure(cfg Config, rt Runtime) (Backend, error) { return testSSHBackend{spec: p.Spec()}, nil } +type testVastProvider struct{} + +func (testVastProvider) Name() string { return "vast" } +func (testVastProvider) Aliases() []string { + return []string{"vast-ai", "vastai"} +} +func (testVastProvider) Spec() ProviderSpec { + return ProviderSpec{ + Name: "vast", + Family: "vast", + Kind: ProviderKindSSHLease, + Targets: []TargetSpec{{OS: targetLinux}}, + Features: FeatureSet{FeatureSSH, FeatureCrabboxSync, FeatureCleanup}, + Coordinator: CoordinatorNever, + } +} +func (testVastProvider) RegisterFlags(fs *flag.FlagSet, defaults Config) any { + return struct{ APIURL *string }{ + APIURL: fs.String("vast-api-url", defaults.Vast.APIURL, ""), + } +} +func (testVastProvider) ApplyFlags(cfg *Config, fs *flag.FlagSet, values any) error { + v, _ := values.(struct{ APIURL *string }) + if flagWasSet(fs, "vast-api-url") && v.APIURL != nil { + cfg.Vast.APIURL = *v.APIURL + } + return nil +} +func (p testVastProvider) Configure(Config, Runtime) (Backend, error) { + return testSSHBackend{spec: p.Spec()}, nil +} + type testNvidiaBrevProvider struct{} func (testNvidiaBrevProvider) Name() string { return "nvidia-brev" } diff --git a/internal/cli/providers_test.go b/internal/cli/providers_test.go index 616f72d4a..ea061aaa6 100644 --- a/internal/cli/providers_test.go +++ b/internal/cli/providers_test.go @@ -17,6 +17,7 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { var digitalOcean *providerMatrixEntry var vultr *providerMatrixEntry var firecracker *providerMatrixEntry + var vast *providerMatrixEntry var nvidiaBrev *providerMatrixEntry var linode *providerMatrixEntry var nebius *providerMatrixEntry @@ -43,6 +44,9 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if entries[i].Provider == "firecracker" { firecracker = &entries[i] } + if entries[i].Provider == "vast" { + vast = &entries[i] + } if entries[i].Provider == "nvidia-brev" { nvidiaBrev = &entries[i] } @@ -89,6 +93,9 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if firecracker == nil { t.Fatal("firecracker provider not found") } + if vast == nil { + t.Fatal("vast provider not found") + } if nvidiaBrev == nil { t.Fatal("nvidia-brev provider not found") } @@ -243,6 +250,18 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if !containsString(nvidiaBrev.Aliases, "brev") || !containsString(nvidiaBrev.Aliases, "nvidia") { t.Fatalf("nvidia-brev aliases=%v", nvidiaBrev.Aliases) } + if vast.Kind != ProviderKindSSHLease || vast.Family != "vast" || vast.Coordinator != string(CoordinatorNever) { + t.Fatalf("vast kind/family/coordinator=%q/%q/%q", vast.Kind, vast.Family, vast.Coordinator) + } + if !containsString(vast.Targets, targetLinux) { + t.Fatalf("vast targets=%v", vast.Targets) + } + if !containsFeature(vast.Features, FeatureSSH) || !containsFeature(vast.Features, FeatureCrabboxSync) || !containsFeature(vast.Features, FeatureCleanup) { + t.Fatalf("vast features=%v", vast.Features) + } + if !containsString(vast.Aliases, "vast-ai") || !containsString(vast.Aliases, "vastai") { + t.Fatalf("vast aliases=%v", vast.Aliases) + } for _, capability := range []string{"local-runtime", "ssh-host"} { if !containsString(localContainer.Runtime, capability) { t.Fatalf("local-container runtime=%v missing %s", localContainer.Runtime, capability) diff --git a/internal/providers/all/all.go b/internal/providers/all/all.go index 951183bff..451cf214d 100644 --- a/internal/providers/all/all.go +++ b/internal/providers/all/all.go @@ -63,6 +63,7 @@ import ( _ "github.com/openclaw/crabbox/internal/providers/tenki" _ "github.com/openclaw/crabbox/internal/providers/tensorlake" _ "github.com/openclaw/crabbox/internal/providers/upstashbox" + _ "github.com/openclaw/crabbox/internal/providers/vast" _ "github.com/openclaw/crabbox/internal/providers/vercelsandbox" _ "github.com/openclaw/crabbox/internal/providers/vultr" _ "github.com/openclaw/crabbox/internal/providers/wandb" diff --git a/internal/providers/all/all_test.go b/internal/providers/all/all_test.go index 467abee5c..46e40b815 100644 --- a/internal/providers/all/all_test.go +++ b/internal/providers/all/all_test.go @@ -77,6 +77,34 @@ func TestNvidiaBrevRegistersCanonicalAndAliases(t *testing.T) { } } +func TestVastRegistersCanonicalAndAliases(t *testing.T) { + for _, name := range []string{"vast", "vast-ai", "vastai"} { + provider, err := core.ProviderFor(name) + if err != nil { + t.Fatalf("ProviderFor(%q): %v", name, err) + } + if provider.Name() != "vast" { + t.Fatalf("ProviderFor(%q).Name=%q want vast", name, provider.Name()) + } + } + provider, err := core.ProviderFor("vast") + if err != nil { + t.Fatalf("ProviderFor(vast): %v", err) + } + spec := provider.Spec() + if spec.Family != "vast" || spec.Kind != core.ProviderKindSSHLease || spec.Coordinator != core.CoordinatorNever { + t.Fatalf("vast spec=%#v", spec) + } + if len(spec.Targets) != 1 || spec.Targets[0].OS != core.TargetLinux { + t.Fatalf("vast targets=%#v", spec.Targets) + } + for _, feature := range []core.Feature{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup} { + if !spec.Features.Has(feature) { + t.Fatalf("vast features=%v missing %s", spec.Features, feature) + } + } +} + func TestLambdaRegistersAsBuiltInProvider(t *testing.T) { provider, err := core.ProviderFor("lambda") if err != nil { @@ -1186,6 +1214,7 @@ func allBuiltInProviderNames() []string { "tenki", "tensorlake", "upstash-box", + "vast", "vercel-sandbox", "vultr", "wandb", diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go new file mode 100644 index 000000000..f79319f38 --- /dev/null +++ b/internal/providers/vast/backend.go @@ -0,0 +1,809 @@ +package vast + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "strings" + "time" + + core "github.com/openclaw/crabbox/internal/cli" + "github.com/openclaw/crabbox/internal/providers/shared" +) + +const ( + vastPollInterval = 3 * time.Second + vastPollTimeout = 10 * time.Minute + vastCleanupTimeout = 30 * time.Second + vastKeyIDLabel = "provider_key_id" + vastKeyOwnedLabel = "provider_key_owned" + vastOfferIDLabel = "vast_offer_id" + vastReadyCheck = "command -v git >/dev/null && command -v rsync >/dev/null && command -v tar >/dev/null && command -v python3 >/dev/null" + vastReleaseActionLabel = "release_action" +) + +type backend struct { + shared.DirectSSHBackend + cfg core.Config + rt core.Runtime + apiFactory func(core.Runtime) (vastAPI, error) + waitSSH func(context.Context, *core.SSHTarget, string, time.Duration) error + sleep func(context.Context, time.Duration) error + pollTimeout time.Duration + cleanupTimeout time.Duration +} + +func newBackend(spec core.ProviderSpec, cfg core.Config, rt core.Runtime) *backend { + applyVastDefaults(&cfg) + b := &backend{cfg: cfg, rt: rt, pollTimeout: vastPollTimeout, cleanupTimeout: vastCleanupTimeout} + b.DirectSSHBackend = shared.DirectSSHBackend{SpecValue: spec, Cfg: cfg, RT: rt, Delete: b.deleteServer, StoredLeaseKeys: true} + b.apiFactory = func(rt core.Runtime) (vastAPI, error) { return newVastClient(cfg.Vast, rt) } + b.waitSSH = func(ctx context.Context, target *core.SSHTarget, phase string, timeout time.Duration) error { + return core.WaitForSSHReady(ctx, target, b.stderr(), phase, timeout) + } + b.sleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + return b +} + +func applyVastDefaults(cfg *core.Config) { + cfg.Provider = providerName + if cfg.TargetOS == "" { + cfg.TargetOS = core.TargetLinux + } + if core.IsSSHUserExplicit(cfg) { + // The generic SSH user is the operator-facing override. Vast.User only + // provides the provider default when that override was not used. + } else if cfg.Vast.User != "" { + cfg.SSHUser = cfg.Vast.User + } else if cfg.SSHUser == "" { + cfg.SSHUser = "root" + } + if cfg.Vast.WorkRoot != "" { + cfg.WorkRoot = cfg.Vast.WorkRoot + } + if cfg.WorkRoot == "" { + cfg.WorkRoot = "/work/crabbox" + } + if cfg.SSHPort == "" { + cfg.SSHPort = "22" + } + if cfg.Vast.InstanceType == "" { + cfg.Vast.InstanceType = "ondemand" + } + if cfg.Vast.Runtype == "" { + cfg.Vast.Runtype = "ssh_direct" + } + if cfg.Vast.Order == "" { + cfg.Vast.Order = "dlperf_per_dphtotal desc" + } + if cfg.Vast.ReleaseAction == "" { + cfg.Vast.ReleaseAction = "destroy" + } +} + +func (b *backend) stderr() io.Writer { + if b.rt.Stderr != nil { + return b.rt.Stderr + } + return io.Discard +} + +func (b *backend) now() time.Time { + if b.rt.Clock != nil { + return b.rt.Clock.Now().UTC() + } + return time.Now().UTC() +} + +func (b *backend) api() (vastAPI, error) { + if b.apiFactory != nil { + return b.apiFactory(b.rt) + } + return newVastClient(b.cfg.Vast, b.rt) +} + +func (b *backend) Doctor(ctx context.Context, _ core.DoctorRequest) (core.DoctorResult, error) { + client, err := b.api() + if err != nil { + return core.DoctorResult{}, err + } + if _, err := client.CheckAuth(ctx); err != nil { + return core.DoctorResult{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.DoctorResult{}, err + } + count := 0 + for _, item := range instances { + if isOwnedVastInstance(item) { + count++ + } + } + result := core.InventoryDoctorResult(providerName, count) + result.Message += fmt.Sprintf(" default_order=%s runtype=%s user=%s", b.cfg.Vast.Order, b.cfg.Vast.Runtype, b.cfg.SSHUser) + return result, nil +} + +func (b *backend) Acquire(ctx context.Context, req core.AcquireRequest) (core.LeaseTarget, error) { + return shared.AcquireAttemptsRetry(b.rt, req.Keep, func() (core.LeaseTarget, error) { + return b.acquireOnce(ctx, req) + }) +} + +func (b *backend) acquireOnce(ctx context.Context, req core.AcquireRequest) (target core.LeaseTarget, err error) { + if b.cfg.TargetOS != "" && b.cfg.TargetOS != core.TargetLinux { + return core.LeaseTarget{}, exit(2, "provider=%s supports target=linux only", providerName) + } + client, err := b.api() + if err != nil { + return core.LeaseTarget{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.LeaseTarget{}, err + } + servers := serversFromInstances(instances, b.cfg, false) + leaseID := core.NewLeaseID() + slug, err := core.AllocateDirectLeaseSlug(leaseID, req.RequestedSlug, servers) + if err != nil { + return core.LeaseTarget{}, err + } + keyPath, publicKey, err := core.EnsureTestboxKeyForConfig(b.cfg, leaseID) + if err != nil { + return core.LeaseTarget{}, err + } + cfg := b.cfg + cfg.SSHKey = keyPath + cfg.ProviderKey = core.ProviderKeyForLease(leaseID) + now := b.now() + label := encodeVastOwnershipLabel(leaseID, slug, "provisioning") + var ( + instanceID int + keyID string + committed bool + ) + defer func() { + if err == nil || committed { + return + } + if instanceID == 0 { + core.RemoveStoredTestboxKey(leaseID) + return + } + _ = b.persistRecoveryClaim(leaseID, slug, cfg, req.Repo.Root, instanceID, keyID, "rollback-cleanup", req.Keep, now) + if !req.Keep { + cleanupErr := rollbackVastAcquire(client, instanceID, keyID) + if cleanupErr != nil { + err = fmt.Errorf("%v; vast cleanup failed: %w", err, cleanupErr) + return + } + core.RemoveLeaseClaim(leaseID) + core.RemoveStoredTestboxKey(leaseID) + } + }() + offers, err := client.SearchOffers(ctx, vastOfferSearchInput{Config: cfg.Vast}) + if err != nil { + return core.LeaseTarget{}, err + } + offer, err := selectVastOffer(offers) + if err != nil { + return core.LeaseTarget{}, err + } + fmt.Fprintf(b.stderr(), "provisioning provider=vast lease=%s slug=%s offer=%d gpu=%s count=%d max_dph=%.4f keep=%v\n", leaseID, slug, offer.ID, offer.GPUName, offer.GPUCount, cfg.Vast.MaxDphTotal, req.Keep) + created, err := client.CreateInstance(ctx, offer.ID, vastCreateInstanceInput{ + Config: cfg.Vast, + Label: label, + SSHKey: publicKey, + Environment: map[string]string{"CRABBOX": "1"}, + }) + if err != nil { + return core.LeaseTarget{}, err + } + instanceID = firstNonZero(created.Instance.ID, created.NewContract) + if instanceID == 0 { + err = exit(5, "vast create returned no instance id") + return core.LeaseTarget{}, err + } + if attach, attachErr := client.AttachInstanceSSHKey(ctx, instanceID, publicKey); attachErr != nil { + err = attachErr + return core.LeaseTarget{}, err + } else { + keyID = vastAttachedKeyID(attach) + } + instance, err := b.waitForInstanceReady(ctx, client, instanceID) + if err != nil { + return core.LeaseTarget{}, err + } + readyLabel := encodeVastOwnershipLabel(leaseID, slug, "ready") + if _, err := client.ManageInstance(ctx, instanceID, vastManageInstanceInput{Label: readyLabel}); err != nil { + return core.LeaseTarget{}, err + } + instance.Label = readyLabel + server := serverFromInstance(instance, cfg) + server.Labels = vastLeaseLabels(cfg, leaseID, slug, "ready", req.Keep, now) + server.Labels[vastOfferIDLabel] = strconv.Itoa(offer.ID) + if keyID != "" { + server.Labels[vastKeyIDLabel] = keyID + } + server.Labels[vastKeyOwnedLabel] = fmt.Sprint(keyID != "") + ssh, err := sshTargetFromInstance(cfg, instance) + if err != nil { + return core.LeaseTarget{}, err + } + if err := b.waitSSH(ctx, &ssh, "vast bootstrap", core.BootstrapWaitTimeout(cfg)); err != nil { + return core.LeaseTarget{}, err + } + target = core.LeaseTarget{Server: server, SSH: ssh, LeaseID: leaseID} + if req.OnAcquired != nil { + if err := req.OnAcquired(target); err != nil { + return core.LeaseTarget{}, err + } + } + if err := core.ClaimLeaseTargetForRepoConfig(leaseID, slug, cfg, server, ssh, req.Repo.Root, cfg.IdleTimeout, req.Reclaim); err != nil { + return core.LeaseTarget{}, err + } + committed = true + fmt.Fprintf(b.stderr(), "provisioned lease=%s vast=%d gpu=%s state=ready\n", leaseID, instanceID, server.ServerType.Name) + return target, nil +} + +func selectVastOffer(offers []vastOffer) (vastOffer, error) { + for _, offer := range offers { + if offer.ID != 0 && offer.Rentable && !offer.Rented { + return offer, nil + } + } + if len(offers) > 0 && offers[0].ID != 0 { + return offers[0], nil + } + return vastOffer{}, exit(4, "vast found no eligible offers") +} + +func vastAttachedKeyID(resp vastAttachSSHKeyResponse) string { + if strings.TrimSpace(resp.Key.ID) != "" { + return strings.TrimSpace(resp.Key.ID) + } + for _, key := range resp.Keys { + if strings.TrimSpace(key.ID) != "" { + return strings.TrimSpace(key.ID) + } + } + return "" +} + +func (b *backend) waitForInstanceReady(ctx context.Context, client vastAPI, id int) (vastInstance, error) { + deadline := b.now().Add(b.pollTimeout) + for { + instance, err := client.GetInstance(ctx, id) + if err != nil { + return vastInstance{}, err + } + if isVastInstanceRunning(instance) && strings.TrimSpace(instance.SSHHost) != "" && instance.SSHPort > 0 { + return instance, nil + } + if isTerminalVastStatus(instance.Status) { + return vastInstance{}, exit(5, "vast instance %d reached terminal status %s", id, instance.Status) + } + if b.now().After(deadline) { + return vastInstance{}, exit(5, "timed out waiting for Vast instance %d to expose SSH", id) + } + if err := b.sleep(ctx, vastPollInterval); err != nil { + return vastInstance{}, err + } + } +} + +func (b *backend) Resolve(ctx context.Context, req core.ResolveRequest) (core.LeaseTarget, error) { + client, err := b.api() + if err != nil { + return core.LeaseTarget{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.LeaseTarget{}, err + } + byID := map[int]vastInstance{} + for _, item := range instances { + byID[item.ID] = item + } + servers := serversFromInstances(instances, b.cfg, false) + server, leaseID, err := core.FindServerByAlias(servers, req.ID) + if err != nil { + return core.LeaseTarget{}, err + } + if leaseID != "" { + if id, ok := parseVastInstanceID(server.CloudID); ok { + return b.targetFromInstance(byID[id], req) + } + } + if claim, ok, claimErr := core.ResolveLeaseClaimForProvider(req.ID, providerName); claimErr != nil { + return core.LeaseTarget{}, claimErr + } else if ok { + if id, parseOK := parseVastInstanceID(claim.CloudID); parseOK { + item, getErr := client.GetInstance(ctx, id) + if getErr == nil { + return b.targetFromInstance(item, req) + } + if req.ReleaseOnly { + return claimTarget(claim), nil + } + return core.LeaseTarget{}, getErr + } + if req.ReleaseOnly { + return claimTarget(claim), nil + } + } + if id, ok := parseVastInstanceID(req.ID); ok { + item, found := byID[id] + if !found { + item, err = client.GetInstance(ctx, id) + if err != nil { + return b.releaseTargetFromClaim(req.ID, err, req.ReleaseOnly) + } + } + return b.targetFromInstance(item, req) + } + return core.LeaseTarget{}, exit(4, "lease/instance not found: %s", req.ID) +} + +func (b *backend) releaseTargetFromClaim(id string, cause error, releaseOnly bool) (core.LeaseTarget, error) { + if !releaseOnly { + return core.LeaseTarget{}, cause + } + claim, ok, err := core.ResolveLeaseClaimForProvider(id, providerName) + if err != nil || !ok { + if err != nil { + return core.LeaseTarget{}, err + } + return core.LeaseTarget{}, cause + } + return claimTarget(claim), nil +} + +func (b *backend) targetFromInstance(item vastInstance, req core.ResolveRequest) (core.LeaseTarget, error) { + if !isOwnedVastInstance(item) { + return core.LeaseTarget{}, exit(2, "refusing to operate on non-Crabbox Vast instance %d", item.ID) + } + if isTerminalVastStatus(item.Status) && !req.ReleaseOnly && !req.StatusOnly { + return core.LeaseTarget{}, exit(5, "vast instance %d reached terminal status %s", item.ID, item.Status) + } + server := serverFromInstance(item, b.cfg) + server = mergeVastClaimMetadata(server) + leaseID := server.Labels["lease"] + target := core.LeaseTarget{Server: server, LeaseID: leaseID} + if !req.ReleaseOnly && (!req.StatusOnly || req.ReadyProbe) { + ssh, err := sshTargetFromInstance(b.cfg, item) + if err != nil { + return core.LeaseTarget{}, err + } + core.UseStoredTestboxKey(&ssh, leaseID) + target.SSH = ssh + } + if req.Repo.Root != "" && !req.NoLocalStateMutations { + if err := core.ClaimLeaseTargetForRepoConfig(leaseID, server.Labels["slug"], b.cfg, target.Server, target.SSH, req.Repo.Root, b.cfg.IdleTimeout, req.Reclaim); err != nil { + return core.LeaseTarget{}, err + } + } + return target, nil +} + +func (b *backend) List(ctx context.Context, req core.ListRequest) ([]core.LeaseView, error) { + client, err := b.api() + if err != nil { + return nil, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return nil, err + } + return serversFromInstances(instances, b.cfg, req.All), nil +} + +func (b *backend) ReleaseLease(ctx context.Context, req core.ReleaseLeaseRequest) error { + if err := core.ValidateLeaseTargetProviderIdentity(req.Lease, req.ExpectedProviderIdentity); err != nil { + return err + } + return b.deleteServer(ctx, b.cfg, req.Lease.Server) +} + +func (b *backend) ReleaseLeaseMessage(lease core.LeaseTarget) string { + action := effectiveVastReleaseAction(b.cfg, lease.Server.Labels) + if action == "stop" || action == "keep" { + return fmt.Sprintf("%s lease=%s vast=%s name=%s", action, lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) + } + return fmt.Sprintf("destroyed lease=%s vast=%s name=%s", lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) +} + +func (b *backend) Touch(_ context.Context, req core.TouchRequest) (core.Server, error) { + server := req.Lease.Server + if err := validateVastServer(server); err != nil { + return core.Server{}, err + } + cfg := b.cfg + if req.IdleTimeout > 0 { + cfg.IdleTimeout = req.IdleTimeout + } + server.Labels = core.TouchDirectLeaseLabels(server.Labels, cfg, req.State, b.now()) + if claim, ok, err := core.ReadLeaseClaimWithPresence(req.Lease.LeaseID); err == nil && ok { + if _, err := core.UpdateLeaseClaimLabelsIfUnchanged(req.Lease.LeaseID, claim, server.Labels); err != nil { + return core.Server{}, err + } + } + return server, nil +} + +func (b *backend) Cleanup(ctx context.Context, req core.CleanupRequest) error { + servers, err := b.List(ctx, core.ListRequest{Options: req.Options}) + if err != nil { + return err + } + servers, err = b.prepareCleanupServers(servers) + if err != nil { + return err + } + return b.CleanupServers(ctx, req, servers) +} + +func (b *backend) prepareCleanupServers(servers []core.Server) ([]core.Server, error) { + for i := range servers { + updated, err := b.prepareCleanupServer(servers[i]) + if err != nil { + return nil, err + } + servers[i] = updated + } + return servers, nil +} + +func (b *backend) prepareCleanupServer(server core.Server) (core.Server, error) { + if server.Provider != providerName { + return server, nil + } + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server, nil + } + claim, claimExists, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil { + return core.Server{}, fmt.Errorf("read vast cleanup claim: %w", err) + } + if claimExists && claim.Provider == providerName && (claim.CloudID == "" || server.CloudID == "" || claim.CloudID == server.CloudID) { + return server, nil + } + + labels := cloneLabels(server.Labels) + labels["state"] = "expired" + labels["expires_at"] = core.LeaseLabelTime(b.now().Add(-time.Second)) + server.Labels = labels + return server, nil +} + +func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.Server) error { + if err := validateVastServer(server); err != nil { + return err + } + client, err := b.api() + if err != nil { + return err + } + leaseID := server.Labels["lease"] + claim, claimExists, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil { + return fmt.Errorf("read vast cleanup claim: %w", err) + } + if !claimExists { + return exit(2, "lease=%s has no local Vast claim; refusing destructive cleanup", leaseID) + } + if claim.Provider != providerName { + return exit(2, "lease=%s is claimed by provider=%s; refusing Vast cleanup", leaseID, claim.Provider) + } + if claim.CloudID != "" && server.CloudID != "" && claim.CloudID != server.CloudID { + return exit(2, "refusing to release Vast instance %s from stale local claim", server.CloudID) + } + instanceID, ok := parseVastInstanceID(firstNonBlank(server.CloudID, claim.CloudID)) + if !ok { + return exit(2, "provider=%s release requires a Vast instance id", providerName) + } + if live, getErr := client.GetInstance(ctx, instanceID); getErr == nil { + if err := validateLiveVastInstance(live, server); err != nil { + return err + } + } else if !isVastNotFound(getErr) { + return getErr + } + action := effectiveVastReleaseAction(b.cfg, claim.Labels) + switch action { + case "keep": + return nil + case "stop": + if _, err := client.ManageInstance(ctx, instanceID, vastManageInstanceInput{State: "stopped", Label: encodeVastOwnershipLabel(leaseID, server.Labels["slug"], "stopped")}); err != nil { + return err + } + labels := cloneLabels(claim.Labels) + labels["state"] = "stopped" + labels[vastReleaseActionLabel] = "stop" + if _, err := core.UpdateLeaseClaimLabelsIfUnchanged(leaseID, claim, labels); err != nil { + return fmt.Errorf("finalize vast stop claim: %w", err) + } + default: + if keyID := strings.TrimSpace(claim.Labels[vastKeyIDLabel]); keyID != "" && claim.Labels[vastKeyOwnedLabel] == "true" { + if err := client.DetachInstanceSSHKey(ctx, instanceID, keyID); err != nil && !isVastNotFound(err) { + return err + } + } + if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { + return err + } + if err := core.RemoveLeaseClaimIfUnchanged(leaseID, claim); err != nil { + return fmt.Errorf("finalize vast cleanup claim: %w", err) + } + core.RemoveStoredTestboxKey(leaseID) + } + return nil +} + +func serversFromInstances(instances []vastInstance, cfg core.Config, includeAll bool) []core.Server { + out := make([]core.Server, 0, len(instances)) + for _, item := range instances { + if !includeAll && !isOwnedVastInstance(item) { + continue + } + server := serverFromInstance(item, cfg) + if isOwnedVastInstance(item) { + server = mergeVastClaimLabels(server) + } + out = append(out, server) + } + return out +} + +func mergeVastClaimLabels(server core.Server) core.Server { + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server + } + claim, ok, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil || !ok || claim.Provider != providerName { + return server + } + if claim.CloudID != "" && claim.CloudID != server.CloudID { + return server + } + if len(claim.Labels) > 0 { + server.Labels = claim.Labels + } + return server +} + +func mergeVastClaimMetadata(server core.Server) core.Server { + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server + } + claim, ok, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil || !ok || claim.Provider != providerName { + return server + } + if claim.CloudID != "" && server.CloudID != "" && claim.CloudID != server.CloudID { + return server + } + server.Labels = preserveVastClaimMetadata(server.Labels, claim.Labels) + return server +} + +func serverFromInstance(item vastInstance, cfg core.Config) core.Server { + labels := labelsFromVastInstance(item, cfg) + server := core.Server{ + CloudID: strconv.Itoa(item.ID), + Provider: providerName, + Name: firstNonBlank(labels["slug"], item.Label, strconv.Itoa(item.ID)), + Status: normalizeVastStatus(item.Status), + Labels: labels, + } + server.PublicNet.IPv4.IP = strings.TrimSpace(item.SSHHost) + server.ServerType.Name = firstNonBlank(item.GPUName, cfg.ServerType) + return server +} + +func labelsFromVastInstance(item vastInstance, cfg core.Config) map[string]string { + if owner, ok := decodeVastOwnershipLabel(item.Label); ok { + labels := vastLeaseLabels(cfg, owner.LeaseID, owner.Slug, owner.State, false, time.Now().UTC()) + labels["provider_key"] = core.ProviderKeyForLease(owner.LeaseID) + labels[vastReleaseActionLabel] = normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + return labels + } + return map[string]string{"label": strings.TrimSpace(item.Label)} +} + +func vastLeaseLabels(cfg core.Config, leaseID, slug, state string, keep bool, now time.Time) map[string]string { + labels := core.DirectLeaseLabels(cfg, leaseID, slug, providerName, "", keep, now) + labels["state"] = state + labels[vastReleaseActionLabel] = normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + return labels +} + +func cloneLabels(labels map[string]string) map[string]string { + out := make(map[string]string, len(labels)) + for key, value := range labels { + out[key] = value + } + return out +} + +func preserveVastClaimMetadata(labels, existing map[string]string) map[string]string { + out := cloneLabels(labels) + for _, key := range []string{ + vastReleaseActionLabel, + vastKeyIDLabel, + vastKeyOwnedLabel, + vastOfferIDLabel, + "provider_key", + "recovery", + } { + if value, ok := existing[key]; ok { + out[key] = value + } + } + return out +} + +func isOwnedVastInstance(item vastInstance) bool { + return isVastCrabboxOwnedLabel(item.Label) +} + +func validateVastServer(server core.Server) error { + if server.Provider != "" && server.Provider != providerName { + return exit(2, "refusing to operate on provider=%s server as Vast", server.Provider) + } + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" || strings.TrimSpace(server.Labels["slug"]) == "" { + return exit(2, "refusing to operate on non-Crabbox Vast instance %s", server.DisplayID()) + } + return nil +} + +func validateLiveVastInstance(item vastInstance, expected core.Server) error { + if !isOwnedVastInstance(item) { + return exit(2, "refusing to operate on non-Crabbox Vast instance %d", item.ID) + } + owner, _ := decodeVastOwnershipLabel(item.Label) + if strconv.Itoa(item.ID) != expected.CloudID || + owner.LeaseID != expected.Labels["lease"] || + owner.Slug != expected.Labels["slug"] { + return exit(2, "refusing to operate on changed Vast instance %s", expected.CloudID) + } + return nil +} + +func sshTargetFromInstance(cfg core.Config, item vastInstance) (core.SSHTarget, error) { + host := strings.TrimSpace(item.SSHHost) + if host == "" || item.SSHPort <= 0 { + return core.SSHTarget{}, exit(5, "vast instance %d is missing SSH endpoint", item.ID) + } + ssh := core.SSHTargetFromConfig(cfg, host) + ssh.Port = strconv.Itoa(item.SSHPort) + ssh.User = firstNonBlank(cfg.SSHUser, cfg.Vast.User, "root") + ssh.TargetOS = core.TargetLinux + ssh.ReadyCheck = vastReadyCheck + return ssh, nil +} + +func isVastInstanceRunning(item vastInstance) bool { + switch strings.ToLower(strings.TrimSpace(item.Status)) { + case "running", "active", "ready": + return true + default: + return false + } +} + +func isTerminalVastStatus(status string) bool { + switch strings.ToLower(strings.TrimSpace(status)) { + case "failed", "error", "exited", "cancelled", "canceled", "destroyed", "deleted", "dead": + return true + default: + return false + } +} + +func normalizeVastStatus(status string) string { + if isVastInstanceRunning(vastInstance{Status: status}) { + return "ready" + } + if status = strings.TrimSpace(status); status != "" { + return status + } + return "unknown" +} + +func normalizeVastReleaseAction(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "stop": + return "stop" + case "keep": + return "keep" + default: + return "destroy" + } +} + +func effectiveVastReleaseAction(cfg core.Config, labels map[string]string) string { + if core.DeleteOnReleaseExplicit(cfg, providerName) { + return normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + } + return normalizeVastReleaseAction(firstNonBlank(labels[vastReleaseActionLabel], cfg.Vast.ReleaseAction)) +} + +func parseVastInstanceID(value string) (int, bool) { + id, err := strconv.Atoi(strings.TrimSpace(value)) + return id, err == nil && id > 0 +} + +func firstNonBlank(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func claimTarget(claim core.LeaseClaim) core.LeaseTarget { + server := core.Server{ + CloudID: claim.CloudID, + Provider: providerName, + Name: claim.Slug, + Status: claim.Labels["state"], + Labels: claim.Labels, + } + server.PublicNet.IPv4.IP = claim.SSHHost + target := core.SSHTarget{Host: claim.SSHHost, Port: strconv.Itoa(claim.SSHPort), TargetOS: core.TargetLinux} + core.UseStoredTestboxKey(&target, claim.LeaseID) + return core.LeaseTarget{LeaseID: claim.LeaseID, Server: server, SSH: target} +} + +func (b *backend) persistRecoveryClaim(leaseID, slug string, cfg core.Config, repoRoot string, instanceID int, keyID, reason string, keep bool, now time.Time) error { + label := encodeVastOwnershipLabel(leaseID, slug, reason) + server := serverFromInstance(vastInstance{ID: instanceID, Label: label, Status: reason}, cfg) + server.Labels = vastLeaseLabels(cfg, leaseID, slug, reason, keep, now) + server.Labels["recovery"] = reason + if keyID != "" { + server.Labels[vastKeyIDLabel] = keyID + server.Labels[vastKeyOwnedLabel] = "true" + } + return core.ClaimLeaseTargetForRepoConfig(leaseID, slug, cfg, server, core.SSHTarget{}, repoRoot, cfg.IdleTimeout, true) +} + +func rollbackVastAcquire(client vastAPI, instanceID int, keyID string) error { + ctx, cancel := context.WithTimeout(context.Background(), vastCleanupTimeout) + defer cancel() + var errs []error + if keyID != "" { + if err := client.DetachInstanceSSHKey(ctx, instanceID, keyID); err != nil && !isVastNotFound(err) { + errs = append(errs, err) + } + } + if instanceID != 0 { + if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func isVastNotFound(err error) bool { + var apiErr *vastAPIError + return errors.As(err, &apiErr) && apiErr.StatusCode == 404 +} diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go new file mode 100644 index 000000000..494a4dbc2 --- /dev/null +++ b/internal/providers/vast/backend_test.go @@ -0,0 +1,626 @@ +package vast + +import ( + "bytes" + "context" + "errors" + "io" + "strconv" + "strings" + "testing" + "time" + + core "github.com/openclaw/crabbox/internal/cli" +) + +type fakeVastAPI struct { + user vastUser + offers []vastOffer + instances []vastInstance + authErr error + listErr error + createErr error + getErr error + manageErr error + destroyErr error + attachErr error + detachErr error + + searches []vastOfferSearchInput + creates []struct { + offerID int + input vastCreateInstanceInput + } + managed []struct { + id int + input vastManageInstanceInput + } + destroyed []int + attached []struct { + id int + publicKey string + } + detached []struct { + id int + keyID string + } + events []string + nextID int +} + +func (f *fakeVastAPI) CheckAuth(context.Context) (vastUser, error) { + if f.authErr != nil { + return vastUser{}, f.authErr + } + if f.user.ID == 0 { + return vastUser{ID: 7, Username: "alice"}, nil + } + return f.user, nil +} + +func (f *fakeVastAPI) SearchOffers(_ context.Context, input vastOfferSearchInput) ([]vastOffer, error) { + f.searches = append(f.searches, input) + return append([]vastOffer(nil), f.offers...), nil +} + +func (f *fakeVastAPI) CreateInstance(_ context.Context, offerID int, input vastCreateInstanceInput) (vastCreateInstanceResponse, error) { + f.creates = append(f.creates, struct { + offerID int + input vastCreateInstanceInput + }{offerID: offerID, input: input}) + if f.createErr != nil { + return vastCreateInstanceResponse{}, f.createErr + } + if f.nextID == 0 { + f.nextID = 100 + } + item := vastInstance{ + ID: f.nextID, + Label: input.Label, + Status: "running", + SSHHost: "203.0.113.24", + SSHPort: 2201, + GPUName: "RTX 4090", + GPUCount: 1, + DphTotal: 0.75, + } + f.instances = append(f.instances, item) + f.nextID++ + return vastCreateInstanceResponse{Success: true, NewContract: item.ID, Instance: item}, nil +} + +func (f *fakeVastAPI) GetInstance(_ context.Context, id int) (vastInstance, error) { + if f.getErr != nil { + return vastInstance{}, f.getErr + } + for _, item := range f.instances { + if item.ID == id { + return item, nil + } + } + return vastInstance{}, &vastAPIError{StatusCode: 404, Status: "404 Not Found"} +} + +func (f *fakeVastAPI) ListInstances(context.Context) ([]vastInstance, error) { + if f.listErr != nil { + return nil, f.listErr + } + return append([]vastInstance(nil), f.instances...), nil +} + +func (f *fakeVastAPI) ManageInstance(_ context.Context, id int, input vastManageInstanceInput) (vastInstance, error) { + f.managed = append(f.managed, struct { + id int + input vastManageInstanceInput + }{id: id, input: input}) + if f.manageErr != nil { + return vastInstance{}, f.manageErr + } + for i := range f.instances { + if f.instances[i].ID == id { + if input.Label != "" { + f.instances[i].Label = input.Label + } + if input.State != "" { + f.instances[i].Status = input.State + } else if f.instances[i].Status == "starting" { + f.instances[i].Status = "running" + } + return f.instances[i], nil + } + } + return vastInstance{}, &vastAPIError{StatusCode: 404, Status: "404 Not Found"} +} + +func (f *fakeVastAPI) DestroyInstance(_ context.Context, id int) error { + f.destroyed = append(f.destroyed, id) + f.events = append(f.events, "destroy:"+strconv.Itoa(id)) + if f.destroyErr != nil { + return f.destroyErr + } + out := f.instances[:0] + for _, item := range f.instances { + if item.ID != id { + out = append(out, item) + } + } + f.instances = out + return nil +} + +func (f *fakeVastAPI) ListInstanceSSHKeys(context.Context, int) ([]vastInstanceSSHKey, error) { + return nil, nil +} + +func (f *fakeVastAPI) AttachInstanceSSHKey(_ context.Context, id int, publicKey string) (vastAttachSSHKeyResponse, error) { + f.attached = append(f.attached, struct { + id int + publicKey string + }{id: id, publicKey: publicKey}) + if f.attachErr != nil { + return vastAttachSSHKeyResponse{}, f.attachErr + } + return vastAttachSSHKeyResponse{Success: true, Key: vastInstanceSSHKey{ID: "key-" + strconv.Itoa(id), PublicKey: publicKey}}, nil +} + +func (f *fakeVastAPI) DetachInstanceSSHKey(_ context.Context, id int, keyID string) error { + f.detached = append(f.detached, struct { + id int + keyID string + }{id: id, keyID: keyID}) + f.events = append(f.events, "detach:"+strconv.Itoa(id)+":"+keyID) + return f.detachErr +} + +func newTestBackend(t *testing.T, api *fakeVastAPI) *backend { + t.Helper() + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", home) + cfg := core.BaseConfig() + cfg.Provider = providerName + cfg.TargetOS = core.TargetLinux + cfg.SSHUser = "root" + cfg.SSHPort = "22" + cfg.WorkRoot = "/work/crabbox" + cfg.Vast.APIURL = "https://console.vast.ai/api/v0" + cfg.Vast.APIKey = "test-key" + cfg.Vast.User = "root" + cfg.Vast.WorkRoot = "/work/crabbox" + cfg.Vast.InstanceType = "ondemand" + cfg.Vast.Runtype = "ssh_direct" + cfg.Vast.Image = "nvidia/cuda:12" + cfg.Vast.Order = "dlperf_per_dphtotal desc" + cfg.Vast.ReleaseAction = "destroy" + b := newBackend(Provider{}.Spec(), cfg, core.Runtime{Stderr: io.Discard}) + b.apiFactory = func(core.Runtime) (vastAPI, error) { return api, nil } + b.waitSSH = func(context.Context, *core.SSHTarget, string, time.Duration) error { return nil } + b.sleep = func(context.Context, time.Duration) error { return nil } + return b +} + +func TestNewBackendPreservesExplicitGenericSSHUser(t *testing.T) { + cfg := core.BaseConfig() + cfg.Provider = providerName + cfg.SSHUser = "ubuntu" + core.MarkSSHUserExplicit(&cfg) + cfg.Vast.User = "root" + + b := newBackend(Provider{}.Spec(), cfg, core.Runtime{Stderr: io.Discard}) + if b.cfg.SSHUser != "ubuntu" { + t.Fatalf("backend SSHUser=%q want explicit generic user", b.cfg.SSHUser) + } + if b.DirectSSHBackend.Cfg.SSHUser != "ubuntu" { + t.Fatalf("direct SSH backend SSHUser=%q want explicit generic user", b.DirectSSHBackend.Cfg.SSHUser) + } +} + +func TestDoctorIsReadOnlyAndCountsOwnedInventory(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 1, Label: encodeVastOwnershipLabel("cbx_owned", "owned", "ready"), Status: "running"}, + {ID: 2, Label: "manual", Status: "running"}, + }} + result, err := newTestBackend(t, api).Doctor(context.Background(), core.DoctorRequest{}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(result.Message, "leases=1") || !strings.Contains(result.Message, "mutation=false") { + t.Fatalf("doctor result=%#v", result) + } + if len(api.creates) != 0 || len(api.destroyed) != 0 || len(api.managed) != 0 { + t.Fatalf("doctor mutated api: creates=%v destroyed=%v managed=%v", api.creates, api.destroyed, api.managed) + } +} + +func TestListFiltersOwnedByDefaultAndAllShowsManual(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 1, Label: encodeVastOwnershipLabel("cbx_owned", "owned", "ready"), Status: "running"}, + {ID: 2, Label: "manual", Status: "running"}, + }} + b := newTestBackend(t, api) + owned, err := b.List(context.Background(), core.ListRequest{}) + if err != nil { + t.Fatal(err) + } + if len(owned) != 1 || owned[0].CloudID != "1" { + t.Fatalf("owned=%#v", owned) + } + all, err := b.List(context.Background(), core.ListRequest{All: true}) + if err != nil { + t.Fatal(err) + } + if len(all) != 2 { + t.Fatalf("all=%#v", all) + } +} + +func TestAcquireCreatesAttachesPollsReadinessAndClaims(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, GPUName: "RTX 4090", GPUCount: 1, Rentable: true}}} + b := newTestBackend(t, api) + var readyTarget core.SSHTarget + b.waitSSH = func(_ context.Context, target *core.SSHTarget, _ string, _ time.Duration) error { + readyTarget = *target + return nil + } + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "gpu-box", Keep: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID == "" || lease.Server.CloudID != "100" || lease.SSH.Host != "203.0.113.24" || lease.SSH.Port != "2201" || lease.SSH.User != "root" || lease.SSH.Key == "" { + t.Fatalf("lease=%#v", lease) + } + if lease.SSH.ReadyCheck != vastReadyCheck || strings.Contains(lease.SSH.ReadyCheck, "crabbox-ready") { + t.Fatalf("lease ready check=%q", lease.SSH.ReadyCheck) + } + if readyTarget.ReadyCheck != vastReadyCheck { + t.Fatalf("waitSSH target ready check=%q", readyTarget.ReadyCheck) + } + if len(api.searches) != 1 || api.searches[0].Config.Order != "dlperf_per_dphtotal desc" { + t.Fatalf("searches=%#v", api.searches) + } + if len(api.creates) != 1 || api.creates[0].offerID != 42 || api.creates[0].input.Config.Runtype != "ssh_direct" || api.creates[0].input.Environment["CRABBOX"] != "1" { + t.Fatalf("creates=%#v", api.creates) + } + if !isVastCrabboxOwnedLabel(api.creates[0].input.Label) || api.creates[0].input.SSHKey == "" { + t.Fatalf("create input=%#v", api.creates[0].input) + } + if len(api.attached) != 1 || api.attached[0].id != 100 || api.attached[0].publicKey == "" { + t.Fatalf("attached=%#v", api.attached) + } + if len(api.managed) != 1 || !strings.Contains(api.managed[0].input.Label, "|ready") { + t.Fatalf("managed=%#v", api.managed) + } + claim, ok, err := core.ResolveLeaseClaimForProvider("gpu-box", providerName) + if err != nil || !ok || claim.CloudID != "100" || claim.Labels[vastKeyIDLabel] != "key-100" || claim.Labels[vastReleaseActionLabel] != "destroy" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, err) + } +} + +func TestResolvePreservesPersistedVastClaimMetadata(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + repoRoot := t.TempDir() + b.cfg.Vast.ReleaseAction = "stop" + b.DirectSSHBackend.Cfg = b.cfg + acquired, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: repoRoot}, RequestedSlug: "preserve-meta"}) + if err != nil { + t.Fatal(err) + } + + b.cfg.Vast.ReleaseAction = "destroy" + b.DirectSSHBackend.Cfg = b.cfg + resolved, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "preserve-meta", Repo: core.Repo{Root: repoRoot}}) + if err != nil { + t.Fatal(err) + } + if resolved.Server.Labels[vastReleaseActionLabel] != "stop" || resolved.Server.Labels[vastKeyIDLabel] != "key-100" || resolved.Server.Labels[vastKeyOwnedLabel] != "true" { + t.Fatalf("resolved labels=%#v", resolved.Server.Labels) + } + if resolved.SSH.Key != acquired.SSH.Key { + t.Fatalf("resolved SSH key=%q want %q", resolved.SSH.Key, acquired.SSH.Key) + } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("preserve-meta", providerName) + if claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } + if claim.Labels[vastReleaseActionLabel] != "stop" || claim.Labels[vastKeyIDLabel] != "key-100" || claim.Labels[vastKeyOwnedLabel] != "true" { + t.Fatalf("claim labels=%#v", claim.Labels) + } +} + +func TestResolveRejectsTerminalStatusForRunButAllowsRelease(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 9, Label: encodeVastOwnershipLabel("cbx_failed", "failed", "ready"), Status: "failed", SSHHost: "203.0.113.9", SSHPort: 22}}} + b := newTestBackend(t, api) + _, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "failed"}) + if err == nil || !strings.Contains(err.Error(), "terminal status failed") { + t.Fatalf("err=%v", err) + } + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "failed", ReleaseOnly: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_failed" || lease.Server.CloudID != "9" { + t.Fatalf("lease=%#v", lease) + } +} + +func TestResolveStatusOnlyAllowsInstanceWithoutSSHEndpoint(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 10, Label: encodeVastOwnershipLabel("cbx_status", "status-me", "stopped"), Status: "stopped"}}} + b := newTestBackend(t, api) + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "status-me", StatusOnly: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_status" || lease.SSH.Host != "" { + t.Fatalf("lease=%#v", lease) + } +} + +func TestResolveStatusOnlyReadyProbeIncludesSSHTarget(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + acquired, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "status-wait", Keep: true}) + if err != nil { + t.Fatal(err) + } + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "status-wait", StatusOnly: true, ReadyProbe: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != acquired.LeaseID || lease.SSH.Host != acquired.SSH.Host || lease.SSH.Key != acquired.SSH.Key { + t.Fatalf("lease=%#v acquired=%#v", lease, acquired) + } + if lease.SSH.ReadyCheck != vastReadyCheck { + t.Fatalf("ready check=%q", lease.SSH.ReadyCheck) + } +} + +func TestResolvePrefersNumericSlugOverInstanceID(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 123, Label: encodeVastOwnershipLabel("cbx_other", "other", "ready"), Status: "running", SSHHost: "203.0.113.123", SSHPort: 22}, + {ID: 100, Label: encodeVastOwnershipLabel("cbx_numeric", "123", "ready"), Status: "running", SSHHost: "203.0.113.100", SSHPort: 2200}, + }} + b := newTestBackend(t, api) + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "123"}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_numeric" || lease.Server.CloudID != "100" || lease.SSH.Host != "203.0.113.100" { + t.Fatalf("lease=%#v", lease) + } +} + +func TestAcquireRollsBackOnCallbackFailure(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + _, err := b.Acquire(context.Background(), core.AcquireRequest{ + Repo: core.Repo{Root: t.TempDir()}, + RequestedSlug: "rollback", + OnAcquired: func(core.LeaseTarget) error { + return errors.New("controller unavailable") + }, + }) + if err == nil || !strings.Contains(err.Error(), "controller unavailable") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if len(api.detached) != 1 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("rollback", providerName); claimErr != nil || ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + +func TestAcquirePreservesRecoveryClaimWhenRollbackCleanupFails(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}, destroyErr: errors.New("destroy uncertain")} + b := newTestBackend(t, api) + _, err := b.Acquire(context.Background(), core.AcquireRequest{ + Repo: core.Repo{Root: t.TempDir()}, + RequestedSlug: "recover-me", + OnAcquired: func(core.LeaseTarget) error { + return errors.New("controller unavailable") + }, + }) + if err == nil || !strings.Contains(err.Error(), "vast cleanup failed") { + t.Fatalf("err=%v", err) + } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("recover-me", providerName) + if claimErr != nil || !ok || claim.Labels["recovery"] != "rollback-cleanup" || claim.CloudID != "100" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, claimErr) + } +} + +func TestReleaseDestroysByDefaultAndRemovesClaim(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "destroy-me"}) + if err != nil { + t.Fatal(err) + } + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if len(api.detached) != 1 || api.detached[0].id != 100 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if got, want := strings.Join(api.events, ","), "detach:100:key-100,destroy:100"; got != want { + t.Fatalf("events=%q want %q", got, want) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("destroy-me", providerName); claimErr != nil || ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + +func TestReleaseHonorsExplicitReleaseActionOverride(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "keep-me"}) + if err != nil { + t.Fatal(err) + } + if lease.Server.Labels[vastReleaseActionLabel] != "destroy" { + t.Fatalf("lease labels=%#v", lease.Server.Labels) + } + + b.cfg.Vast.ReleaseAction = "keep" + core.MarkDeleteOnReleaseExplicit(&b.cfg, providerName) + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 || len(api.detached) != 0 { + t.Fatalf("destroyed=%v detached=%v", api.destroyed, api.detached) + } + msg := b.ReleaseLeaseMessage(lease) + if !strings.Contains(msg, "keep lease=") || strings.Contains(msg, "destroyed") { + t.Fatalf("message=%q", msg) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("keep-me", providerName); claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + +func TestReleaseStopIsExplicitAndTested(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + b.cfg.Vast.ReleaseAction = "stop" + b.DirectSSHBackend.Cfg = b.cfg + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "stop-me"}) + if err != nil { + t.Fatal(err) + } + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if len(api.managed) < 2 || api.managed[len(api.managed)-1].input.State != "stopped" { + t.Fatalf("managed=%#v", api.managed) + } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("stop-me", providerName) + if claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } + if claim.Labels["state"] != "stopped" || claim.Labels[vastReleaseActionLabel] != "stop" || claim.Labels[vastKeyIDLabel] != "key-100" { + t.Fatalf("claim labels=%#v", claim.Labels) + } + + b.cfg.Vast.ReleaseAction = "destroy" + core.MarkDeleteOnReleaseExplicit(&b.cfg, providerName) + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.detached) != 1 || api.detached[0].id != 100 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("stop-me", providerName); claimErr != nil || ok { + t.Fatalf("claim after destroy ok=%v err=%v", ok, claimErr) + } +} + +func TestReleaseLeaseMessageUsesPersistedReleaseAction(t *testing.T) { + b := newTestBackend(t, &fakeVastAPI{}) + b.cfg.Vast.ReleaseAction = "destroy" + lease := core.LeaseTarget{ + LeaseID: "cbx_message", + Server: core.Server{ + CloudID: "100", + Name: "message-me", + Provider: providerName, + Labels: map[string]string{ + vastReleaseActionLabel: "stop", + }, + }, + } + + msg := b.ReleaseLeaseMessage(lease) + if !strings.Contains(msg, "stop lease=cbx_message") || strings.Contains(msg, "destroyed") { + t.Fatalf("message=%q", msg) + } +} + +func TestCleanupDryRunDoesNotDestroyExpiredOwnedInstance(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 8, Label: encodeVastOwnershipLabel("cbx_old", "old", "ready"), Status: "running"}}} + b := newTestBackend(t, api) + server := serverFromInstance(api.instances[0], b.cfg) + server.Labels["expires_at"] = core.LeaseLabelTime(time.Now().Add(-time.Hour)) + if err := core.ClaimLeaseTargetForRepoConfig("cbx_old", "old", b.cfg, server, core.SSHTarget{}, t.TempDir(), time.Minute, false); err != nil { + t.Fatal(err) + } + var stderr bytes.Buffer + b.rt.Stderr = &stderr + b.DirectSSHBackend.RT = b.rt + if err := b.Cleanup(context.Background(), core.CleanupRequest{DryRun: true}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if !strings.Contains(stderr.String(), "delete server id=8") { + t.Fatalf("stderr=%q", stderr.String()) + } +} + +func TestCleanupReportsMissingClaimForOwnedInstance(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 9, Label: encodeVastOwnershipLabel("cbx_orphan", "orphan", "ready"), Status: "running"}}} + b := newTestBackend(t, api) + var stderr bytes.Buffer + b.rt.Stderr = &stderr + b.DirectSSHBackend.RT = b.rt + + err := b.Cleanup(context.Background(), core.CleanupRequest{}) + if err == nil || !strings.Contains(err.Error(), "lease=cbx_orphan has no local Vast claim") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if !strings.Contains(stderr.String(), "delete server id=9") { + t.Fatalf("stderr=%q", stderr.String()) + } +} + +func TestManualUnownedCleanupIsRejected(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 5, Label: "manual-instance", Status: "running", SSHHost: "203.0.113.5", SSHPort: 22}}} + b := newTestBackend(t, api) + manual := serverFromInstance(api.instances[0], b.cfg) + err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: core.LeaseTarget{LeaseID: "manual", Server: manual}}) + if err == nil || !strings.Contains(err.Error(), "non-Crabbox Vast instance") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } +} + +func TestTouchUpdatesLocalClaimLabels(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "touch-me"}) + if err != nil { + t.Fatal(err) + } + updated, err := b.Touch(context.Background(), core.TouchRequest{Lease: lease, State: "busy", IdleTimeout: 2 * time.Hour}) + if err != nil { + t.Fatal(err) + } + if updated.Labels["state"] != "busy" { + t.Fatalf("updated=%#v", updated.Labels) + } + claim, ok, err := core.ResolveLeaseClaimForProvider("touch-me", providerName) + if err != nil || !ok || claim.Labels["state"] != "busy" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, err) + } +} diff --git a/internal/providers/vast/client.go b/internal/providers/vast/client.go new file mode 100644 index 000000000..35e04ab7d --- /dev/null +++ b/internal/providers/vast/client.go @@ -0,0 +1,661 @@ +package vast + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" +) + +type vastAPI interface { + CheckAuth(context.Context) (vastUser, error) + SearchOffers(context.Context, vastOfferSearchInput) ([]vastOffer, error) + CreateInstance(context.Context, int, vastCreateInstanceInput) (vastCreateInstanceResponse, error) + GetInstance(context.Context, int) (vastInstance, error) + ListInstances(context.Context) ([]vastInstance, error) + ManageInstance(context.Context, int, vastManageInstanceInput) (vastInstance, error) + DestroyInstance(context.Context, int) error + ListInstanceSSHKeys(context.Context, int) ([]vastInstanceSSHKey, error) + AttachInstanceSSHKey(context.Context, int, string) (vastAttachSSHKeyResponse, error) + DetachInstanceSSHKey(context.Context, int, string) error +} + +type vastClient struct { + apiKey string + apiURL string + httpClient *http.Client +} + +const vastMaxResponseBytes = 4 << 20 + +type vastAPIError struct { + Operation string + StatusCode int + Status string + Body string +} + +func (e *vastAPIError) Error() string { + if e.Body == "" { + return fmt.Sprintf("vast %s: %s", e.Operation, e.Status) + } + return fmt.Sprintf("vast %s: %s: %s", e.Operation, e.Status, e.Body) +} + +type vastUser struct { + ID int `json:"id"` + Email string `json:"email"` + Username string `json:"username"` +} + +type vastOffer struct { + ID int `json:"id"` + AskID int `json:"ask_contract_id"` + MachineID int `json:"machine_id"` + GPUName string `json:"gpu_name"` + GPUCount int `json:"num_gpus"` + Reliability float64 `json:"reliability2"` + DphTotal float64 `json:"dph_total"` + SSHHost string `json:"ssh_host"` + SSHPort int `json:"ssh_port"` + Rentable bool `json:"rentable"` + Rented bool `json:"rented"` + Verified bool `json:"verified"` +} + +type vastInstance struct { + ID int `json:"id"` + ContractID int `json:"contract_id"` + Label string `json:"label"` + Status string `json:"actual_status"` + IntendedStatus string `json:"intended_status"` + SSHHost string `json:"ssh_host"` + SSHPort int `json:"ssh_port"` + GPUName string `json:"gpu_name"` + GPUCount int `json:"num_gpus"` + DphTotal float64 `json:"dph_total"` + Image string `json:"image_uuid"` + InstanceAPIKey string `json:"instance_api_key,omitempty"` +} + +type vastCreateInstanceResponse struct { + NewContract int `json:"new_contract"` + Instance vastInstance `json:"instance"` + Success bool `json:"success"` +} + +type vastInstanceSSHKey struct { + ID string `json:"id"` + Name string `json:"name"` + PublicKey string `json:"ssh_key"` +} + +type vastAttachSSHKeyResponse struct { + Success bool `json:"success"` + Key vastInstanceSSHKey `json:"key"` + Keys []vastInstanceSSHKey `json:"keys"` +} + +type vastOfferSearchInput struct { + Config VastConfig +} + +type vastCreateInstanceInput struct { + Config VastConfig + Label string + SSHKey string + Environment map[string]string + OnStart string +} + +type vastManageInstanceInput struct { + State string `json:"state,omitempty"` + Label string `json:"label,omitempty"` +} + +func newVastClient(cfg VastConfig, rt Runtime) (vastAPI, error) { + apiKey := strings.TrimSpace(cfg.APIKey) + if apiKey == "" { + return nil, exit(2, "provider=%s requires CRABBOX_VAST_API_KEY or VAST_API_KEY", providerName) + } + apiURL := strings.TrimRight(strings.TrimSpace(cfg.APIURL), "/") + if apiURL == "" { + apiURL = "https://console.vast.ai/api/v0" + } + parsed, err := url.Parse(apiURL) + if err != nil || parsed.Scheme == "" || parsed.Host == "" || parsed.User != nil { + return nil, exit(2, "vast.apiUrl must be an absolute URL without credentials") + } + if parsed.Scheme != "https" && !isLoopbackHTTPURL(parsed) { + return nil, exit(2, "vast.apiUrl must use https unless it targets localhost") + } + httpClient := rt.HTTP + if httpClient == nil { + httpClient = http.DefaultClient + } + return &vastClient{apiKey: apiKey, apiURL: apiURL, httpClient: secureVastHTTPClient(httpClient, apiURL)}, nil +} + +func secureVastHTTPClient(source *http.Client, apiURL string) *http.Client { + client := *source + trusted, _ := url.Parse(apiURL) + originalCheckRedirect := source.CheckRedirect + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if !sameVastOrigin(trusted, req.URL) { + return fmt.Errorf("%s refused cross-origin redirect to %s", providerName, req.URL.Redacted()) + } + if originalCheckRedirect != nil { + return originalCheckRedirect(req, via) + } + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + } + return &client +} + +func (c *vastClient) do(ctx context.Context, method, path string, body any, out any) error { + var reader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return err + } + reader = bytes.NewReader(data) + } + endpoint := c.apiURL + path + if parsed, err := url.Parse(path); err == nil && parsed.IsAbs() { + endpoint = path + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, reader) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := c.httpClient.Do(req) + if err != nil { + return redactVastString(err.Error(), c.apiKey) + } + defer resp.Body.Close() + data, readErr := io.ReadAll(io.LimitReader(resp.Body, vastMaxResponseBytes+1)) + operation := method + " " + path + if len(data) > vastMaxResponseBytes { + return fmt.Errorf("vast %s response exceeds %d bytes", operation, vastMaxResponseBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return c.decodeAPIError(operation, resp.StatusCode, resp.Status, data, readErr) + } + if readErr != nil { + return fmt.Errorf("vast %s response body: %w", operation, readErr) + } + if out != nil && len(strings.TrimSpace(string(data))) > 0 { + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("decode vast %s response: %w", operation, err) + } + } + return nil +} + +func (c *vastClient) decodeAPIError(operation string, statusCode int, status string, data []byte, readErr error) error { + body := strings.TrimSpace(string(data)) + if len(body) > 1600 { + body = body[:1600] + } + body = redactVastText(body, c.apiKey) + if readErr != nil { + if body != "" { + body += "; " + } + body += "response body read failed: " + readErr.Error() + } + return &vastAPIError{Operation: operation, StatusCode: statusCode, Status: status, Body: body} +} + +func (c *vastClient) CheckAuth(ctx context.Context) (vastUser, error) { + var out vastUser + err := c.do(ctx, http.MethodGet, "/users/current/", nil, &out) + return out, err +} + +func (c *vastClient) SearchOffers(ctx context.Context, input vastOfferSearchInput) ([]vastOffer, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodPost, "/bundles/", buildVastOfferSearchPayload(input.Config), &raw); err != nil { + return nil, err + } + return decodeVastOffers(raw) +} + +func (c *vastClient) CreateInstance(ctx context.Context, offerID int, input vastCreateInstanceInput) (vastCreateInstanceResponse, error) { + var raw json.RawMessage + path := "/asks/" + strconv.Itoa(offerID) + "/" + if err := c.do(ctx, http.MethodPut, path, buildVastCreatePayload(input), &raw); err != nil { + return vastCreateInstanceResponse{}, err + } + return decodeVastCreateInstanceResponse(raw) +} + +func (c *vastClient) GetInstance(ctx context.Context, id int) (vastInstance, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodGet, "/instances/"+strconv.Itoa(id)+"/", nil, &raw); err != nil { + return vastInstance{}, err + } + return decodeVastInstance(raw) +} + +func (c *vastClient) ListInstances(ctx context.Context) ([]vastInstance, error) { + var out []vastInstance + var afterToken string + for { + params := url.Values{} + params.Set("limit", "25") + if afterToken != "" { + params.Set("after_token", afterToken) + } + var raw json.RawMessage + endpoint := vastAPIURLForVersion(c.apiURL, "v1") + "/instances/?" + params.Encode() + if err := c.do(ctx, http.MethodGet, endpoint, nil, &raw); err != nil { + return nil, err + } + page, nextToken, err := decodeVastInstancesPage(raw) + if err != nil { + return nil, err + } + out = append(out, page...) + if strings.TrimSpace(nextToken) == "" { + return out, nil + } + afterToken = nextToken + } +} + +func (c *vastClient) ManageInstance(ctx context.Context, id int, input vastManageInstanceInput) (vastInstance, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodPut, "/instances/"+strconv.Itoa(id)+"/", input, &raw); err != nil { + return vastInstance{}, err + } + return decodeVastInstance(raw) +} + +func (c *vastClient) DestroyInstance(ctx context.Context, id int) error { + return c.do(ctx, http.MethodDelete, "/instances/"+strconv.Itoa(id)+"/", nil, nil) +} + +func (c *vastClient) ListInstanceSSHKeys(ctx context.Context, id int) ([]vastInstanceSSHKey, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodGet, "/instances/"+strconv.Itoa(id)+"/ssh/", nil, &raw); err != nil { + return nil, err + } + return decodeVastSSHKeys(raw) +} + +func (c *vastClient) AttachInstanceSSHKey(ctx context.Context, id int, publicKey string) (vastAttachSSHKeyResponse, error) { + var out vastAttachSSHKeyResponse + body := map[string]string{"ssh_key": publicKey} + err := c.do(ctx, http.MethodPost, "/instances/"+strconv.Itoa(id)+"/ssh/", body, &out) + return out, err +} + +func (c *vastClient) DetachInstanceSSHKey(ctx context.Context, id int, keyID string) error { + return c.do(ctx, http.MethodDelete, "/instances/"+strconv.Itoa(id)+"/ssh/"+url.PathEscape(keyID)+"/", nil, nil) +} + +func buildVastOfferSearchPayload(cfg VastConfig) map[string]any { + payload := map[string]any{ + "verified": vastFilter("eq", true), + "rentable": vastFilter("eq", true), + "rented": vastFilter("eq", false), + "direct_port_count": vastFilter("gte", 1), + } + if cfg.GPUName != "" { + payload["gpu_name"] = vastFilter("eq", cfg.GPUName) + } + if cfg.GPUCount > 0 { + payload["num_gpus"] = vastFilter("gte", cfg.GPUCount) + } + if cfg.MinReliability > 0 { + payload["reliability"] = vastFilter("gte", cfg.MinReliability) + } + if cfg.MaxDphTotal > 0 { + payload["dph_total"] = vastFilter("lte", cfg.MaxDphTotal) + } + instanceType := vastAPIInstanceType(cfg.InstanceType) + if instanceType == "" { + instanceType = "ondemand" + } + payload["type"] = instanceType + if order := strings.TrimSpace(cfg.Order); order != "" { + payload["order"] = vastOrderTuples(order) + } + return payload +} + +func vastFilter(operator string, value any) map[string]any { + return map[string]any{operator: value} +} + +func vastAPIInstanceType(value string) string { + switch normalizeInstanceType(value) { + case "interruptible": + return "bid" + default: + return normalizeInstanceType(value) + } +} + +func vastOrderTuples(order string) [][]string { + parts := strings.Split(order, ",") + out := make([][]string, 0, len(parts)) + for _, part := range parts { + fields := strings.Fields(strings.TrimSpace(part)) + if len(fields) == 0 { + continue + } + direction := "desc" + if len(fields) > 1 { + switch strings.ToLower(fields[1]) { + case "asc", "desc": + direction = strings.ToLower(fields[1]) + } + } + out = append(out, []string{fields[0], direction}) + } + return out +} + +func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { + cfg := input.Config + payload := map[string]any{ + "runtype": strings.TrimSpace(cfg.Runtype), + "target_state": "running", + "cancel_unavail": true, + "vm": false, + } + if cfg.Image != "" { + payload["image"] = cfg.Image + } + if cfg.TemplateID != "" { + payload["template_hash_id"] = cfg.TemplateID + } + if cfg.DiskGB > 0 { + payload["disk"] = cfg.DiskGB + } + if input.Label != "" { + payload["label"] = input.Label + } + if input.SSHKey != "" { + payload["ssh_key"] = input.SSHKey + } + if len(input.Environment) > 0 { + payload["env"] = vastEnvFlags(input.Environment) + } + if input.OnStart != "" { + payload["onstart"] = input.OnStart + } + return payload +} + +func vastEnvFlags(env map[string]string) string { + keys := make([]string, 0, len(env)) + for key := range env { + key = strings.TrimSpace(key) + if key != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, "-e", key+"="+shellQuoteVastEnvValue(env[key])) + } + return strings.Join(parts, " ") +} + +func shellQuoteVastEnvValue(value string) string { + if value == "" { + return "''" + } + if strings.IndexFunc(value, func(r rune) bool { + return (r < 'A' || r > 'Z') && + (r < 'a' || r > 'z') && + (r < '0' || r > '9') && + !strings.ContainsRune("_-./:", r) + }) == -1 { + return value + } + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} + +func decodeVastOffers(raw json.RawMessage) ([]vastOffer, error) { + var direct []vastOffer + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var envelope struct { + Offers []vastOffer `json:"offers"` + Bundles []vastOffer `json:"bundles"` + Data []vastOffer `json:"data"` + Results []vastOffer `json:"results"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, err + } + switch { + case envelope.Offers != nil: + return envelope.Offers, nil + case envelope.Bundles != nil: + return envelope.Bundles, nil + case envelope.Data != nil: + return envelope.Data, nil + default: + return envelope.Results, nil + } +} + +func decodeVastCreateInstanceResponse(raw json.RawMessage) (vastCreateInstanceResponse, error) { + var out vastCreateInstanceResponse + if err := json.Unmarshal(raw, &out); err != nil { + return out, err + } + if out.Instance.ID == 0 { + inst, err := decodeVastInstance(raw) + if err == nil { + out.Instance = inst + } + } + if out.NewContract == 0 { + out.NewContract = firstNonZero(out.Instance.ContractID, out.Instance.ID) + } + return out, nil +} + +func decodeVastInstance(raw json.RawMessage) (vastInstance, error) { + var direct vastInstance + if err := json.Unmarshal(raw, &direct); err != nil { + return vastInstance{}, err + } + if direct.ID != 0 || direct.ContractID != 0 || direct.SSHHost != "" { + return normalizeVastInstance(direct), nil + } + var envelope struct { + Instance vastInstance `json:"instance"` + Instances vastInstance `json:"instances"` + Data vastInstance `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return vastInstance{}, err + } + if envelope.Instance.ID != 0 || envelope.Instance.ContractID != 0 || envelope.Instance.SSHHost != "" { + return normalizeVastInstance(envelope.Instance), nil + } + if envelope.Instances.ID != 0 || envelope.Instances.ContractID != 0 || envelope.Instances.SSHHost != "" { + return normalizeVastInstance(envelope.Instances), nil + } + return normalizeVastInstance(envelope.Data), nil +} + +func decodeVastInstancesPage(raw json.RawMessage) ([]vastInstance, string, error) { + var direct []vastInstance + if err := json.Unmarshal(raw, &direct); err == nil { + return normalizeVastInstances(direct), "", nil + } + var envelope struct { + Instances []vastInstance `json:"instances"` + Data []vastInstance `json:"data"` + Results []vastInstance `json:"results"` + NextToken string `json:"next_token"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, "", err + } + switch { + case envelope.Instances != nil: + return normalizeVastInstances(envelope.Instances), envelope.NextToken, nil + case envelope.Data != nil: + return normalizeVastInstances(envelope.Data), envelope.NextToken, nil + default: + return normalizeVastInstances(envelope.Results), envelope.NextToken, nil + } +} + +func vastAPIURLForVersion(apiURL, version string) string { + base := strings.TrimRight(apiURL, "/") + for _, suffix := range []string{"/api/v0", "/api/v1"} { + if strings.HasSuffix(base, suffix) { + return strings.TrimSuffix(base, suffix) + "/api/" + version + } + } + return base +} + +func decodeVastSSHKeys(raw json.RawMessage) ([]vastInstanceSSHKey, error) { + var direct []vastInstanceSSHKey + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var envelope struct { + Keys []vastInstanceSSHKey `json:"keys"` + Data []vastInstanceSSHKey `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, err + } + if envelope.Keys != nil { + return envelope.Keys, nil + } + return envelope.Data, nil +} + +func normalizeVastInstances(instances []vastInstance) []vastInstance { + for i := range instances { + instances[i] = normalizeVastInstance(instances[i]) + } + return instances +} + +func normalizeVastInstance(instance vastInstance) vastInstance { + if instance.ContractID == 0 { + instance.ContractID = instance.ID + } + if instance.ID == 0 { + instance.ID = instance.ContractID + } + if instance.Status == "" { + instance.Status = instance.IntendedStatus + } + return instance +} + +func sameVastOrigin(a, b *url.URL) bool { + return a != nil && b != nil && + strings.EqualFold(a.Scheme, b.Scheme) && + strings.EqualFold(a.Hostname(), b.Hostname()) && + effectiveVastPort(a) == effectiveVastPort(b) +} + +func effectiveVastPort(value *url.URL) string { + if port := value.Port(); port != "" { + return port + } + switch strings.ToLower(value.Scheme) { + case "https": + return "443" + case "http": + return "80" + default: + return "" + } +} + +func isLoopbackHTTPURL(parsed *url.URL) bool { + if parsed == nil || parsed.Scheme != "http" { + return false + } + host := strings.ToLower(parsed.Hostname()) + ip := net.ParseIP(host) + return host == "localhost" || host == "127.0.0.1" || host == "::1" || (ip != nil && ip.IsLoopback()) +} + +func redactVastString(value, apiKey string) error { + return errors.New(redactVastText(value, apiKey)) +} + +func redactVastText(value, apiKey string) string { + out := value + if apiKey != "" { + out = strings.ReplaceAll(out, apiKey, "") + } + for _, field := range []string{ + "authorization", "api_key", "apiKey", "instance_api_key", "instanceApiKey", + "ssh_key", "private_key", "privateKey", "jupyter_token", "jupyterToken", + "user_data", "userData", "token", + } { + out = redactVastJSONishField(out, field) + out = redactVastInlineField(out, field) + } + out = redactVastPrivateKeyBlock(out) + out = redactVastTokenURLs(out) + return out +} + +func redactVastJSONishField(body, field string) string { + pattern := regexp.MustCompile(`(?i)("` + regexp.QuoteMeta(field) + `"\s*:\s*)("[^"]*"|[^,}\s]+)`) + return pattern.ReplaceAllString(body, `${1}""`) +} + +func redactVastInlineField(body, field string) string { + pattern := regexp.MustCompile(`(?i)(\b` + regexp.QuoteMeta(field) + `\s*[=:]\s*)[^",\s]+`) + return pattern.ReplaceAllString(body, `${1}`) +} + +func redactVastPrivateKeyBlock(body string) string { + pattern := regexp.MustCompile(`-----BEGIN [A-Z ]*PRIVATE KEY-----[\s\S]*?-----END [A-Z ]*PRIVATE KEY-----`) + return pattern.ReplaceAllString(body, "") +} + +func redactVastTokenURLs(body string) string { + pattern := regexp.MustCompile(`https?://[^\s"']*(?i:(token|api_key|instance_api_key)=)[^\s"']+`) + return pattern.ReplaceAllString(body, "") +} + +func firstNonZero(values ...int) int { + for _, value := range values { + if value != 0 { + return value + } + } + return 0 +} diff --git a/internal/providers/vast/client_test.go b/internal/providers/vast/client_test.go new file mode 100644 index 000000000..4b0715231 --- /dev/null +++ b/internal/providers/vast/client_test.go @@ -0,0 +1,313 @@ +package vast + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestClientSendsBearerAuthAndRefusesCrossOriginRedirect(t *testing.T) { + redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("cross-origin redirect target received request with auth=%q", r.Header.Get("Authorization")) + })) + defer redirectTarget.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectTarget.URL+"/capture", http.StatusFound) + })) + defer server.Close() + + client, err := newVastClient(VastConfig{APIKey: "vast-secret", APIURL: server.URL}, Runtime{HTTP: server.Client()}) + if err != nil { + t.Fatal(err) + } + _, err = client.CheckAuth(context.Background()) + if err == nil || !strings.Contains(err.Error(), "refused cross-origin redirect") { + t.Fatalf("err=%v want cross-origin redirect refusal", err) + } +} + +func TestClientCheckAuthUsesCurrentUserEndpoint(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/api/v0/users/current/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer vast-secret" { + t.Fatalf("Authorization=%q", got) + } + writeJSON(t, w, map[string]any{"id": 7, "username": "alice"}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + user, err := client.CheckAuth(context.Background()) + if err != nil { + t.Fatal(err) + } + if user.ID != 7 || user.Username != "alice" { + t.Fatalf("user=%#v", user) + } +} + +func TestRedactVastAPIErrorSecrets(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"api_key":"vast-secret","instance_api_key":"inst-secret","jupyter_url":"https://host/?token=jupyter-secret","user_data":"secret cloud init","private_key":"-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----"}`)) + })) + defer server.Close() + + client := newTestVastClient(t, server) + _, err := client.CheckAuth(context.Background()) + if err == nil { + t.Fatal("expected API error") + } + text := err.Error() + for _, secret := range []string{"vast-secret", "inst-secret", "jupyter-secret", "secret cloud init", "BEGIN PRIVATE KEY", "abc"} { + if strings.Contains(text, secret) { + t.Fatalf("error leaked %q in %q", secret, text) + } + } + if !strings.Contains(text, "") { + t.Fatalf("error was not redacted: %q", text) + } +} + +func TestOfferSearchPayloadAndDecode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/api/v0/bundles/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body["type"] != "ondemand" || + body["gpu_name"].(map[string]any)["eq"] != "H100" || + body["verified"].(map[string]any)["eq"] != true || + body["rentable"].(map[string]any)["eq"] != true || + body["rented"].(map[string]any)["eq"] != false { + t.Fatalf("unexpected search body: %#v", body) + } + order := body["order"].([]any) + if len(order) != 1 || order[0].([]any)[0] != "dlperf_per_dphtotal" || order[0].([]any)[1] != "desc" { + t.Fatalf("order=%#v", body["order"]) + } + if _, ok := body["query"]; ok { + t.Fatalf("search body should use top-level filters: %#v", body) + } + if body["direct_port_count"].(map[string]any)["gte"] != float64(1) || + body["num_gpus"].(map[string]any)["gte"] != float64(4) || + body["reliability"].(map[string]any)["gte"] != 0.95 || + body["dph_total"].(map[string]any)["lte"] != 3.5 { + t.Fatalf("unexpected search filters: %#v", body) + } + writeJSON(t, w, map[string]any{"offers": []map[string]any{{"id": 11, "gpu_name": "H100", "ssh_host": "203.0.113.10", "ssh_port": 2201}}}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + offers, err := client.SearchOffers(context.Background(), vastOfferSearchInput{Config: VastConfig{ + InstanceType: "on-demand", + GPUName: "H100", + GPUCount: 4, + MaxDphTotal: 3.5, + MinReliability: 0.95, + Order: "dlperf_per_dphtotal desc", + }}) + if err != nil { + t.Fatal(err) + } + if len(offers) != 1 || offers[0].ID != 11 || offers[0].SSHPort != 2201 { + t.Fatalf("offers=%#v", offers) + } +} + +func TestOfferSearchPayloadMapsInterruptibleToBid(t *testing.T) { + body := buildVastOfferSearchPayload(VastConfig{InstanceType: "interruptible"}) + if body["type"] != "bid" { + t.Fatalf("type=%#v want bid", body["type"]) + } +} + +func TestCreateInstancePayloadAndDecodeNewContract(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut || r.URL.Path != "/api/v0/asks/42/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body["runtype"] != "ssh_direct" || body["target_state"] != "running" || + body["cancel_unavail"] != true || body["vm"] != false || + body["image"] != "nvidia/cuda:12" || body["template_hash_id"] != "tpl-123" || + body["disk"] != float64(80) || body["label"] != "cbx1|lease|slug|active" || + body["ssh_key"] != "ssh-ed25519 AAAA..." { + t.Fatalf("unexpected create body: %#v", body) + } + if _, ok := body["template_id"]; ok { + t.Fatalf("create body should use template_hash_id: %#v", body) + } + if body["env"] != "-e CRABBOX=1" { + t.Fatalf("env=%#v", body["env"]) + } + writeJSON(t, w, map[string]any{"success": true, "new_contract": 99}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + resp, err := client.CreateInstance(context.Background(), 42, vastCreateInstanceInput{ + Config: VastConfig{Image: "nvidia/cuda:12", TemplateID: "tpl-123", Runtype: "ssh_direct", DiskGB: 80}, + Label: "cbx1|lease|slug|active", + SSHKey: "ssh-ed25519 AAAA...", + Environment: map[string]string{"CRABBOX": "1"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.NewContract != 99 || !resp.Success { + t.Fatalf("resp=%#v", resp) + } +} + +func TestInstanceMethodsAndDecoding(t *testing.T) { + var seen []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen = append(seen, r.Method+" "+r.URL.RequestURI()) + switch r.Method + " " + r.URL.Path { + case "GET /api/v0/instances/99/": + writeJSON(t, w, map[string]any{"instances": map[string]any{"id": 99, "actual_status": "running", "ssh_host": "198.51.100.8", "ssh_port": 2222, "gpu_name": "RTX 4090"}}) + case "GET /api/v1/instances/": + if r.URL.Query().Get("limit") != "25" { + t.Fatalf("list query=%s", r.URL.RawQuery) + } + switch r.URL.Query().Get("after_token") { + case "": + writeJSON(t, w, map[string]any{"next_token": "page-2", "instances": []map[string]any{{"id": 99}}}) + case "page-2": + writeJSON(t, w, map[string]any{"next_token": nil, "instances": []map[string]any{{"contract_id": 100}}}) + default: + t.Fatalf("unexpected after_token query=%s", r.URL.RawQuery) + } + case "PUT /api/v0/instances/99/": + var body vastManageInstanceInput + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body.State != "stopped" || body.Label != "cbx1|lease|slug|stopped" { + t.Fatalf("manage body=%#v", body) + } + writeJSON(t, w, map[string]any{"instance": map[string]any{"id": 99, "intended_status": "stopped"}}) + case "DELETE /api/v0/instances/99/": + w.WriteHeader(http.StatusNoContent) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + client := newTestVastClient(t, server) + instance, err := client.GetInstance(context.Background(), 99) + if err != nil { + t.Fatal(err) + } + if instance.ID != 99 || instance.SSHHost != "198.51.100.8" || instance.SSHPort != 2222 { + t.Fatalf("instance=%#v", instance) + } + list, err := client.ListInstances(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(list) != 2 || list[1].ID != 100 { + t.Fatalf("list=%#v", list) + } + managed, err := client.ManageInstance(context.Background(), 99, vastManageInstanceInput{State: "stopped", Label: "cbx1|lease|slug|stopped"}) + if err != nil { + t.Fatal(err) + } + if managed.Status != "stopped" { + t.Fatalf("managed=%#v", managed) + } + if err := client.DestroyInstance(context.Background(), 99); err != nil { + t.Fatal(err) + } + wantSeen := "GET /api/v0/instances/99/,GET /api/v1/instances/?limit=25,GET /api/v1/instances/?after_token=page-2&limit=25,PUT /api/v0/instances/99/,DELETE /api/v0/instances/99/" + if strings.Join(seen, ",") != wantSeen { + t.Fatalf("seen=%v", seen) + } +} + +func TestInstanceSSHKeyMethods(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /api/v0/instances/99/ssh/": + writeJSON(t, w, map[string]any{"keys": []map[string]any{{"id": "key-1", "ssh_key": "ssh-ed25519 AAAA..."}}}) + case "POST /api/v0/instances/99/ssh/": + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body["ssh_key"] != "ssh-ed25519 AAAA..." { + t.Fatalf("attach body=%#v", body) + } + writeJSON(t, w, map[string]any{"success": true, "key": map[string]any{"id": "key-1"}}) + case "DELETE /api/v0/instances/99/ssh/key-1/": + w.WriteHeader(http.StatusNoContent) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + client := newTestVastClient(t, server) + keys, err := client.ListInstanceSSHKeys(context.Background(), 99) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 || keys[0].ID != "key-1" { + t.Fatalf("keys=%#v", keys) + } + attached, err := client.AttachInstanceSSHKey(context.Background(), 99, "ssh-ed25519 AAAA...") + if err != nil { + t.Fatal(err) + } + if !attached.Success || attached.Key.ID != "key-1" { + t.Fatalf("attached=%#v", attached) + } + if err := client.DetachInstanceSSHKey(context.Background(), 99, "key-1"); err != nil { + t.Fatal(err) + } +} + +func TestClientRejectsNonHTTPSExceptLoopback(t *testing.T) { + if _, err := newVastClient(VastConfig{APIKey: "secret", APIURL: "http://vast.example.test"}, Runtime{}); err == nil { + t.Fatal("expected non-https non-loopback rejection") + } + if _, err := newVastClient(VastConfig{APIKey: "secret", APIURL: "http://127.0.0.1:8080/api/v0"}, Runtime{}); err != nil { + t.Fatalf("loopback rejected: %v", err) + } +} + +func newTestVastClient(t *testing.T, server *httptest.Server) *vastClient { + t.Helper() + api, err := newVastClient(VastConfig{APIKey: "vast-secret", APIURL: server.URL + "/api/v0"}, Runtime{HTTP: server.Client()}) + if err != nil { + t.Fatal(err) + } + client, ok := api.(*vastClient) + if !ok { + t.Fatalf("client type=%T", api) + } + return client +} + +func writeJSON(t *testing.T, w http.ResponseWriter, value any) { + t.Helper() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(value); err != nil { + t.Fatal(err) + } +} diff --git a/internal/providers/vast/core.go b/internal/providers/vast/core.go new file mode 100644 index 000000000..3f321e344 --- /dev/null +++ b/internal/providers/vast/core.go @@ -0,0 +1,59 @@ +package vast + +import ( + "flag" + "strings" + + core "github.com/openclaw/crabbox/internal/cli" +) + +type Config = core.Config +type VastConfig = core.VastConfig +type ProviderSpec = core.ProviderSpec +type Runtime = core.Runtime +type Backend = core.Backend +type DoctorRequest = core.DoctorRequest +type DoctorResult = core.DoctorResult +type AcquireRequest = core.AcquireRequest +type ResolveRequest = core.ResolveRequest +type ListRequest = core.ListRequest +type LeaseView = core.LeaseView +type ReleaseLeaseRequest = core.ReleaseLeaseRequest +type TouchRequest = core.TouchRequest +type LeaseTarget = core.LeaseTarget +type Server = core.Server +type SSHTarget = core.SSHTarget + +const ( + providerName = "vast" + targetLinux = core.TargetLinux +) + +func exit(code int, format string, args ...any) core.ExitError { + return core.Exit(code, format, args...) +} + +func flagWasSet(fs *flag.FlagSet, name string) bool { + return core.FlagWasSet(fs, name) +} + +func markVastWorkRootExplicit(cfg *Config) { + core.MarkVastWorkRootExplicit(cfg) +} + +func markReleaseActionExplicit(cfg *Config) { + core.MarkDeleteOnReleaseExplicit(cfg, providerName) +} + +func normalizeInstanceType(value string) string { + return core.NormalizeVastInstanceType(value) +} + +func isVastProviderName(provider string) bool { + switch strings.ToLower(strings.TrimSpace(provider)) { + case providerName, "vast-ai", "vastai": + return true + default: + return false + } +} diff --git a/internal/providers/vast/flags.go b/internal/providers/vast/flags.go new file mode 100644 index 000000000..f3f47ba1b --- /dev/null +++ b/internal/providers/vast/flags.go @@ -0,0 +1,104 @@ +package vast + +import "flag" + +type vastFlagValues struct { + APIURL *string + InstanceType *string + GPUName *string + GPUCount *int + Image *string + TemplateID *string + Runtype *string + DiskGB *int + MaxDphTotal *float64 + MinReliability *float64 + Order *string + User *string + WorkRoot *string + ReleaseAction *string +} + +// RegisterVastProviderFlags exposes only non-secret Vast settings. API keys +// are sourced from CRABBOX_VAST_API_KEY / VAST_API_KEY and never argv. +func RegisterVastProviderFlags(fs *flag.FlagSet, defaults Config) any { + return vastFlagValues{ + APIURL: fs.String("vast-api-url", defaults.Vast.APIURL, "Vast.ai REST API URL"), + InstanceType: fs.String("vast-instance-type", defaults.Vast.InstanceType, "Vast.ai offer type: ondemand or interruptible"), + GPUName: fs.String("vast-gpu-name", defaults.Vast.GPUName, "Vast.ai GPU name selector"), + GPUCount: fs.Int("vast-gpu-count", defaults.Vast.GPUCount, "Vast.ai minimum GPU count"), + Image: fs.String("vast-image", defaults.Vast.Image, "Docker image to deploy on the instance"), + TemplateID: fs.String("vast-template-id", defaults.Vast.TemplateID, "Optional Vast.ai template ID"), + Runtype: fs.String("vast-runtype", defaults.Vast.Runtype, "Vast.ai runtime type: ssh_direct"), + DiskGB: fs.Int("vast-disk-gb", defaults.Vast.DiskGB, "Instance disk size in GB"), + MaxDphTotal: fs.Float64("vast-max-dph-total", defaults.Vast.MaxDphTotal, "Maximum total dollars per hour"), + MinReliability: fs.Float64("vast-min-reliability", defaults.Vast.MinReliability, "Minimum reliability score from 0 to 1"), + Order: fs.String("vast-order", defaults.Vast.Order, "Vast.ai offer ordering expression"), + User: fs.String("vast-user", defaults.Vast.User, "SSH user for Vast.ai instances"), + WorkRoot: fs.String("vast-work-root", defaults.Vast.WorkRoot, "remote Crabbox work root on Vast.ai instances"), + ReleaseAction: fs.String("vast-release-action", defaults.Vast.ReleaseAction, "Vast.ai release action: destroy, stop, or keep"), + } +} + +func ApplyVastProviderFlags(cfg *Config, fs *flag.FlagSet, values any) error { + if isVastProviderName(cfg.Provider) { + if flagWasSet(fs, "class") { + return exit(2, "--class is not supported for provider=%s; use --vast-gpu-name or --vast-gpu-count", providerName) + } + if flagWasSet(fs, "type") { + return exit(2, "--type is not supported for provider=%s; use --vast-image", providerName) + } + } + v, ok := values.(vastFlagValues) + if !ok { + return nil + } + if flagWasSet(fs, "vast-api-url") { + cfg.Vast.APIURL = *v.APIURL + } + if flagWasSet(fs, "vast-instance-type") { + cfg.Vast.InstanceType = normalizeInstanceType(*v.InstanceType) + } + if flagWasSet(fs, "vast-gpu-name") { + cfg.Vast.GPUName = *v.GPUName + } + if flagWasSet(fs, "vast-gpu-count") { + cfg.Vast.GPUCount = *v.GPUCount + } + if flagWasSet(fs, "vast-image") { + cfg.Vast.Image = *v.Image + } + if flagWasSet(fs, "vast-template-id") { + cfg.Vast.TemplateID = *v.TemplateID + } + if flagWasSet(fs, "vast-runtype") { + cfg.Vast.Runtype = *v.Runtype + } + if flagWasSet(fs, "vast-disk-gb") { + cfg.Vast.DiskGB = *v.DiskGB + } + if flagWasSet(fs, "vast-max-dph-total") { + cfg.Vast.MaxDphTotal = *v.MaxDphTotal + } + if flagWasSet(fs, "vast-min-reliability") { + cfg.Vast.MinReliability = *v.MinReliability + } + if flagWasSet(fs, "vast-order") { + cfg.Vast.Order = *v.Order + } + if flagWasSet(fs, "vast-user") { + cfg.Vast.User = *v.User + } + if flagWasSet(fs, "vast-work-root") { + cfg.Vast.WorkRoot = *v.WorkRoot + markVastWorkRootExplicit(cfg) + } + if flagWasSet(fs, "vast-release-action") { + cfg.Vast.ReleaseAction = *v.ReleaseAction + markReleaseActionExplicit(cfg) + } + if isVastProviderName(cfg.Provider) { + return Provider{}.ValidateConfig(*cfg) + } + return nil +} diff --git a/internal/providers/vast/labels.go b/internal/providers/vast/labels.go new file mode 100644 index 000000000..872ab0b5e --- /dev/null +++ b/internal/providers/vast/labels.go @@ -0,0 +1,60 @@ +package vast + +import ( + "strings" + "unicode" +) + +const ( + vastOwnershipLabelPrefix = "cbx1|" + vastLabelMaxPart = 48 +) + +type vastOwnershipLabel struct { + LeaseID string + Slug string + State string +} + +func encodeVastOwnershipLabel(leaseID, slug, state string) string { + leaseID = sanitizeVastLabelPart(leaseID, vastLabelMaxPart) + slug = sanitizeVastLabelPart(slug, vastLabelMaxPart) + state = sanitizeVastLabelPart(state, 16) + parts := []string{leaseID, slug, state} + return vastOwnershipLabelPrefix + strings.Join(parts, "|") +} + +func decodeVastOwnershipLabel(label string) (vastOwnershipLabel, bool) { + if !strings.HasPrefix(label, vastOwnershipLabelPrefix) { + return vastOwnershipLabel{}, false + } + parts := strings.Split(strings.TrimPrefix(label, vastOwnershipLabelPrefix), "|") + if len(parts) != 3 || parts[0] == "" { + return vastOwnershipLabel{}, false + } + return vastOwnershipLabel{LeaseID: parts[0], Slug: parts[1], State: parts[2]}, true +} + +func isVastCrabboxOwnedLabel(label string) bool { + _, ok := decodeVastOwnershipLabel(label) + return ok +} + +func sanitizeVastLabelPart(value string, limit int) string { + value = strings.TrimSpace(value) + var b strings.Builder + for _, r := range value { + switch { + case unicode.IsLetter(r) || unicode.IsDigit(r): + b.WriteRune(r) + case r == '-' || r == '_' || r == '.': + b.WriteRune(r) + default: + b.WriteByte('-') + } + if b.Len() >= limit { + break + } + } + return strings.Trim(b.String(), "-_.") +} diff --git a/internal/providers/vast/labels_test.go b/internal/providers/vast/labels_test.go new file mode 100644 index 000000000..331160922 --- /dev/null +++ b/internal/providers/vast/labels_test.go @@ -0,0 +1,52 @@ +package vast + +import ( + "strings" + "testing" +) + +func TestOwnershipLabelRoundTrip(t *testing.T) { + label := encodeVastOwnershipLabel("lease_123", "gpu-box", "active") + got, ok := decodeVastOwnershipLabel(label) + if !ok { + t.Fatalf("label did not decode: %q", label) + } + if got.LeaseID != "lease_123" || got.Slug != "gpu-box" || got.State != "active" { + t.Fatalf("decoded=%#v", got) + } + if !isVastCrabboxOwnedLabel(label) { + t.Fatalf("owned label not recognized: %q", label) + } +} + +func TestOwnershipLabelRejectsMalformedAndManualLabels(t *testing.T) { + for _, label := range []string{ + "", + "crabbox lease lease_123", + "manual-crabbox-instance", + "cbx1|", + "cbx1|lease|missing-state", + "cbx2|lease|slug|active", + } { + if _, ok := decodeVastOwnershipLabel(label); ok { + t.Fatalf("label %q decoded as owned", label) + } + if isVastCrabboxOwnedLabel(label) { + t.Fatalf("label %q recognized as owned", label) + } + } +} + +func TestOwnershipLabelSanitizesAndBoundsParts(t *testing.T) { + label := encodeVastOwnershipLabel("lease/with spaces", strings.Repeat("x", 80), "active now") + if strings.ContainsAny(label, " /") { + t.Fatalf("label was not sanitized: %q", label) + } + got, ok := decodeVastOwnershipLabel(label) + if !ok { + t.Fatalf("label did not decode: %q", label) + } + if got.LeaseID != "lease-with-spaces" || len(got.Slug) != vastLabelMaxPart || got.State != "active-now" { + t.Fatalf("decoded=%#v label=%q", got, label) + } +} diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go new file mode 100644 index 000000000..c8e5fd202 --- /dev/null +++ b/internal/providers/vast/provider.go @@ -0,0 +1,132 @@ +package vast + +import ( + "flag" + "net/url" + "strings" + + core "github.com/openclaw/crabbox/internal/cli" +) + +func init() { + core.RegisterProvider(Provider{}) +} + +type Provider struct{} + +func (Provider) Name() string { return providerName } + +func (Provider) Aliases() []string { + return []string{"vast-ai", "vastai"} +} + +func (Provider) Spec() core.ProviderSpec { + return core.ProviderSpec{ + Name: providerName, + Family: "vast", + Kind: core.ProviderKindSSHLease, + Targets: []core.TargetSpec{{OS: core.TargetLinux}}, + Features: core.FeatureSet{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup}, + Coordinator: core.CoordinatorNever, + } +} + +func (Provider) RegisterFlags(fs *flag.FlagSet, defaults core.Config) any { + return RegisterVastProviderFlags(fs, defaults) +} + +func (Provider) ApplyFlags(cfg *core.Config, fs *flag.FlagSet, values any) error { + return ApplyVastProviderFlags(cfg, fs, values) +} + +func (Provider) PrepareLeaseClaimEndpoint(existing core.LeaseClaim, provider, slug string, server core.Server, allowProviderMetadata bool) (core.Server, error) { + if provider != providerName { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s as provider=%s", existing.LeaseID, provider) + } + if slug != existing.Slug { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with slug=%s", existing.LeaseID, slug) + } + if server.Labels["lease"] != existing.LeaseID || server.Labels["slug"] != existing.Slug { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with mismatched label identity", existing.LeaseID) + } + if existing.CloudID != "" && server.CloudID != "" && existing.CloudID != server.CloudID { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with stale instance identity", existing.LeaseID) + } + if allowProviderMetadata { + return server, nil + } + server.Labels = preserveVastClaimMetadata(server.Labels, existing.Labels) + return server, nil +} + +func (p Provider) Configure(cfg core.Config, rt core.Runtime) (core.Backend, error) { + if err := p.ValidateConfig(cfg); err != nil { + return nil, err + } + if cfg.TargetOS != "" && cfg.TargetOS != core.TargetLinux { + return nil, exit(2, "provider=%s supports target=linux only", providerName) + } + if cfg.Tailscale.Enabled || cfg.Network == core.NetworkTailscale { + return nil, exit(2, "provider=%s does not support Tailscale options", providerName) + } + return newBackend(p.Spec(), cfg, rt), nil +} + +func (p Provider) ConfigureDoctor(cfg core.Config, rt core.Runtime) (core.DoctorBackend, error) { + backend, err := p.Configure(cfg, rt) + if err != nil { + return nil, err + } + doctor, ok := backend.(core.DoctorBackend) + if !ok { + return nil, core.Exit(2, "vast doctor backend unavailable") + } + return doctor, nil +} + +func (Provider) ValidateConfig(cfg core.Config) error { + apiURL := strings.TrimSpace(cfg.Vast.APIURL) + if apiURL == "" { + return exit(2, "vast.apiUrl is required") + } + u, err := url.Parse(apiURL) + if err != nil || u.Scheme == "" || u.Host == "" || u.User != nil { + return exit(2, "vast.apiUrl must be an absolute URL without credentials") + } + if u.RawQuery != "" || u.Fragment != "" { + return exit(2, "vast.apiUrl must not include query strings or fragments") + } + switch strings.ToLower(u.Scheme) { + case "https", "http": + default: + return exit(2, "vast.apiUrl must use http or https") + } + switch normalizeInstanceType(cfg.Vast.InstanceType) { + case "ondemand", "interruptible": + default: + return exit(2, "vast.instanceType must be ondemand or interruptible") + } + switch strings.ToLower(strings.TrimSpace(cfg.Vast.Runtype)) { + case "ssh_direct": + default: + return exit(2, "vast.runtype must be ssh_direct") + } + if cfg.Vast.GPUCount < 0 { + return exit(2, "vast.gpuCount must be non-negative") + } + if cfg.Vast.DiskGB < 0 { + return exit(2, "vast.diskGB must be non-negative") + } + if cfg.Vast.MaxDphTotal < 0 { + return exit(2, "vast.maxDphTotal must be non-negative") + } + if cfg.Vast.MinReliability < 0 || cfg.Vast.MinReliability > 1 { + return exit(2, "vast.minReliability must be between 0 and 1") + } + switch strings.ToLower(strings.TrimSpace(cfg.Vast.ReleaseAction)) { + case "", "destroy", "delete", "stop", "keep": + default: + return exit(2, "vast.releaseAction must be destroy, delete, stop, or keep") + } + return nil +} diff --git a/internal/providers/vast/provider_test.go b/internal/providers/vast/provider_test.go new file mode 100644 index 000000000..628a633df --- /dev/null +++ b/internal/providers/vast/provider_test.go @@ -0,0 +1,230 @@ +package vast + +import ( + "flag" + "strings" + "testing" + + core "github.com/openclaw/crabbox/internal/cli" +) + +func TestProviderSpecAndAliases(t *testing.T) { + p := Provider{} + if p.Name() != "vast" { + t.Fatalf("Name=%q", p.Name()) + } + aliases := p.Aliases() + if len(aliases) != 2 || aliases[0] != "vast-ai" || aliases[1] != "vastai" { + t.Fatalf("aliases=%v", aliases) + } + spec := p.Spec() + if spec.Name != "vast" || spec.Family != "vast" || spec.Kind != core.ProviderKindSSHLease || spec.Coordinator != core.CoordinatorNever { + t.Fatalf("spec=%#v", spec) + } + if len(spec.Targets) != 1 || spec.Targets[0].OS != core.TargetLinux { + t.Fatalf("targets=%#v", spec.Targets) + } + for _, feature := range []core.Feature{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup} { + if !spec.Features.Has(feature) { + t.Fatalf("features=%v missing %s", spec.Features, feature) + } + } +} + +func TestProviderFlagsApplyNonSecretConfig(t *testing.T) { + cfg := core.Config{ + Provider: "vast", + Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + Image: "nvidia/cuda:default", + DiskGB: 20, + User: "root", + WorkRoot: "/work", + ReleaseAction: "destroy", + }, + } + fs := flag.NewFlagSet("test", flag.ContinueOnError) + values := RegisterVastProviderFlags(fs, cfg) + if err := fs.Parse([]string{ + "--vast-api-url", "https://approved.example.test/api/v0", + "--vast-instance-type", "on-demand", + "--vast-gpu-name", "H100", + "--vast-gpu-count", "4", + "--vast-image", "nvidia/cuda:12", + "--vast-template-id", "tpl-123", + "--vast-runtype", "ssh_direct", + "--vast-disk-gb", "80", + "--vast-max-dph-total", "4.5", + "--vast-min-reliability", "0.95", + "--vast-order", "reliability desc", + "--vast-user", "ubuntu", + "--vast-work-root", "/work/vast", + "--vast-release-action", "keep", + }); err != nil { + t.Fatal(err) + } + if err := ApplyVastProviderFlags(&cfg, fs, values); err != nil { + t.Fatal(err) + } + if cfg.Vast.APIURL != "https://approved.example.test/api/v0" || + cfg.Vast.InstanceType != "ondemand" || + cfg.Vast.GPUName != "H100" || + cfg.Vast.GPUCount != 4 || + cfg.Vast.Image != "nvidia/cuda:12" || + cfg.Vast.TemplateID != "tpl-123" || + cfg.Vast.Runtype != "ssh_direct" || + cfg.Vast.DiskGB != 80 || + cfg.Vast.MaxDphTotal != 4.5 || + cfg.Vast.MinReliability != 0.95 || + cfg.Vast.Order != "reliability desc" || + cfg.Vast.User != "ubuntu" || + cfg.Vast.WorkRoot != "/work/vast" || + cfg.Vast.ReleaseAction != "keep" { + t.Fatalf("vast config=%#v", cfg.Vast) + } +} + +func TestProviderFlagsDoNotExposeAPIKey(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + RegisterVastProviderFlags(fs, core.Config{}) + fs.VisitAll(func(f *flag.Flag) { + if strings.Contains(f.Name, "api-key") { + t.Fatalf("unexpected secret flag --%s", f.Name) + } + }) + forbidden := "vast-" + "api-key" + if fs.Lookup(forbidden) != nil { + t.Fatalf("unexpected --%s flag", forbidden) + } +} + +func TestValidateConfigRejectsUnsafeValues(t *testing.T) { + base := core.Config{Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + DiskGB: 20, + ReleaseAction: "destroy", + }} + tests := []struct { + name string + mutate func(*core.Config) + want string + }{ + { + name: "credential url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://user:secret@vast.example.test" + }, + want: "absolute URL without credentials", + }, + { + name: "query url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://vast.example.test/api/v0?token=secret" + }, + want: "query strings or fragments", + }, + { + name: "fragment url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://vast.example.test/api/v0#secret" + }, + want: "query strings or fragments", + }, + { + name: "instance type", + mutate: func(cfg *core.Config) { + cfg.Vast.InstanceType = "spot" + }, + want: "vast.instanceType", + }, + { + name: "runtype", + mutate: func(cfg *core.Config) { + cfg.Vast.Runtype = "ssh_proxy" + }, + want: "vast.runtype", + }, + { + name: "disk", + mutate: func(cfg *core.Config) { + cfg.Vast.DiskGB = -1 + }, + want: "vast.diskGB", + }, + { + name: "reliability", + mutate: func(cfg *core.Config) { + cfg.Vast.MinReliability = 1.1 + }, + want: "vast.minReliability", + }, + { + name: "release", + mutate: func(cfg *core.Config) { + cfg.Vast.ReleaseAction = "hibernate" + }, + want: "vast.releaseAction", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cfg := base + test.mutate(&cfg) + err := (Provider{}).ValidateConfig(cfg) + if err == nil || !strings.Contains(err.Error(), test.want) { + t.Fatalf("err=%v want %q", err, test.want) + } + }) + } + if err := (Provider{}).ValidateConfig(base); err != nil { + t.Fatalf("valid config rejected: %v", err) + } + base.Vast.InstanceType = "on-demand" + if err := (Provider{}).ValidateConfig(base); err != nil { + t.Fatalf("on-demand alias rejected: %v", err) + } +} + +func TestConfigureRejectsTailscaleBeforeBackend(t *testing.T) { + base := core.Config{ + TargetOS: core.TargetLinux, + Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + DiskGB: 20, + ReleaseAction: "destroy", + }, + } + tests := []struct { + name string + mutate func(*core.Config) + }{ + { + name: "enabled", + mutate: func(cfg *core.Config) { + cfg.Tailscale.Enabled = true + }, + }, + { + name: "network", + mutate: func(cfg *core.Config) { + cfg.Network = core.NetworkTailscale + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cfg := base + test.mutate(&cfg) + backend, err := (Provider{}).Configure(cfg, core.Runtime{}) + if err == nil || backend != nil || !strings.Contains(err.Error(), "does not support Tailscale") { + t.Fatalf("backend=%T err=%v, want Tailscale rejection", backend, err) + } + }) + } +} diff --git a/scripts/live-vast-smoke.sh b/scripts/live-vast-smoke.sh new file mode 100755 index 000000000..91cbb58da --- /dev/null +++ b/scripts/live-vast-smoke.sh @@ -0,0 +1,324 @@ +#!/usr/bin/env bash +set -euo pipefail + +provider_enabled() { + local list="${CRABBOX_LIVE_PROVIDERS:-}" + local item + IFS=',' read -ra items <<<"$list" + for item in "${items[@]}"; do + item="${item//[[:space:]]/}" + if [[ "$item" == "vast" || "$item" == "vast-ai" || "$item" == "vastai" ]]; then + return 0 + fi + done + return 1 +} + +redact_output() { + VAST_SMOKE_REDACT_TOKEN_1="${CRABBOX_VAST_API_KEY:-}" \ + VAST_SMOKE_REDACT_TOKEN_2="${VAST_API_KEY:-}" python3 -c ' +import os +import re +import sys + +body = sys.stdin.read() +for token in (os.environ.get("VAST_SMOKE_REDACT_TOKEN_1", ""), os.environ.get("VAST_SMOKE_REDACT_TOKEN_2", "")): + if token: + body = body.replace(token, "") +fields = ( + "api_key", + "instance_api_key", + "instanceApiKey", + "jupyter_token", + "jupyterToken", + "jupyter_url", + "jupyterUrl", + "user_data", + "userData", + "private_key", + "privateKey", + "ssh_key", +) +for field in fields: + body = re.sub(rf"(\"{re.escape(field)}\"\s*:\s*\")[^\"]*(\")", rf"\1\2", body, flags=re.IGNORECASE) + body = re.sub(rf"({re.escape(field)}\s*[=:]\s*)[^\",\s]+", rf"\1", body, flags=re.IGNORECASE) +body = re.sub(r"-----BEGIN [A-Z ]*PRIVATE KEY-----[\s\S]*?-----END [A-Z ]*PRIVATE KEY-----", "", body) +body = re.sub(r"https?://[^\s\"'\''<>]*(?:token|api_key|apikey|auth|signature|sig|access_token)=[^\s\"'\''<>]+", "", body, flags=re.IGNORECASE) +sys.stdout.write(body) +' +} + +classify_known_external_blocker() { + local command="$1" + local status="$2" + local output="$3" + local classification="" + local lower + lower="$(printf '%s' "$output" | tr '[:upper:]' '[:lower:]')" + if [[ "$lower" == *billing* || "$lower" == *fund* || "$lower" == *payment* || "$lower" == *balance* || "$lower" == *credit* ]]; then + classification="billing_blocked" + elif [[ "$lower" == *quota* || "$lower" == *"rate limit"* || "$lower" == *"account limit"* || "$lower" == *"limit exceeded"* ]]; then + classification="quota_blocked" + elif [[ "$lower" == *capacity* || "$lower" == *"no eligible offers"* || "$lower" == *"no offers"* || "$lower" == *"no matching"* || "$lower" == *"not enough"* || "$lower" == *"insufficient"* || "$lower" == *"resource exhausted"* || "$lower" == *"found no eligible offers"* ]]; then + classification="capacity_blocked" + elif [[ "$lower" == *unauthorized* || "$lower" == *forbidden* || "$lower" == *"permission denied"* || "$lower" == *"invalid api"* || "$lower" == *"invalid-api"* || "$lower" == *"missing_token"* || "$lower" == *"api key"* ]]; then + classification="environment_blocked" + else + return 1 + fi + printf 'classification=%s command=%q exit=%s\n' "$classification" "$command" "$status" >&2 + printf '%s\n' "$output" | redact_output >&2 + return 0 +} + +classify_validation_failure() { + local command="$1" + local status="$2" + local output="$3" + printf 'classification=validation_failed command=%q exit=%s\n' "$command" "$status" >&2 + printf '%s\n' "$output" | redact_output >&2 +} + +run_capture() { + local command="$1" + shift + local output + set +e + output="$("$@" 2>&1)" + local status=$? + set -e + if [[ "$status" -ne 0 ]]; then + if classify_known_external_blocker "$command" "$status" "$output"; then + exit 0 + fi + classify_validation_failure "$command" "$status" "$output" + exit "$status" + fi + CAPTURED_OUTPUT="$(printf '%s\n' "$output" | redact_output)" +} + +run_capture_validation() { + local command="$1" + shift + local output + set +e + output="$("$@" 2>&1)" + local status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$output" + exit "$status" + fi + CAPTURED_OUTPUT="$(printf '%s\n' "$output" | redact_output)" +} + +validate_list_json_contains_slug() { + local command="$1" + local output="$2" + local validation_output="" + local status=0 + set +e + validation_output="$(CRABBOX_SMOKE_SLUG="$slug" python3 -c ' +import json +import os +import sys + +slug = os.environ["CRABBOX_SMOKE_SLUG"] +try: + payload = json.load(sys.stdin) +except Exception as exc: + print(f"invalid JSON: {exc}", file=sys.stderr) + sys.exit(1) + +def has_slug(value): + if isinstance(value, dict): + labels = value.get("labels") or value.get("tags") + if isinstance(labels, dict) and (labels.get("slug") == slug or labels.get("crabbox.slug") == slug): + return True + label = str(value.get("label", "")) + if f"|{slug}|" in label or value.get("slug") == slug or value.get("name") == slug or value.get("id") == slug or value.get("leaseId") == slug: + return True + return any(has_slug(child) for child in value.values()) + if isinstance(value, list): + return any(has_slug(child) for child in value) + return False + +if not has_slug(payload): + print(f"list JSON did not include slug {slug}", file=sys.stderr) + sys.exit(1) +' <<<"$output" 2>&1)" + status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$validation_output" + exit "$status" + fi +} + +validate_list_json_empty() { + local command="$1" + local output="$2" + local validation_output="" + local status=0 + set +e + validation_output="$(python3 -c ' +import json +import sys + +try: + payload = json.load(sys.stdin) +except Exception as exc: + print(f"invalid JSON: {exc}", file=sys.stderr) + sys.exit(1) + +if payload != []: + print("Vast Crabbox inventory is not empty", file=sys.stderr) + sys.exit(1) +' <<<"$output" 2>&1)" + status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$validation_output" + exit "$status" + fi +} + +validate_nvidia_smi() { + local command="$1" + local output="$2" + if [[ "$output" != *"NVIDIA-SMI"* && "$output" != *"NVIDIA"* ]]; then + classify_validation_failure "$command" 1 "remote command output did not include NVIDIA-SMI output" + exit 1 + fi +} + +cleanup_armed=0 +lease_acquired=0 +slug="" +config_file="" + +is_pre_acquire_not_found() { + local output="$1" + local lower + lower="$(printf '%s' "$output" | tr '[:upper:]' '[:lower:]')" + [[ "$lower" == *"not found"* || "$lower" == *"no lease"* || "$lower" == *"unknown lease"* || "$lower" == *"missing lease"* ]] +} + +cleanup() { + local status=$? + if [[ "$cleanup_armed" -eq 1 && -n "$slug" ]]; then + local cleanup_output="" + local cleanup_status=1 + local attempt + for attempt in 1 2 3; do + set +e + cleanup_output="$(bin/crabbox stop --provider vast "$slug" 2>&1)" + cleanup_status=$? + set -e + if [[ "$cleanup_status" -eq 0 ]]; then + cleanup_armed=0 + break + fi + sleep 2 + done + if [[ "$cleanup_status" -ne 0 ]]; then + if [[ "$lease_acquired" -eq 0 && "$status" -eq 0 ]] && is_pre_acquire_not_found "$cleanup_output"; then + cleanup_armed=0 + else + printf 'classification=cleanup_failed command=%q exit=%s slug=%s\n' "bin/crabbox stop --provider vast $slug" "$cleanup_status" "$slug" >&2 + printf '%s\n' "$cleanup_output" | redact_output >&2 + if [[ "$status" -eq 0 ]]; then + status="$cleanup_status" + fi + fi + fi + fi + if [[ -n "$config_file" ]]; then + rm -f "$config_file" + fi + exit "$status" +} +trap cleanup EXIT + +if [[ "${CRABBOX_LIVE:-}" != "1" ]]; then + printf 'classification=environment_blocked reason=CRABBOX_LIVE_not_enabled\n' + exit 0 +fi + +if ! provider_enabled; then + printf 'classification=environment_blocked reason=vast_not_selected providers=%q\n' "${CRABBOX_LIVE_PROVIDERS:-}" + exit 0 +fi + +if [[ -z "${CRABBOX_VAST_API_KEY:-}${VAST_API_KEY:-}" ]]; then + printf 'classification=environment_blocked reason=VAST_API_KEY_missing\n' + exit 0 +fi + +mkdir -p bin +go build -trimpath -o bin/crabbox ./cmd/crabbox + +slug="vast-smoke-$(date +%Y%m%d%H%M%S)-$$" +vast_gpu_name="${CRABBOX_LIVE_VAST_GPU_NAME:-}" +vast_gpu_count="${CRABBOX_LIVE_VAST_GPU_COUNT:-1}" +vast_max_dph_total="${CRABBOX_LIVE_VAST_MAX_DPH_TOTAL:-0}" +vast_instance_type="${CRABBOX_LIVE_VAST_INSTANCE_TYPE:-ondemand}" +vast_image="${CRABBOX_LIVE_VAST_IMAGE:-nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04}" +vast_release_action="${CRABBOX_LIVE_VAST_RELEASE_ACTION:-destroy}" + +config_file="$(mktemp)" +cat >"$config_file" <"$out" <<'SCRIPT' +${scriptBody} +SCRIPT +chmod +x "$out" +`, + ); +} + +const shellArgHelper = ` +arg_after() { + local want="$1" + shift + while [[ "$#" -gt 0 ]]; do + if [[ "$1" == "$want" ]]; then + printf '%s' "$2" + return 0 + fi + shift + done + return 1 +} +`; + +test("live vast smoke skips unless globally opted in", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-skip-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "", + VAST_API_KEY: "", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=CRABBOX_LIVE_not_enabled/); +}); + +test("live vast smoke skips unless vast is selected", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-provider-skip-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "lambda", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=vast_not_selected/); +}); + +test("live vast smoke requires token before building", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-token-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "", + VAST_API_KEY: "", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=VAST_API_KEY_missing/); +}); + +test("live vast smoke runs guarded lifecycle and redacts secret material", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + const slugFile = path.join(dir, "slug.txt"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +${shellArgHelper} +printf '%s\\n' "$*" >>"${calls}" +if [[ "\${CRABBOX_VAST_API_KEY:-}" != "test-secret-token" ]]; then + printf 'missing token\\n' >&2 + exit 91 +fi +case "$1" in + doctor) + printf 'auth=ready control_plane=ready inventory=ready api=list mutation=false leases=0 runtime=unchecked api_key=test-secret-token instance_api_key=visible jupyter_url=https://example.test/?token=abc\\n' + ;; + warmup) + arg_after --slug "$@" >"${slugFile}" + ;; + status) + printf 'status=ready\\n' + ;; + run) + printf 'NVIDIA-SMI 570.00.00\\n' + ;; + list) + slug="$(cat "${slugFile}" 2>/dev/null || true)" + if [[ -z "$slug" || -f "${slugFile}.stopped" ]]; then + printf '[]\\n' + else + printf '[{"labels":{"slug":"%s"},"provider":"vast"}]\\n' "$slug" + fi + ;; + stop) + printf stopped >"${slugFile}.stopped" + ;; + cleanup) + printf 'skip vast instance id=none name=none reason=missing labels\\n' + ;; + *) + printf 'unexpected args: %s\\n' "$*" >&2 + exit 99 + ;; +esac +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: " vast-ai ", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stdout, /classification=live_vast_smoke_passed/); + assert.doesNotMatch(result.stdout + result.stderr, /test-secret-token|visible|token=abc/); + + const seen = fs.readFileSync(calls, "utf8").trim().split("\n"); + assert.equal(seen[0], "doctor --provider vast"); + assert.equal(seen[1], "list --provider vast --json"); + assert.match(seen[2], /^warmup --provider vast --slug vast-smoke-\d{14}-\d+ --keep --ttl 20m --idle-timeout 5m$/); + assert.match(seen[3], /^status --provider vast --id vast-smoke-\d{14}-\d+ --wait --wait-timeout 600s$/); + assert.match(seen[4], /^run --provider vast --id vast-smoke-\d{14}-\d+ --no-sync -- nvidia-smi$/); + assert.equal(seen[5], "list --provider vast --json"); + assert.match(seen[6], /^stop --provider vast vast-smoke-\d{14}-\d+$/); + assert.equal(seen[7], "cleanup --provider vast --dry-run"); + assert.equal(seen[8], "list --provider vast --json"); +}); + +test("live vast smoke attempts cleanup after partial failure", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-fail-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$*" >>"${calls}" +if [[ "$1" == "doctor" || "$1" == "list" ]]; then + [[ "$1" == "list" ]] && printf '[]\\n' || printf 'auth=ready\\n' + exit 0 +fi +if [[ "$1" == "warmup" ]]; then + printf 'no eligible offers after partial create\\n' >&2 + exit 37 +fi +if [[ "$1" == "stop" ]]; then + exit 0 +fi +exit 99 +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stderr, /classification=capacity_blocked/); + assert.match(result.stderr, /no eligible offers/); + assert.match(fs.readFileSync(calls, "utf8"), /warmup .* --keep /); + assert.match(fs.readFileSync(calls, "utf8"), /stop --provider vast vast-smoke-\d{14}-\d+/); +}); + +test("live vast smoke preserves capacity classification when pre-acquire cleanup finds no lease", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-no-lease-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$*" >>"${calls}" +if [[ "$1" == "doctor" || "$1" == "list" ]]; then + [[ "$1" == "list" ]] && printf '[]\\n' || printf 'auth=ready\\n' + exit 0 +fi +if [[ "$1" == "warmup" ]]; then + printf 'no eligible offers found\\n' >&2 + exit 37 +fi +if [[ "$1" == "stop" ]]; then + printf 'lease not found\\n' >&2 + exit 44 +fi +exit 99 +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stderr, /classification=capacity_blocked/); + assert.doesNotMatch(result.stderr, /classification=cleanup_failed/); + assert.match(fs.readFileSync(calls, "utf8"), /stop --provider vast vast-smoke-\d{14}-\d+/); +}); + +test("live vast smoke validates nvidia-smi output", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-bad-run-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const slugFile = path.join(dir, "slug.txt"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +${shellArgHelper} +case "$1" in + doctor) + printf 'auth=ready\\n' + ;; + warmup) + arg_after --slug "$@" >"${slugFile}" + ;; + status) + printf 'status=ready\\n' + ;; + run) + printf 'ok\\n' + ;; + list) + slug="$(cat "${slugFile}" 2>/dev/null || true)" + if [[ -z "$slug" || -f "${slugFile}.stopped" ]]; then + printf '[]\\n' + else + printf '[{"labels":{"slug":"%s"}}]\\n' "$slug" + fi + ;; + stop) + printf stopped >"${slugFile}.stopped" + ;; + cleanup) + printf 'cleanup dry-run\\n' + ;; + *) + exit 99 + ;; +esac +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 1, result.stdout + result.stderr); + assert.match(result.stderr, /classification=validation_failed/); + assert.match(result.stderr, /NVIDIA-SMI/); +}); diff --git a/worker/vitest.config.ts b/worker/vitest.config.ts index 4db9ee006..1907ab03f 100644 --- a/worker/vitest.config.ts +++ b/worker/vitest.config.ts @@ -1,10 +1,13 @@ +import { fileURLToPath } from "node:url"; + import { defineConfig } from "vitest/config"; export default defineConfig({ resolve: { alias: { - "cloudflare:workers": new URL("./test/cloudflare-workers-runtime.ts", import.meta.url) - .pathname, + "cloudflare:workers": fileURLToPath( + new URL("./test/cloudflare-workers-runtime.ts", import.meta.url), + ), }, }, test: {