From 96cbe310aa6e82145aae0b9a8bef7b538c69caeb Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:06:30 -0700 Subject: [PATCH 01/18] feat: register vast provider config Add the first-party Vast provider configuration surface with env-only API-key sourcing, safe config output, provider flags, built-in registration, and offline placeholder backend contracts for later lifecycle plans. Add focused tests for Vast config precedence, credential destination provenance, provider discovery, flag registration, validation, and secret redaction. --- internal/cli/config.go | 212 ++++++++++++++++++ internal/cli/config_cmd.go | 24 ++ internal/cli/config_cmd_test.go | 95 ++++++++ internal/cli/config_test.go | 70 ++++++ internal/cli/credential_provenance.go | 10 + internal/cli/credential_provenance_test.go | 44 ++++ internal/cli/provider_categories_generated.go | 1 + internal/cli/providers_builtin_test.go | 33 +++ internal/cli/providers_test.go | 19 ++ internal/providers/all/all.go | 1 + internal/providers/all/all_test.go | 29 +++ internal/providers/vast/core.go | 62 +++++ internal/providers/vast/flags.go | 104 +++++++++ internal/providers/vast/provider.go | 145 ++++++++++++ internal/providers/vast/provider_test.go | 176 +++++++++++++++ 15 files changed, 1025 insertions(+) create mode 100644 internal/providers/vast/core.go create mode 100644 internal/providers/vast/flags.go create mode 100644 internal/providers/vast/provider.go create mode 100644 internal/providers/vast/provider_test.go diff --git a/internal/cli/config.go b/internal/cli/config.go index 888dd3ca5..d31d5dc52 100644 --- a/internal/cli/config.go +++ b/internal/cli/config.go @@ -163,6 +163,8 @@ type Config struct { Railway RailwayConfig FastAPICloud FastAPICloudConfig Runpod RunpodConfig + Vast VastConfig + vastWorkRootExplicit bool NvidiaBrev NvidiaBrevConfig nvidiaBrevWorkRootExplicit bool Hostinger HostingerConfig @@ -605,6 +607,26 @@ type RunpodConfig struct { WorkRoot string } +// VastConfig contains Vast.ai provider settings. APIKey is populated only from +// CRABBOX_VAST_API_KEY / VAST_API_KEY and must not be persisted or printed. +type VastConfig struct { + APIKey string + APIURL string + InstanceType string + GPUName string + GPUCount int + Image string + TemplateID string + Runtype string + DiskGB int + MaxDphTotal float64 + MinReliability float64 + Order string + User string + WorkRoot string + ReleaseAction string +} + // NvidiaBrevConfig is intentionally non-secret. Authentication stays in the // NVIDIA Brev CLI's own credential store and is never accepted as Crabbox // config or argv. @@ -1639,6 +1661,61 @@ func applyProviderConfigDefaults(cfg *Config) error { normalizeTargetConfig(cfg) return validateTargetConfig(*cfg) } + if cfg.Provider == "vast" { + cfg.Vast.InstanceType = normalizeVastInstanceType(cfg.Vast.InstanceType) + if cfg.Vast.APIURL == "" { + cfg.Vast.APIURL = "https://console.vast.ai/api/v0" + } + if cfg.Vast.InstanceType == "" { + cfg.Vast.InstanceType = "ondemand" + } + if cfg.Vast.Image == "" { + cfg.Vast.Image = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" + } + if cfg.Vast.Runtype == "" { + cfg.Vast.Runtype = "ssh_direct" + } + if cfg.Vast.DiskGB == 0 { + cfg.Vast.DiskGB = 20 + } + if cfg.Vast.Order == "" { + cfg.Vast.Order = "dlperf_per_dphtotal desc" + } + if cfg.Vast.User == "" { + cfg.Vast.User = "root" + } + if cfg.Vast.WorkRoot == "" { + cfg.Vast.WorkRoot = defaultPOSIXWorkRoot + } + if cfg.Vast.ReleaseAction == "" { + cfg.Vast.ReleaseAction = "destroy" + } + if !IsTargetExplicit(cfg) { + cfg.TargetOS = targetLinux + } + if cfg.explicitWindowsMode != "" { + cfg.WindowsMode = cfg.explicitWindowsMode + } else { + cfg.WindowsMode = windowsModeNormal + } + if cfg.explicitWorkRoot != "" && !IsVastWorkRootExplicit(cfg) { + cfg.Vast.WorkRoot = cfg.explicitWorkRoot + } + cfg.WorkRoot = cfg.Vast.WorkRoot + if cfg.explicitSSHUser != "" { + cfg.SSHUser = cfg.explicitSSHUser + } else { + cfg.SSHUser = cfg.Vast.User + } + if cfg.explicitSSHPort != "" { + cfg.SSHPort = cfg.explicitSSHPort + } else { + cfg.SSHPort = "22" + } + cfg.SSHFallbackPorts = nil + normalizeTargetConfig(cfg) + return validateTargetConfig(*cfg) + } if cfg.Provider == "nebius" { if cfg.Nebius.CLI == "" { cfg.Nebius.CLI = "nebius" @@ -2307,6 +2384,38 @@ func MarkNvidiaBrevWorkRootExplicit(cfg *Config) { cfg.nvidiaBrevWorkRootExplicit = true } +func IsVastWorkRootExplicit(cfg *Config) bool { + return cfg.vastWorkRootExplicit +} + +func MarkVastWorkRootExplicit(cfg *Config) { + cfg.vastWorkRootExplicit = true +} + +func EffectiveVastWorkRoot(cfg Config) string { + workRoot := cfg.Vast.WorkRoot + if !IsVastWorkRootExplicit(&cfg) && (workRoot == "" || workRoot == defaultPOSIXWorkRoot) && cfg.explicitWorkRoot != "" { + return cfg.explicitWorkRoot + } + if workRoot == "" { + return defaultPOSIXWorkRoot + } + return workRoot +} + +func NormalizeVastInstanceType(value string) string { + return normalizeVastInstanceType(value) +} + +func normalizeVastInstanceType(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "on-demand", "on_demand": + return "ondemand" + default: + return strings.ToLower(strings.TrimSpace(value)) + } +} + func EffectiveNvidiaBrevWorkRoot(cfg Config) string { workRoot := cfg.NvidiaBrev.WorkRoot providerDefault := workRoot == "" || workRoot == "/tmp/crabbox" @@ -2541,6 +2650,17 @@ func baseConfig() Config { Image: "runpod/pytorch:2.8.0-py3.11-cuda12.8.1-cudnn-devel-ubuntu22.04", DiskGB: 20, }, + Vast: VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Image: "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04", + Runtype: "ssh_direct", + DiskGB: 20, + Order: "dlperf_per_dphtotal desc", + User: "root", + WorkRoot: defaultPOSIXWorkRoot, + ReleaseAction: "destroy", + }, NvidiaBrev: NvidiaBrevConfig{ CLI: "brev", GPUName: "A100", @@ -2852,6 +2972,7 @@ type fileConfig struct { Railway *fileRailwayConfig `yaml:"railway,omitempty"` FastAPICloud *fileFastAPICloudConfig `yaml:"fastapiCloud,omitempty"` Runpod *fileRunpodConfig `yaml:"runpod,omitempty"` + Vast *fileVastConfig `yaml:"vast,omitempty"` NvidiaBrev *fileNvidiaBrevConfig `yaml:"nvidiaBrev,omitempty"` Hostinger *fileHostingerConfig `yaml:"hostinger,omitempty"` Wandb *fileWandbConfig `yaml:"wandb,omitempty"` @@ -3415,6 +3536,23 @@ type fileRunpodConfig struct { WorkRoot string `yaml:"workRoot,omitempty"` } +type fileVastConfig struct { + APIURL string `yaml:"apiUrl,omitempty"` + InstanceType string `yaml:"instanceType,omitempty"` + GPUName string `yaml:"gpuName,omitempty"` + GPUCount int `yaml:"gpuCount,omitempty"` + Image string `yaml:"image,omitempty"` + TemplateID string `yaml:"templateId,omitempty"` + Runtype string `yaml:"runtype,omitempty"` + DiskGB int `yaml:"diskGB,omitempty"` + MaxDphTotal *float64 `yaml:"maxDphTotal,omitempty"` + MinReliability *float64 `yaml:"minReliability,omitempty"` + Order string `yaml:"order,omitempty"` + User string `yaml:"user,omitempty"` + WorkRoot string `yaml:"workRoot,omitempty"` + ReleaseAction string `yaml:"releaseAction,omitempty"` +} + type fileNvidiaBrevConfig struct { CLI string `yaml:"cli,omitempty"` Org string `yaml:"org,omitempty"` @@ -5470,6 +5608,53 @@ func applyFileConfigWithTrust(cfg *Config, file fileConfig, trusted bool) error cfg.Runpod.WorkRoot = file.Runpod.WorkRoot } } + if file.Vast != nil { + if file.Vast.APIURL != "" { + cfg.Vast.APIURL = file.Vast.APIURL + cfg.credentialProvenance.vastAPIURL = credentialSource + } + if file.Vast.InstanceType != "" { + cfg.Vast.InstanceType = file.Vast.InstanceType + } + if file.Vast.GPUName != "" { + cfg.Vast.GPUName = file.Vast.GPUName + } + if file.Vast.GPUCount != 0 { + cfg.Vast.GPUCount = file.Vast.GPUCount + } + if file.Vast.Image != "" { + cfg.Vast.Image = file.Vast.Image + } + if file.Vast.TemplateID != "" { + cfg.Vast.TemplateID = file.Vast.TemplateID + } + if file.Vast.Runtype != "" { + cfg.Vast.Runtype = file.Vast.Runtype + } + if file.Vast.DiskGB != 0 { + cfg.Vast.DiskGB = file.Vast.DiskGB + } + if file.Vast.MaxDphTotal != nil { + cfg.Vast.MaxDphTotal = *file.Vast.MaxDphTotal + } + if file.Vast.MinReliability != nil { + cfg.Vast.MinReliability = *file.Vast.MinReliability + } + if file.Vast.Order != "" { + cfg.Vast.Order = file.Vast.Order + } + if file.Vast.User != "" { + cfg.Vast.User = file.Vast.User + } + if file.Vast.WorkRoot != "" { + cfg.Vast.WorkRoot = file.Vast.WorkRoot + MarkVastWorkRootExplicit(cfg) + } + if file.Vast.ReleaseAction != "" { + cfg.Vast.ReleaseAction = file.Vast.ReleaseAction + MarkDeleteOnReleaseExplicit(cfg, "vast") + } + } if file.NvidiaBrev != nil { if trusted && file.NvidiaBrev.CLI != "" { cfg.NvidiaBrev.CLI = file.NvidiaBrev.CLI @@ -7379,6 +7564,33 @@ func applyEnv(cfg *Config) error { cfg.Runpod.DiskGB = getenvInt("CRABBOX_RUNPOD_DISK_GB", cfg.Runpod.DiskGB) cfg.Runpod.User = getenv("CRABBOX_RUNPOD_USER", cfg.Runpod.User) cfg.Runpod.WorkRoot = getenv("CRABBOX_RUNPOD_WORK_ROOT", cfg.Runpod.WorkRoot) + if value, ok := firstNonEmptyEnv("CRABBOX_VAST_API_KEY", "VAST_API_KEY"); ok { + cfg.Vast.APIKey = value + cfg.credentialProvenance.vastAPIKey = credentialSourceEnvironment + } + if value, ok := firstNonEmptyEnv("CRABBOX_VAST_API_URL", "VAST_API_URL"); ok { + cfg.Vast.APIURL = value + cfg.credentialProvenance.vastAPIURL = credentialSourceEnvironment + } + cfg.Vast.InstanceType = getenv("CRABBOX_VAST_INSTANCE_TYPE", cfg.Vast.InstanceType) + cfg.Vast.GPUName = getenv("CRABBOX_VAST_GPU_NAME", cfg.Vast.GPUName) + cfg.Vast.GPUCount = getenvInt("CRABBOX_VAST_GPU_COUNT", cfg.Vast.GPUCount) + cfg.Vast.Image = getenv("CRABBOX_VAST_IMAGE", cfg.Vast.Image) + cfg.Vast.TemplateID = getenv("CRABBOX_VAST_TEMPLATE_ID", cfg.Vast.TemplateID) + cfg.Vast.Runtype = getenv("CRABBOX_VAST_RUNTYPE", cfg.Vast.Runtype) + cfg.Vast.DiskGB = getenvInt("CRABBOX_VAST_DISK_GB", cfg.Vast.DiskGB) + cfg.Vast.MaxDphTotal = getenvFloat("CRABBOX_VAST_MAX_DPH_TOTAL", cfg.Vast.MaxDphTotal) + cfg.Vast.MinReliability = getenvFloat("CRABBOX_VAST_MIN_RELIABILITY", cfg.Vast.MinReliability) + cfg.Vast.Order = getenv("CRABBOX_VAST_ORDER", cfg.Vast.Order) + cfg.Vast.User = getenv("CRABBOX_VAST_USER", cfg.Vast.User) + if value := os.Getenv("CRABBOX_VAST_WORK_ROOT"); value != "" { + cfg.Vast.WorkRoot = value + MarkVastWorkRootExplicit(cfg) + } + if value := os.Getenv("CRABBOX_VAST_RELEASE_ACTION"); value != "" { + cfg.Vast.ReleaseAction = value + MarkDeleteOnReleaseExplicit(cfg, "vast") + } cfg.NvidiaBrev.CLI = getenv("CRABBOX_NVIDIA_BREV_CLI", cfg.NvidiaBrev.CLI) cfg.NvidiaBrev.Org = getenv("CRABBOX_NVIDIA_BREV_ORG", cfg.NvidiaBrev.Org) cfg.NvidiaBrev.Type = getenv("CRABBOX_NVIDIA_BREV_TYPE", cfg.NvidiaBrev.Type) diff --git a/internal/cli/config_cmd.go b/internal/cli/config_cmd.go index 52fe8ce37..b9bdf9311 100644 --- a/internal/cli/config_cmd.go +++ b/internal/cli/config_cmd.go @@ -50,6 +50,7 @@ func (a App) configShow(args []string) error { func effectiveConfigForShow(cfg Config) Config { cfg.Hostinger.WorkRoot = EffectiveHostingerWorkRoot(cfg) + cfg.Vast.WorkRoot = EffectiveVastWorkRoot(cfg) cfg.NvidiaBrev.WorkRoot = EffectiveNvidiaBrevWorkRoot(cfg) if cfg.Provider == "digitalocean" || cfg.Provider == "linode" { base := baseConfig() @@ -98,6 +99,11 @@ func effectiveConfigForShow(cfg Config) Config { cfg.SSHFallbackPorts = nil } switch normalizeProviderName(cfg.Provider) { + case "vast", "vast-ai", "vastai": + cfg.WorkRoot = cfg.Vast.WorkRoot + cfg.SSHUser = cfg.Vast.User + cfg.SSHPort = "22" + cfg.SSHFallbackPorts = nil case "nvidia-brev", "brev", "nvidia": cfg.WorkRoot = cfg.NvidiaBrev.WorkRoot } @@ -221,6 +227,23 @@ func configShowView(cfg Config) map[string]any { "user": cfg.NvidiaBrev.User, "workRoot": cfg.NvidiaBrev.WorkRoot, }, + "vast": map[string]any{ + "apiUrl": redactedConfigURLWithoutQuery(cfg.Vast.APIURL), + "auth": tokenState(cfg.Vast.APIKey), + "instanceType": cfg.Vast.InstanceType, + "gpuName": cfg.Vast.GPUName, + "gpuCount": cfg.Vast.GPUCount, + "image": cfg.Vast.Image, + "templateId": cfg.Vast.TemplateID, + "runtype": cfg.Vast.Runtype, + "diskGB": cfg.Vast.DiskGB, + "maxDphTotal": cfg.Vast.MaxDphTotal, + "minReliability": cfg.Vast.MinReliability, + "order": cfg.Vast.Order, + "user": cfg.Vast.User, + "workRoot": cfg.Vast.WorkRoot, + "releaseAction": cfg.Vast.ReleaseAction, + }, "nebius": map[string]any{ "cli": cfg.Nebius.CLI, "auth": "cli", @@ -638,6 +661,7 @@ func writeConfigShowText(w io.Writer, cfg Config) { fmt.Fprintf(w, "vultr region=%s os=%s image=%s snapshot=%s firewall_group=%s vpc_ids=%s ssh_cidrs=%s user_scheme=%s\n", cfg.Vultr.Region, blank(cfg.Vultr.OS, "-"), blank(cfg.Vultr.Image, "-"), blank(cfg.Vultr.Snapshot, "-"), blank(cfg.Vultr.FirewallGroup, "-"), blank(strings.Join(cfg.Vultr.VPCIDs, ","), "-"), blank(strings.Join(cfg.Vultr.SSHCIDRs, ","), "-"), blank(cfg.Vultr.UserScheme, "-")) fmt.Fprintf(w, "linode region=%s image=%s type=%s firewall=%s ssh_cidrs=%s\n", cfg.Linode.Region, cfg.Linode.Image, cfg.Linode.Type, blank(cfg.Linode.FirewallID, "-"), blank(strings.Join(cfg.Linode.SSHCIDRs, ","), "-")) fmt.Fprintf(w, "lambda region=%s type=%s image=%s image_family=%s firewall_ruleset=%s ssh_cidrs=%s filesystems=%s mounts=%d auth=%s\n", cfg.Lambda.Region, cfg.Lambda.Type, blank(cfg.Lambda.Image, "-"), blank(cfg.Lambda.ImageFamily, "-"), blank(cfg.Lambda.FirewallRuleset, "-"), blank(strings.Join(cfg.Lambda.SSHCIDRs, ","), "-"), blank(strings.Join(cfg.Lambda.FilesystemNames, ","), "-"), len(cfg.Lambda.FilesystemMounts), lambdaAuthState()) + fmt.Fprintf(w, "vast api_url=%s instance_type=%s gpu_name=%s gpu_count=%d image=%s template_id=%s runtype=%s disk_gb=%d max_dph_total=%.4g min_reliability=%.4g order=%s user=%s work_root=%s release_action=%s auth=%s\n", blank(redactedConfigURLWithoutQuery(cfg.Vast.APIURL), "-"), blank(cfg.Vast.InstanceType, "-"), blank(cfg.Vast.GPUName, "-"), cfg.Vast.GPUCount, blank(cfg.Vast.Image, "-"), blank(cfg.Vast.TemplateID, "-"), blank(cfg.Vast.Runtype, "-"), cfg.Vast.DiskGB, cfg.Vast.MaxDphTotal, cfg.Vast.MinReliability, blank(cfg.Vast.Order, "-"), blank(cfg.Vast.User, "-"), blank(cfg.Vast.WorkRoot, "-"), blank(cfg.Vast.ReleaseAction, "-"), tokenState(cfg.Vast.APIKey)) fmt.Fprintf(w, "nvidia_brev cli=%s org=%s type=%s gpu_name=%s provider=%s mode=%s launchable=%s startup_script=%s release_action=%s target=%s user=%s work_root=%s auth=cli\n", blank(cfg.NvidiaBrev.CLI, "-"), blank(cfg.NvidiaBrev.Org, "-"), blank(cfg.NvidiaBrev.Type, "-"), blank(cfg.NvidiaBrev.GPUName, "-"), blank(cfg.NvidiaBrev.Provider, "-"), blank(cfg.NvidiaBrev.Mode, "-"), blank(cfg.NvidiaBrev.Launchable, "-"), blank(cfg.NvidiaBrev.StartupScript, "-"), blank(cfg.NvidiaBrev.ReleaseAction, "-"), blank(cfg.NvidiaBrev.Target, "-"), blank(cfg.NvidiaBrev.User, "-"), blank(cfg.NvidiaBrev.WorkRoot, "-")) fmt.Fprintf(w, "nebius cli=%s profile=%s parent_id=%s subnet_id=%s platform=%s preset=%s image_family=%s disk_type=%s disk_size_gib=%d user=%s public_ip=%s security_group_ids=%s service_account_id=%s recovery_policy=%s auth=cli\n", blank(cfg.Nebius.CLI, "-"), blank(cfg.Nebius.Profile, "-"), blank(cfg.Nebius.ParentID, "-"), blank(cfg.Nebius.SubnetID, "-"), blank(cfg.Nebius.Platform, "-"), blank(cfg.Nebius.Preset, "-"), blank(cfg.Nebius.ImageFamily, "-"), blank(cfg.Nebius.DiskType, "-"), cfg.Nebius.DiskSizeGiB, blank(cfg.Nebius.User, "-"), blank(cfg.Nebius.PublicIP, "-"), blank(strings.Join(cfg.Nebius.SecurityGroupIDs, ","), "-"), blank(cfg.Nebius.ServiceAccountID, "-"), blank(cfg.Nebius.RecoveryPolicy, "-")) fmt.Fprintf(w, "hostinger api_url=%s item_id=%s payment_method_id=%s template_id=%s data_center_id=%s hostname_prefix=%s user=%s work_root=%s allow_purchase=%t release_action=%s auth=%s\n", blank(cfg.Hostinger.APIURL, "-"), blank(cfg.Hostinger.ItemID, "-"), blank(cfg.Hostinger.PaymentMethodID, "-"), blank(cfg.Hostinger.TemplateID, "-"), blank(cfg.Hostinger.DataCenterID, "-"), blank(cfg.Hostinger.HostnamePrefix, "-"), blank(cfg.Hostinger.User, "-"), blank(cfg.Hostinger.WorkRoot, "-"), cfg.Hostinger.AllowPurchase, blank(cfg.Hostinger.ReleaseAction, "-"), tokenState(cfg.Hostinger.APIToken)) diff --git a/internal/cli/config_cmd_test.go b/internal/cli/config_cmd_test.go index 970c7500b..d3bc800c7 100644 --- a/internal/cli/config_cmd_test.go +++ b/internal/cli/config_cmd_test.go @@ -1179,6 +1179,101 @@ func TestConfigShowIncludesNvidiaBrevWithoutSecretSurface(t *testing.T) { } } +func TestConfigShowIncludesVastWithoutSecretSurface(t *testing.T) { + clearConfigEnv(t) + home := t.TempDir() + configPath := filepath.Join(home, "config.yaml") + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) + t.Setenv("CRABBOX_CONFIG", configPath) + t.Setenv("CRABBOX_VAST_API_KEY", "vast-redaction-fixture-secret") + if err := os.WriteFile(configPath, []byte(`provider: vast +vast: + apiUrl: https://user:secret@vast.example.test/api/v0?token=hidden + instanceType: on-demand + gpuName: RTX 4090 + gpuCount: 2 + image: nvidia/cuda:vast + templateId: tpl-vast + runtype: ssh_direct + diskGB: 60 + maxDphTotal: 3.5 + minReliability: 0.9 + order: reliability desc + user: root + workRoot: /work/vast + releaseAction: stop +`), 0o600); err != nil { + t.Fatal(err) + } + + var stdout bytes.Buffer + app := App{Stdout: &stdout, Stderr: &bytes.Buffer{}} + if err := app.configShow(nil); err != nil { + t.Fatal(err) + } + text := stdout.String() + if !strings.Contains(text, "vast api_url=https://@vast.example.test/api/v0 instance_type=ondemand gpu_name=RTX 4090 gpu_count=2") || + !strings.Contains(text, "work_root=/work/vast release_action=stop auth=configured") { + t.Fatalf("config show missing vast summary: %q", text) + } + if strings.Contains(text, "vast-redaction-fixture-secret") || strings.Contains(text, "user:secret") || strings.Contains(text, "hidden") { + t.Fatalf("config show text leaked Vast secret: %q", text) + } + + stdout.Reset() + if err := app.configShow([]string{"--json"}); err != nil { + t.Fatal(err) + } + var got struct { + SSHUser string `json:"sshUser"` + SSHPort string `json:"sshPort"` + Vast struct { + APIURL string `json:"apiUrl"` + Auth string `json:"auth"` + InstanceType string `json:"instanceType"` + GPUName string `json:"gpuName"` + GPUCount int `json:"gpuCount"` + Image string `json:"image"` + TemplateID string `json:"templateId"` + Runtype string `json:"runtype"` + DiskGB int `json:"diskGB"` + MaxDphTotal float64 `json:"maxDphTotal"` + MinReliability float64 `json:"minReliability"` + Order string `json:"order"` + User string `json:"user"` + WorkRoot string `json:"workRoot"` + ReleaseAction string `json:"releaseAction"` + } `json:"vast"` + } + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got.Vast.APIURL != "https://@vast.example.test/api/v0" || + got.Vast.Auth != "configured" || + got.Vast.InstanceType != "ondemand" || + got.Vast.GPUName != "RTX 4090" || + got.Vast.GPUCount != 2 || + got.Vast.Image != "nvidia/cuda:vast" || + got.Vast.TemplateID != "tpl-vast" || + got.Vast.Runtype != "ssh_direct" || + got.Vast.DiskGB != 60 || + got.Vast.MaxDphTotal != 3.5 || + got.Vast.MinReliability != 0.9 || + got.Vast.Order != "reliability desc" || + got.Vast.User != "root" || + got.Vast.WorkRoot != "/work/vast" || + got.Vast.ReleaseAction != "stop" { + t.Fatalf("unexpected vast json: %#v", got.Vast) + } + if got.SSHUser != "root" || got.SSHPort != "22" { + t.Fatalf("unexpected vast ssh json: %#v", got) + } + if strings.Contains(stdout.String(), "vast-redaction-fixture-secret") || strings.Contains(stdout.String(), "user:secret") || strings.Contains(stdout.String(), "hidden") { + t.Fatalf("config show json leaked Vast secret: %q", stdout.String()) + } +} + func TestConfigShowIncludesNebiusWithoutSecretSurface(t *testing.T) { clearConfigEnv(t) home := t.TempDir() diff --git a/internal/cli/config_test.go b/internal/cli/config_test.go index 6ba65021a..6b1f1e73a 100644 --- a/internal/cli/config_test.go +++ b/internal/cli/config_test.go @@ -446,6 +446,38 @@ func clearConfigEnv(t *testing.T) { "FASTAPI_CLOUD_APP_ID", "CRABBOX_FASTAPI_CLOUD_TEAM_ID", "FASTAPI_CLOUD_TEAM_ID", + "RUNPOD_API_KEY", + "CRABBOX_RUNPOD_API_KEY", + "RUNPOD_API_URL", + "CRABBOX_RUNPOD_API_URL", + "RUNPOD_CLOUD_TYPE", + "CRABBOX_RUNPOD_CLOUD_TYPE", + "RUNPOD_INSTANCE_ID", + "CRABBOX_RUNPOD_INSTANCE_ID", + "RUNPOD_IMAGE", + "CRABBOX_RUNPOD_IMAGE", + "RUNPOD_TEMPLATE_ID", + "CRABBOX_RUNPOD_TEMPLATE_ID", + "CRABBOX_RUNPOD_DISK_GB", + "CRABBOX_RUNPOD_USER", + "CRABBOX_RUNPOD_WORK_ROOT", + "CRABBOX_VAST_API_KEY", + "VAST_API_KEY", + "CRABBOX_VAST_API_URL", + "VAST_API_URL", + "CRABBOX_VAST_INSTANCE_TYPE", + "CRABBOX_VAST_GPU_NAME", + "CRABBOX_VAST_GPU_COUNT", + "CRABBOX_VAST_IMAGE", + "CRABBOX_VAST_TEMPLATE_ID", + "CRABBOX_VAST_RUNTYPE", + "CRABBOX_VAST_DISK_GB", + "CRABBOX_VAST_MAX_DPH_TOTAL", + "CRABBOX_VAST_MIN_RELIABILITY", + "CRABBOX_VAST_ORDER", + "CRABBOX_VAST_USER", + "CRABBOX_VAST_WORK_ROOT", + "CRABBOX_VAST_RELEASE_ACTION", "CRABBOX_NVIDIA_BREV_CLI", "CRABBOX_NVIDIA_BREV_ORG", "CRABBOX_NVIDIA_BREV_TYPE", @@ -4555,6 +4587,21 @@ runpod: diskGB: 25 user: runpod-user workRoot: /workspaces/runpod-test +vast: + apiUrl: https://vast.example.test/api/v0 + instanceType: on-demand + gpuName: RTX 4090 + gpuCount: 2 + image: nvidia/cuda:vast-file + templateId: vast-tpl-file + runtype: ssh_direct + diskGB: 60 + maxDphTotal: 3.5 + minReliability: 0.9 + order: reliability desc + user: root + workRoot: /workspaces/vast-test + releaseAction: stop islo: baseUrl: https://islo.example.test image: docker.io/library/ubuntu:24.04 @@ -4810,6 +4857,9 @@ ssh: if cfg.Runpod.APIURL != "https://runpod.example.test/v1" || cfg.Runpod.CloudType != "SECURE" || cfg.Runpod.InstanceID != "NVIDIA L4" || cfg.Runpod.Image != "runpod/pytorch:custom" || cfg.Runpod.TemplateID != "tpl-file" || cfg.Runpod.DiskGB != 25 || cfg.Runpod.User != "runpod-user" || cfg.Runpod.WorkRoot != "/workspaces/runpod-test" { t.Fatalf("runpod config not loaded: %#v", cfg.Runpod) } + if cfg.Vast.APIURL != "https://vast.example.test/api/v0" || cfg.Vast.InstanceType != "on-demand" || cfg.Vast.GPUName != "RTX 4090" || cfg.Vast.GPUCount != 2 || cfg.Vast.Image != "nvidia/cuda:vast-file" || cfg.Vast.TemplateID != "vast-tpl-file" || cfg.Vast.Runtype != "ssh_direct" || cfg.Vast.DiskGB != 60 || cfg.Vast.MaxDphTotal != 3.5 || cfg.Vast.MinReliability != 0.9 || cfg.Vast.Order != "reliability desc" || cfg.Vast.User != "root" || cfg.Vast.WorkRoot != "/workspaces/vast-test" || cfg.Vast.ReleaseAction != "stop" { + t.Fatalf("vast config not loaded: %#v", cfg.Vast) + } if cfg.Islo.BaseURL != "https://islo.example.test" || cfg.Islo.Image != "docker.io/library/ubuntu:24.04" || cfg.Islo.Workdir != "crabbox" || cfg.Islo.GatewayProfile != "default" || cfg.Islo.SnapshotName != "snap-ready" || cfg.Islo.VCPUs != 4 || cfg.Islo.MemoryMB != 8192 || cfg.Islo.DiskGB != 40 { t.Fatalf("islo config not loaded: %#v", cfg.Islo) } @@ -5144,6 +5194,23 @@ func TestEnvOverridesConfig(t *testing.T) { t.Setenv("CRABBOX_RUNPOD_DISK_GB", "30") t.Setenv("CRABBOX_RUNPOD_USER", "runpod-env-user") t.Setenv("CRABBOX_RUNPOD_WORK_ROOT", "/work/runpod-env") + t.Setenv("VAST_API_KEY", "vast-key-file") + t.Setenv("CRABBOX_VAST_API_KEY", "vast-key-env") + t.Setenv("VAST_API_URL", "https://vast-file.example/api/v0") + t.Setenv("CRABBOX_VAST_API_URL", "https://vast-env.example/api/v0") + t.Setenv("CRABBOX_VAST_INSTANCE_TYPE", "interruptible") + t.Setenv("CRABBOX_VAST_GPU_NAME", "H100") + t.Setenv("CRABBOX_VAST_GPU_COUNT", "4") + t.Setenv("CRABBOX_VAST_IMAGE", "nvidia/cuda:vast-env") + t.Setenv("CRABBOX_VAST_TEMPLATE_ID", "vast-tpl-env") + t.Setenv("CRABBOX_VAST_RUNTYPE", "ssh_direct") + t.Setenv("CRABBOX_VAST_DISK_GB", "80") + t.Setenv("CRABBOX_VAST_MAX_DPH_TOTAL", "4.25") + t.Setenv("CRABBOX_VAST_MIN_RELIABILITY", "0.95") + t.Setenv("CRABBOX_VAST_ORDER", "dlperf desc") + t.Setenv("CRABBOX_VAST_USER", "ubuntu") + t.Setenv("CRABBOX_VAST_WORK_ROOT", "/work/vast-env") + t.Setenv("CRABBOX_VAST_RELEASE_ACTION", "keep") t.Setenv("ISLO_API_KEY", "islo-api-file") t.Setenv("CRABBOX_ISLO_API_KEY", "islo-api-env") t.Setenv("ISLO_BASE_URL", "https://islo-file.example") @@ -5398,6 +5465,9 @@ func TestEnvOverridesConfig(t *testing.T) { if cfg.Runpod.APIKey != "runpod-key-env" || cfg.Runpod.APIURL != "https://runpod-env.example/v1" || cfg.Runpod.CloudType != "SECURE" || cfg.Runpod.InstanceID != "NVIDIA L4" || cfg.Runpod.Image != "runpod/pytorch:env" || cfg.Runpod.TemplateID != "tpl-env" || cfg.Runpod.DiskGB != 30 || cfg.Runpod.User != "runpod-env-user" || cfg.Runpod.WorkRoot != "/work/runpod-env" { t.Fatalf("unexpected runpod env: %#v", cfg.Runpod) } + if cfg.Vast.APIKey != "vast-key-env" || cfg.Vast.APIURL != "https://vast-env.example/api/v0" || cfg.Vast.InstanceType != "interruptible" || cfg.Vast.GPUName != "H100" || cfg.Vast.GPUCount != 4 || cfg.Vast.Image != "nvidia/cuda:vast-env" || cfg.Vast.TemplateID != "vast-tpl-env" || cfg.Vast.Runtype != "ssh_direct" || cfg.Vast.DiskGB != 80 || cfg.Vast.MaxDphTotal != 4.25 || cfg.Vast.MinReliability != 0.95 || cfg.Vast.Order != "dlperf desc" || cfg.Vast.User != "ubuntu" || cfg.Vast.WorkRoot != "/work/vast-env" || cfg.Vast.ReleaseAction != "keep" { + t.Fatalf("unexpected vast env: %#v", cfg.Vast) + } if cfg.Islo.APIKey != "islo-api-env" || cfg.Islo.BaseURL != "https://islo-env.example" || cfg.Islo.Image != "ubuntu:env" || cfg.Islo.Workdir != "env-workdir" || cfg.Islo.GatewayProfile != "env-gateway" || cfg.Islo.SnapshotName != "env-snapshot" || cfg.Islo.VCPUs != 8 || cfg.Islo.MemoryMB != 16384 || cfg.Islo.DiskGB != 80 { t.Fatalf("unexpected islo env: %#v", cfg.Islo) } diff --git a/internal/cli/credential_provenance.go b/internal/cli/credential_provenance.go index c19059fb6..3e4047f05 100644 --- a/internal/cli/credential_provenance.go +++ b/internal/cli/credential_provenance.go @@ -45,6 +45,8 @@ type credentialDestinationProvenance struct { fastAPICloudToken credentialValueSource runpodAPIURL credentialValueSource runpodAPIKey credentialValueSource + vastAPIURL credentialValueSource + vastAPIKey credentialValueSource isloBaseURL credentialValueSource isloAPIKey credentialValueSource tenkiEndpoint credentialValueSource @@ -128,6 +130,9 @@ func markCredentialDestinationFlagSources(cfg *Config, fs *flag.FlagSet) { if flagWasSet(fs, "runpod-url") { provenance.runpodAPIURL = credentialSourceFlag } + if flagWasSet(fs, "vast-api-url") { + provenance.vastAPIURL = credentialSourceFlag + } if flagWasSet(fs, "islo-base-url") { provenance.isloBaseURL = credentialSourceFlag } @@ -247,6 +252,11 @@ func validateProviderCredentialDestination(cfg Config) error { inheritedCredential(sourcedCredential{cfg.Runpod.APIKey, provenance.runpodAPIKey}) { return repositoryCredentialDestinationError("runpod", "runpod.apiUrl", "CRABBOX_RUNPOD_API_URL or --runpod-url") } + case "vast": + if provenance.vastAPIURL == credentialSourceRepository && + inheritedCredential(sourcedCredential{cfg.Vast.APIKey, provenance.vastAPIKey}) { + return repositoryCredentialDestinationError("vast", "vast.apiUrl", "CRABBOX_VAST_API_URL or --vast-api-url") + } case "islo": if provenance.isloBaseURL == credentialSourceRepository && inheritedCredential(sourcedCredential{cfg.Islo.APIKey, provenance.isloAPIKey}) { diff --git a/internal/cli/credential_provenance_test.go b/internal/cli/credential_provenance_test.go index 2104deb06..91ac2e4da 100644 --- a/internal/cli/credential_provenance_test.go +++ b/internal/cli/credential_provenance_test.go @@ -135,6 +135,18 @@ func TestRepositoryCredentialDestinationsRejectInheritedCredentials(t *testing.T }, want: "runpod.apiUrl", }, + { + name: "vast api", + cfg: Config{ + Provider: "vast", + Vast: VastConfig{APIURL: "https://repo.example.test", APIKey: "secret"}, + credentialProvenance: credentialDestinationProvenance{ + vastAPIURL: credentialSourceRepository, + vastAPIKey: credentialSourceEnvironment, + }, + }, + want: "vast.apiUrl", + }, { name: "islo api", cfg: Config{ @@ -389,6 +401,31 @@ func TestRepositoryCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) } } +func TestVastCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) { + cfg := Config{ + Provider: "vast", + Vast: VastConfig{APIURL: "https://repo.example.test", APIKey: "secret"}, + credentialProvenance: credentialDestinationProvenance{ + vastAPIURL: credentialSourceRepository, + vastAPIKey: credentialSourceEnvironment, + }, + } + fs := newFlagSet("test", io.Discard) + values := registerProviderFlags(fs, cfg) + if err := parseFlags(fs, []string{"--vast-api-url", "https://approved.example.test"}); err != nil { + t.Fatal(err) + } + if err := applyProviderFlags(&cfg, fs, values); err != nil { + t.Fatal(err) + } + if err := validateProviderCredentialDestination(cfg); err != nil { + t.Fatalf("explicit vast flag override rejected: %v", err) + } + if cfg.Vast.APIURL != "https://approved.example.test" { + t.Fatalf("vast apiUrl=%q", cfg.Vast.APIURL) + } +} + func TestAzureDynamicSessionsCredentialDestinationAllowsExplicitFlagOverride(t *testing.T) { cfg := Config{ Provider: "azure-dynamic-sessions", @@ -572,6 +609,13 @@ func TestConfigMergeSourceBindsDirectProviderCredentials(t *testing.T) { credentialEnv: "CRABBOX_RUNPOD_API_KEY", approveEnv: "CRABBOX_RUNPOD_API_URL", }, + { + name: "vast", + provider: "vast", + file: fileConfig{Vast: &fileVastConfig{APIURL: "https://repo.example.test"}}, + credentialEnv: "CRABBOX_VAST_API_KEY", + approveEnv: "CRABBOX_VAST_API_URL", + }, { name: "islo", provider: "islo", diff --git a/internal/cli/provider_categories_generated.go b/internal/cli/provider_categories_generated.go index ffa9baad1..215f8606e 100644 --- a/internal/cli/provider_categories_generated.go +++ b/internal/cli/provider_categories_generated.go @@ -65,6 +65,7 @@ var benchmarkProviderCategories = map[string]string{ "tenki": "direct-cloud", "tensorlake": "delegated-sandbox", "upstash-box": "delegated-sandbox", + "vast": "gpu-cloud", "vercel-sandbox": "delegated-sandbox", "vultr": "direct-cloud", "wandb": "gpu-cloud", diff --git a/internal/cli/providers_builtin_test.go b/internal/cli/providers_builtin_test.go index 456591bcb..5d524030c 100644 --- a/internal/cli/providers_builtin_test.go +++ b/internal/cli/providers_builtin_test.go @@ -32,6 +32,7 @@ func init() { RegisterProvider(testExternalProvider{}) RegisterProvider(testExeDevProvider{}) RegisterProvider(testRunPodProvider{}) + RegisterProvider(testVastProvider{}) RegisterProvider(testNvidiaBrevProvider{}) RegisterProvider(testBlacksmithProvider{}) RegisterProvider(testNamespaceProvider{}) @@ -1056,6 +1057,38 @@ func (p testRunPodProvider) Configure(cfg Config, rt Runtime) (Backend, error) { return testSSHBackend{spec: p.Spec()}, nil } +type testVastProvider struct{} + +func (testVastProvider) Name() string { return "vast" } +func (testVastProvider) Aliases() []string { + return []string{"vast-ai", "vastai"} +} +func (testVastProvider) Spec() ProviderSpec { + return ProviderSpec{ + Name: "vast", + Family: "vast", + Kind: ProviderKindSSHLease, + Targets: []TargetSpec{{OS: targetLinux}}, + Features: FeatureSet{FeatureSSH, FeatureCrabboxSync, FeatureCleanup}, + Coordinator: CoordinatorNever, + } +} +func (testVastProvider) RegisterFlags(fs *flag.FlagSet, defaults Config) any { + return struct{ APIURL *string }{ + APIURL: fs.String("vast-api-url", defaults.Vast.APIURL, ""), + } +} +func (testVastProvider) ApplyFlags(cfg *Config, fs *flag.FlagSet, values any) error { + v, _ := values.(struct{ APIURL *string }) + if flagWasSet(fs, "vast-api-url") && v.APIURL != nil { + cfg.Vast.APIURL = *v.APIURL + } + return nil +} +func (p testVastProvider) Configure(Config, Runtime) (Backend, error) { + return testSSHBackend{spec: p.Spec()}, nil +} + type testNvidiaBrevProvider struct{} func (testNvidiaBrevProvider) Name() string { return "nvidia-brev" } diff --git a/internal/cli/providers_test.go b/internal/cli/providers_test.go index 616f72d4a..ea061aaa6 100644 --- a/internal/cli/providers_test.go +++ b/internal/cli/providers_test.go @@ -17,6 +17,7 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { var digitalOcean *providerMatrixEntry var vultr *providerMatrixEntry var firecracker *providerMatrixEntry + var vast *providerMatrixEntry var nvidiaBrev *providerMatrixEntry var linode *providerMatrixEntry var nebius *providerMatrixEntry @@ -43,6 +44,9 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if entries[i].Provider == "firecracker" { firecracker = &entries[i] } + if entries[i].Provider == "vast" { + vast = &entries[i] + } if entries[i].Provider == "nvidia-brev" { nvidiaBrev = &entries[i] } @@ -89,6 +93,9 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if firecracker == nil { t.Fatal("firecracker provider not found") } + if vast == nil { + t.Fatal("vast provider not found") + } if nvidiaBrev == nil { t.Fatal("nvidia-brev provider not found") } @@ -243,6 +250,18 @@ func TestProviderMatrixIncludesCapabilities(t *testing.T) { if !containsString(nvidiaBrev.Aliases, "brev") || !containsString(nvidiaBrev.Aliases, "nvidia") { t.Fatalf("nvidia-brev aliases=%v", nvidiaBrev.Aliases) } + if vast.Kind != ProviderKindSSHLease || vast.Family != "vast" || vast.Coordinator != string(CoordinatorNever) { + t.Fatalf("vast kind/family/coordinator=%q/%q/%q", vast.Kind, vast.Family, vast.Coordinator) + } + if !containsString(vast.Targets, targetLinux) { + t.Fatalf("vast targets=%v", vast.Targets) + } + if !containsFeature(vast.Features, FeatureSSH) || !containsFeature(vast.Features, FeatureCrabboxSync) || !containsFeature(vast.Features, FeatureCleanup) { + t.Fatalf("vast features=%v", vast.Features) + } + if !containsString(vast.Aliases, "vast-ai") || !containsString(vast.Aliases, "vastai") { + t.Fatalf("vast aliases=%v", vast.Aliases) + } for _, capability := range []string{"local-runtime", "ssh-host"} { if !containsString(localContainer.Runtime, capability) { t.Fatalf("local-container runtime=%v missing %s", localContainer.Runtime, capability) diff --git a/internal/providers/all/all.go b/internal/providers/all/all.go index 951183bff..451cf214d 100644 --- a/internal/providers/all/all.go +++ b/internal/providers/all/all.go @@ -63,6 +63,7 @@ import ( _ "github.com/openclaw/crabbox/internal/providers/tenki" _ "github.com/openclaw/crabbox/internal/providers/tensorlake" _ "github.com/openclaw/crabbox/internal/providers/upstashbox" + _ "github.com/openclaw/crabbox/internal/providers/vast" _ "github.com/openclaw/crabbox/internal/providers/vercelsandbox" _ "github.com/openclaw/crabbox/internal/providers/vultr" _ "github.com/openclaw/crabbox/internal/providers/wandb" diff --git a/internal/providers/all/all_test.go b/internal/providers/all/all_test.go index 467abee5c..46e40b815 100644 --- a/internal/providers/all/all_test.go +++ b/internal/providers/all/all_test.go @@ -77,6 +77,34 @@ func TestNvidiaBrevRegistersCanonicalAndAliases(t *testing.T) { } } +func TestVastRegistersCanonicalAndAliases(t *testing.T) { + for _, name := range []string{"vast", "vast-ai", "vastai"} { + provider, err := core.ProviderFor(name) + if err != nil { + t.Fatalf("ProviderFor(%q): %v", name, err) + } + if provider.Name() != "vast" { + t.Fatalf("ProviderFor(%q).Name=%q want vast", name, provider.Name()) + } + } + provider, err := core.ProviderFor("vast") + if err != nil { + t.Fatalf("ProviderFor(vast): %v", err) + } + spec := provider.Spec() + if spec.Family != "vast" || spec.Kind != core.ProviderKindSSHLease || spec.Coordinator != core.CoordinatorNever { + t.Fatalf("vast spec=%#v", spec) + } + if len(spec.Targets) != 1 || spec.Targets[0].OS != core.TargetLinux { + t.Fatalf("vast targets=%#v", spec.Targets) + } + for _, feature := range []core.Feature{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup} { + if !spec.Features.Has(feature) { + t.Fatalf("vast features=%v missing %s", spec.Features, feature) + } + } +} + func TestLambdaRegistersAsBuiltInProvider(t *testing.T) { provider, err := core.ProviderFor("lambda") if err != nil { @@ -1186,6 +1214,7 @@ func allBuiltInProviderNames() []string { "tenki", "tensorlake", "upstash-box", + "vast", "vercel-sandbox", "vultr", "wandb", diff --git a/internal/providers/vast/core.go b/internal/providers/vast/core.go new file mode 100644 index 000000000..f0214e9cc --- /dev/null +++ b/internal/providers/vast/core.go @@ -0,0 +1,62 @@ +package vast + +import ( + "flag" + "strings" + + core "github.com/openclaw/crabbox/internal/cli" +) + +type Config = core.Config +type VastConfig = core.VastConfig +type ProviderSpec = core.ProviderSpec +type Runtime = core.Runtime +type Backend = core.Backend +type DoctorRequest = core.DoctorRequest +type DoctorResult = core.DoctorResult +type AcquireRequest = core.AcquireRequest +type ResolveRequest = core.ResolveRequest +type ListRequest = core.ListRequest +type LeaseView = core.LeaseView +type ReleaseLeaseRequest = core.ReleaseLeaseRequest +type TouchRequest = core.TouchRequest +type LeaseTarget = core.LeaseTarget +type Server = core.Server + +const ( + providerName = "vast" + targetLinux = core.TargetLinux +) + +func exit(code int, format string, args ...any) core.ExitError { + return core.Exit(code, format, args...) +} + +func flagWasSet(fs *flag.FlagSet, name string) bool { + return core.FlagWasSet(fs, name) +} + +func markVastWorkRootExplicit(cfg *Config) { + core.MarkVastWorkRootExplicit(cfg) +} + +func markReleaseActionExplicit(cfg *Config) { + core.MarkDeleteOnReleaseExplicit(cfg, providerName) +} + +func normalizeInstanceType(value string) string { + return core.NormalizeVastInstanceType(value) +} + +func isVastProviderName(provider string) bool { + switch strings.ToLower(strings.TrimSpace(provider)) { + case providerName, "vast-ai", "vastai": + return true + default: + return false + } +} + +func notImplemented(operation string) error { + return exit(2, "provider=%s %s is not implemented yet", providerName, operation) +} diff --git a/internal/providers/vast/flags.go b/internal/providers/vast/flags.go new file mode 100644 index 000000000..f3f47ba1b --- /dev/null +++ b/internal/providers/vast/flags.go @@ -0,0 +1,104 @@ +package vast + +import "flag" + +type vastFlagValues struct { + APIURL *string + InstanceType *string + GPUName *string + GPUCount *int + Image *string + TemplateID *string + Runtype *string + DiskGB *int + MaxDphTotal *float64 + MinReliability *float64 + Order *string + User *string + WorkRoot *string + ReleaseAction *string +} + +// RegisterVastProviderFlags exposes only non-secret Vast settings. API keys +// are sourced from CRABBOX_VAST_API_KEY / VAST_API_KEY and never argv. +func RegisterVastProviderFlags(fs *flag.FlagSet, defaults Config) any { + return vastFlagValues{ + APIURL: fs.String("vast-api-url", defaults.Vast.APIURL, "Vast.ai REST API URL"), + InstanceType: fs.String("vast-instance-type", defaults.Vast.InstanceType, "Vast.ai offer type: ondemand or interruptible"), + GPUName: fs.String("vast-gpu-name", defaults.Vast.GPUName, "Vast.ai GPU name selector"), + GPUCount: fs.Int("vast-gpu-count", defaults.Vast.GPUCount, "Vast.ai minimum GPU count"), + Image: fs.String("vast-image", defaults.Vast.Image, "Docker image to deploy on the instance"), + TemplateID: fs.String("vast-template-id", defaults.Vast.TemplateID, "Optional Vast.ai template ID"), + Runtype: fs.String("vast-runtype", defaults.Vast.Runtype, "Vast.ai runtime type: ssh_direct"), + DiskGB: fs.Int("vast-disk-gb", defaults.Vast.DiskGB, "Instance disk size in GB"), + MaxDphTotal: fs.Float64("vast-max-dph-total", defaults.Vast.MaxDphTotal, "Maximum total dollars per hour"), + MinReliability: fs.Float64("vast-min-reliability", defaults.Vast.MinReliability, "Minimum reliability score from 0 to 1"), + Order: fs.String("vast-order", defaults.Vast.Order, "Vast.ai offer ordering expression"), + User: fs.String("vast-user", defaults.Vast.User, "SSH user for Vast.ai instances"), + WorkRoot: fs.String("vast-work-root", defaults.Vast.WorkRoot, "remote Crabbox work root on Vast.ai instances"), + ReleaseAction: fs.String("vast-release-action", defaults.Vast.ReleaseAction, "Vast.ai release action: destroy, stop, or keep"), + } +} + +func ApplyVastProviderFlags(cfg *Config, fs *flag.FlagSet, values any) error { + if isVastProviderName(cfg.Provider) { + if flagWasSet(fs, "class") { + return exit(2, "--class is not supported for provider=%s; use --vast-gpu-name or --vast-gpu-count", providerName) + } + if flagWasSet(fs, "type") { + return exit(2, "--type is not supported for provider=%s; use --vast-image", providerName) + } + } + v, ok := values.(vastFlagValues) + if !ok { + return nil + } + if flagWasSet(fs, "vast-api-url") { + cfg.Vast.APIURL = *v.APIURL + } + if flagWasSet(fs, "vast-instance-type") { + cfg.Vast.InstanceType = normalizeInstanceType(*v.InstanceType) + } + if flagWasSet(fs, "vast-gpu-name") { + cfg.Vast.GPUName = *v.GPUName + } + if flagWasSet(fs, "vast-gpu-count") { + cfg.Vast.GPUCount = *v.GPUCount + } + if flagWasSet(fs, "vast-image") { + cfg.Vast.Image = *v.Image + } + if flagWasSet(fs, "vast-template-id") { + cfg.Vast.TemplateID = *v.TemplateID + } + if flagWasSet(fs, "vast-runtype") { + cfg.Vast.Runtype = *v.Runtype + } + if flagWasSet(fs, "vast-disk-gb") { + cfg.Vast.DiskGB = *v.DiskGB + } + if flagWasSet(fs, "vast-max-dph-total") { + cfg.Vast.MaxDphTotal = *v.MaxDphTotal + } + if flagWasSet(fs, "vast-min-reliability") { + cfg.Vast.MinReliability = *v.MinReliability + } + if flagWasSet(fs, "vast-order") { + cfg.Vast.Order = *v.Order + } + if flagWasSet(fs, "vast-user") { + cfg.Vast.User = *v.User + } + if flagWasSet(fs, "vast-work-root") { + cfg.Vast.WorkRoot = *v.WorkRoot + markVastWorkRootExplicit(cfg) + } + if flagWasSet(fs, "vast-release-action") { + cfg.Vast.ReleaseAction = *v.ReleaseAction + markReleaseActionExplicit(cfg) + } + if isVastProviderName(cfg.Provider) { + return Provider{}.ValidateConfig(*cfg) + } + return nil +} diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go new file mode 100644 index 000000000..6bf6de77c --- /dev/null +++ b/internal/providers/vast/provider.go @@ -0,0 +1,145 @@ +package vast + +import ( + "context" + "flag" + "net/url" + "strings" + + core "github.com/openclaw/crabbox/internal/cli" +) + +func init() { + core.RegisterProvider(Provider{}) +} + +type Provider struct{} + +func (Provider) Name() string { return providerName } + +func (Provider) Aliases() []string { + return []string{"vast-ai", "vastai"} +} + +func (Provider) Spec() core.ProviderSpec { + return core.ProviderSpec{ + Name: providerName, + Family: "vast", + Kind: core.ProviderKindSSHLease, + Targets: []core.TargetSpec{{OS: core.TargetLinux}}, + Features: core.FeatureSet{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup}, + Coordinator: core.CoordinatorNever, + } +} + +func (Provider) RegisterFlags(fs *flag.FlagSet, defaults core.Config) any { + return RegisterVastProviderFlags(fs, defaults) +} + +func (Provider) ApplyFlags(cfg *core.Config, fs *flag.FlagSet, values any) error { + return ApplyVastProviderFlags(cfg, fs, values) +} + +func (p Provider) Configure(cfg core.Config, _ core.Runtime) (core.Backend, error) { + if err := p.ValidateConfig(cfg); err != nil { + return nil, err + } + if cfg.TargetOS != "" && cfg.TargetOS != core.TargetLinux { + return nil, exit(2, "provider=%s supports target=linux only", providerName) + } + return backend{spec: p.Spec()}, nil +} + +func (p Provider) ConfigureDoctor(cfg core.Config, rt core.Runtime) (core.DoctorBackend, error) { + backend, err := p.Configure(cfg, rt) + if err != nil { + return nil, err + } + doctor, ok := backend.(core.DoctorBackend) + if !ok { + return nil, core.Exit(2, "vast doctor backend unavailable") + } + return doctor, nil +} + +func (Provider) ValidateConfig(cfg core.Config) error { + apiURL := strings.TrimSpace(cfg.Vast.APIURL) + if apiURL == "" { + return exit(2, "vast.apiUrl is required") + } + u, err := url.Parse(apiURL) + if err != nil || u.Scheme == "" || u.Host == "" || u.User != nil { + return exit(2, "vast.apiUrl must be an absolute URL without credentials") + } + switch strings.ToLower(u.Scheme) { + case "https", "http": + default: + return exit(2, "vast.apiUrl must use http or https") + } + switch normalizeInstanceType(cfg.Vast.InstanceType) { + case "ondemand", "interruptible": + default: + return exit(2, "vast.instanceType must be ondemand or interruptible") + } + switch strings.ToLower(strings.TrimSpace(cfg.Vast.Runtype)) { + case "ssh_direct": + default: + return exit(2, "vast.runtype must be ssh_direct") + } + if cfg.Vast.GPUCount < 0 { + return exit(2, "vast.gpuCount must be non-negative") + } + if cfg.Vast.DiskGB < 0 { + return exit(2, "vast.diskGB must be non-negative") + } + if cfg.Vast.MaxDphTotal < 0 { + return exit(2, "vast.maxDphTotal must be non-negative") + } + if cfg.Vast.MinReliability < 0 || cfg.Vast.MinReliability > 1 { + return exit(2, "vast.minReliability must be between 0 and 1") + } + switch strings.ToLower(strings.TrimSpace(cfg.Vast.ReleaseAction)) { + case "", "destroy", "delete", "stop", "keep": + default: + return exit(2, "vast.releaseAction must be destroy, delete, stop, or keep") + } + return nil +} + +type backend struct { + spec core.ProviderSpec +} + +func (b backend) Spec() core.ProviderSpec { return b.spec } + +func (b backend) Doctor(context.Context, core.DoctorRequest) (core.DoctorResult, error) { + return core.DoctorResult{ + Provider: providerName, + Status: "unsupported", + Message: "vast lifecycle is not implemented yet", + }, nil +} + +func (b backend) Acquire(context.Context, core.AcquireRequest) (core.LeaseTarget, error) { + return core.LeaseTarget{}, notImplemented("acquire") +} + +func (b backend) Resolve(context.Context, core.ResolveRequest) (core.LeaseTarget, error) { + return core.LeaseTarget{}, notImplemented("resolve") +} + +func (b backend) List(context.Context, core.ListRequest) ([]core.LeaseView, error) { + return nil, notImplemented("list") +} + +func (b backend) ReleaseLease(context.Context, core.ReleaseLeaseRequest) error { + return notImplemented("release") +} + +func (b backend) Touch(context.Context, core.TouchRequest) (core.Server, error) { + return core.Server{}, notImplemented("touch") +} + +func (b backend) Cleanup(context.Context, core.CleanupRequest) error { + return notImplemented("cleanup") +} diff --git a/internal/providers/vast/provider_test.go b/internal/providers/vast/provider_test.go new file mode 100644 index 000000000..206f05678 --- /dev/null +++ b/internal/providers/vast/provider_test.go @@ -0,0 +1,176 @@ +package vast + +import ( + "flag" + "strings" + "testing" + + core "github.com/openclaw/crabbox/internal/cli" +) + +func TestProviderSpecAndAliases(t *testing.T) { + p := Provider{} + if p.Name() != "vast" { + t.Fatalf("Name=%q", p.Name()) + } + aliases := p.Aliases() + if len(aliases) != 2 || aliases[0] != "vast-ai" || aliases[1] != "vastai" { + t.Fatalf("aliases=%v", aliases) + } + spec := p.Spec() + if spec.Name != "vast" || spec.Family != "vast" || spec.Kind != core.ProviderKindSSHLease || spec.Coordinator != core.CoordinatorNever { + t.Fatalf("spec=%#v", spec) + } + if len(spec.Targets) != 1 || spec.Targets[0].OS != core.TargetLinux { + t.Fatalf("targets=%#v", spec.Targets) + } + for _, feature := range []core.Feature{core.FeatureSSH, core.FeatureCrabboxSync, core.FeatureCleanup} { + if !spec.Features.Has(feature) { + t.Fatalf("features=%v missing %s", spec.Features, feature) + } + } +} + +func TestProviderFlagsApplyNonSecretConfig(t *testing.T) { + cfg := core.Config{ + Provider: "vast", + Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + Image: "nvidia/cuda:default", + DiskGB: 20, + User: "root", + WorkRoot: "/work", + ReleaseAction: "destroy", + }, + } + fs := flag.NewFlagSet("test", flag.ContinueOnError) + values := RegisterVastProviderFlags(fs, cfg) + if err := fs.Parse([]string{ + "--vast-api-url", "https://approved.example.test/api/v0", + "--vast-instance-type", "on-demand", + "--vast-gpu-name", "H100", + "--vast-gpu-count", "4", + "--vast-image", "nvidia/cuda:12", + "--vast-template-id", "tpl-123", + "--vast-runtype", "ssh_direct", + "--vast-disk-gb", "80", + "--vast-max-dph-total", "4.5", + "--vast-min-reliability", "0.95", + "--vast-order", "reliability desc", + "--vast-user", "ubuntu", + "--vast-work-root", "/work/vast", + "--vast-release-action", "keep", + }); err != nil { + t.Fatal(err) + } + if err := ApplyVastProviderFlags(&cfg, fs, values); err != nil { + t.Fatal(err) + } + if cfg.Vast.APIURL != "https://approved.example.test/api/v0" || + cfg.Vast.InstanceType != "ondemand" || + cfg.Vast.GPUName != "H100" || + cfg.Vast.GPUCount != 4 || + cfg.Vast.Image != "nvidia/cuda:12" || + cfg.Vast.TemplateID != "tpl-123" || + cfg.Vast.Runtype != "ssh_direct" || + cfg.Vast.DiskGB != 80 || + cfg.Vast.MaxDphTotal != 4.5 || + cfg.Vast.MinReliability != 0.95 || + cfg.Vast.Order != "reliability desc" || + cfg.Vast.User != "ubuntu" || + cfg.Vast.WorkRoot != "/work/vast" || + cfg.Vast.ReleaseAction != "keep" { + t.Fatalf("vast config=%#v", cfg.Vast) + } +} + +func TestProviderFlagsDoNotExposeAPIKey(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + RegisterVastProviderFlags(fs, core.Config{}) + fs.VisitAll(func(f *flag.Flag) { + if strings.Contains(f.Name, "api-key") { + t.Fatalf("unexpected secret flag --%s", f.Name) + } + }) + forbidden := "vast-" + "api-key" + if fs.Lookup(forbidden) != nil { + t.Fatalf("unexpected --%s flag", forbidden) + } +} + +func TestValidateConfigRejectsUnsafeValues(t *testing.T) { + base := core.Config{Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + DiskGB: 20, + ReleaseAction: "destroy", + }} + tests := []struct { + name string + mutate func(*core.Config) + want string + }{ + { + name: "credential url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://user:secret@vast.example.test" + }, + want: "absolute URL without credentials", + }, + { + name: "instance type", + mutate: func(cfg *core.Config) { + cfg.Vast.InstanceType = "spot" + }, + want: "vast.instanceType", + }, + { + name: "runtype", + mutate: func(cfg *core.Config) { + cfg.Vast.Runtype = "ssh_proxy" + }, + want: "vast.runtype", + }, + { + name: "disk", + mutate: func(cfg *core.Config) { + cfg.Vast.DiskGB = -1 + }, + want: "vast.diskGB", + }, + { + name: "reliability", + mutate: func(cfg *core.Config) { + cfg.Vast.MinReliability = 1.1 + }, + want: "vast.minReliability", + }, + { + name: "release", + mutate: func(cfg *core.Config) { + cfg.Vast.ReleaseAction = "hibernate" + }, + want: "vast.releaseAction", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cfg := base + test.mutate(&cfg) + err := (Provider{}).ValidateConfig(cfg) + if err == nil || !strings.Contains(err.Error(), test.want) { + t.Fatalf("err=%v want %q", err, test.want) + } + }) + } + if err := (Provider{}).ValidateConfig(base); err != nil { + t.Fatalf("valid config rejected: %v", err) + } + base.Vast.InstanceType = "on-demand" + if err := (Provider{}).ValidateConfig(base); err != nil { + t.Fatalf("on-demand alias rejected: %v", err) + } +} From fb2eb8f5ae8433b3ddebb3bbd13b9ae54513e907 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:21:26 -0700 Subject: [PATCH 02/18] feat(vast): add REST client and ownership labels Add the provider-local Vast.ai HTTP client, request and response models, payload builders, redacted API error handling, and compact ownership label helpers for the upcoming lifecycle implementation. Cover the client contract with fake HTTP tests for auth, redirects, offer search, instance creation and management, SSH-key methods, response decoding, redaction, and label ownership safety. --- internal/providers/vast/client.go | 561 +++++++++++++++++++++++++ internal/providers/vast/client_test.go | 284 +++++++++++++ internal/providers/vast/labels.go | 60 +++ internal/providers/vast/labels_test.go | 52 +++ 4 files changed, 957 insertions(+) create mode 100644 internal/providers/vast/client.go create mode 100644 internal/providers/vast/client_test.go create mode 100644 internal/providers/vast/labels.go create mode 100644 internal/providers/vast/labels_test.go diff --git a/internal/providers/vast/client.go b/internal/providers/vast/client.go new file mode 100644 index 000000000..406637ed2 --- /dev/null +++ b/internal/providers/vast/client.go @@ -0,0 +1,561 @@ +package vast + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" +) + +type vastAPI interface { + CheckAuth(context.Context) (vastUser, error) + SearchOffers(context.Context, vastOfferSearchInput) ([]vastOffer, error) + CreateInstance(context.Context, int, vastCreateInstanceInput) (vastCreateInstanceResponse, error) + GetInstance(context.Context, int) (vastInstance, error) + ListInstances(context.Context) ([]vastInstance, error) + ManageInstance(context.Context, int, vastManageInstanceInput) (vastInstance, error) + DestroyInstance(context.Context, int) error + ListInstanceSSHKeys(context.Context, int) ([]vastInstanceSSHKey, error) + AttachInstanceSSHKey(context.Context, int, string) (vastAttachSSHKeyResponse, error) + DetachInstanceSSHKey(context.Context, int, string) error +} + +type vastClient struct { + apiKey string + apiURL string + httpClient *http.Client +} + +const vastMaxResponseBytes = 4 << 20 + +type vastAPIError struct { + Operation string + StatusCode int + Status string + Body string +} + +func (e *vastAPIError) Error() string { + if e.Body == "" { + return fmt.Sprintf("vast %s: %s", e.Operation, e.Status) + } + return fmt.Sprintf("vast %s: %s: %s", e.Operation, e.Status, e.Body) +} + +type vastUser struct { + ID int `json:"id"` + Email string `json:"email"` + Username string `json:"username"` +} + +type vastOffer struct { + ID int `json:"id"` + AskID int `json:"ask_contract_id"` + MachineID int `json:"machine_id"` + GPUName string `json:"gpu_name"` + GPUCount int `json:"num_gpus"` + Reliability float64 `json:"reliability2"` + DphTotal float64 `json:"dph_total"` + SSHHost string `json:"ssh_host"` + SSHPort int `json:"ssh_port"` + Rentable bool `json:"rentable"` + Rented bool `json:"rented"` + Verified bool `json:"verified"` +} + +type vastInstance struct { + ID int `json:"id"` + ContractID int `json:"contract_id"` + Label string `json:"label"` + Status string `json:"actual_status"` + IntendedStatus string `json:"intended_status"` + SSHHost string `json:"ssh_host"` + SSHPort int `json:"ssh_port"` + GPUName string `json:"gpu_name"` + GPUCount int `json:"num_gpus"` + DphTotal float64 `json:"dph_total"` + Image string `json:"image_uuid"` + InstanceAPIKey string `json:"instance_api_key,omitempty"` +} + +type vastCreateInstanceResponse struct { + NewContract int `json:"new_contract"` + Instance vastInstance `json:"instance"` + Success bool `json:"success"` +} + +type vastInstanceSSHKey struct { + ID string `json:"id"` + Name string `json:"name"` + PublicKey string `json:"ssh_key"` +} + +type vastAttachSSHKeyResponse struct { + Success bool `json:"success"` + Key vastInstanceSSHKey `json:"key"` + Keys []vastInstanceSSHKey `json:"keys"` +} + +type vastOfferSearchInput struct { + Config VastConfig +} + +type vastCreateInstanceInput struct { + Config VastConfig + Label string + SSHKey string + Environment map[string]string + OnStart string +} + +type vastManageInstanceInput struct { + State string `json:"state,omitempty"` + Label string `json:"label,omitempty"` +} + +func newVastClient(cfg VastConfig, rt Runtime) (vastAPI, error) { + apiKey := strings.TrimSpace(cfg.APIKey) + if apiKey == "" { + return nil, exit(2, "provider=%s requires CRABBOX_VAST_API_KEY or VAST_API_KEY", providerName) + } + apiURL := strings.TrimRight(strings.TrimSpace(cfg.APIURL), "/") + if apiURL == "" { + apiURL = "https://console.vast.ai/api/v0" + } + parsed, err := url.Parse(apiURL) + if err != nil || parsed.Scheme == "" || parsed.Host == "" || parsed.User != nil { + return nil, exit(2, "vast.apiUrl must be an absolute URL without credentials") + } + if parsed.Scheme != "https" && !isLoopbackHTTPURL(parsed) { + return nil, exit(2, "vast.apiUrl must use https unless it targets localhost") + } + httpClient := rt.HTTP + if httpClient == nil { + httpClient = http.DefaultClient + } + return &vastClient{apiKey: apiKey, apiURL: apiURL, httpClient: secureVastHTTPClient(httpClient, apiURL)}, nil +} + +func secureVastHTTPClient(source *http.Client, apiURL string) *http.Client { + client := *source + trusted, _ := url.Parse(apiURL) + originalCheckRedirect := source.CheckRedirect + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if !sameVastOrigin(trusted, req.URL) { + return fmt.Errorf("%s refused cross-origin redirect to %s", providerName, req.URL.Redacted()) + } + if originalCheckRedirect != nil { + return originalCheckRedirect(req, via) + } + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + } + return &client +} + +func (c *vastClient) do(ctx context.Context, method, path string, body any, out any) error { + var reader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return err + } + reader = bytes.NewReader(data) + } + req, err := http.NewRequestWithContext(ctx, method, c.apiURL+path, reader) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := c.httpClient.Do(req) + if err != nil { + return redactVastString(err.Error(), c.apiKey) + } + defer resp.Body.Close() + data, readErr := io.ReadAll(io.LimitReader(resp.Body, vastMaxResponseBytes+1)) + operation := method + " " + path + if len(data) > vastMaxResponseBytes { + return fmt.Errorf("vast %s response exceeds %d bytes", operation, vastMaxResponseBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return c.decodeAPIError(operation, resp.StatusCode, resp.Status, data, readErr) + } + if readErr != nil { + return fmt.Errorf("vast %s response body: %w", operation, readErr) + } + if out != nil && len(strings.TrimSpace(string(data))) > 0 { + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("decode vast %s response: %w", operation, err) + } + } + return nil +} + +func (c *vastClient) decodeAPIError(operation string, statusCode int, status string, data []byte, readErr error) error { + body := strings.TrimSpace(string(data)) + if len(body) > 1600 { + body = body[:1600] + } + body = redactVastText(body, c.apiKey) + if readErr != nil { + if body != "" { + body += "; " + } + body += "response body read failed: " + readErr.Error() + } + return &vastAPIError{Operation: operation, StatusCode: statusCode, Status: status, Body: body} +} + +func (c *vastClient) CheckAuth(ctx context.Context) (vastUser, error) { + var out vastUser + err := c.do(ctx, http.MethodGet, "/users/current/", nil, &out) + return out, err +} + +func (c *vastClient) SearchOffers(ctx context.Context, input vastOfferSearchInput) ([]vastOffer, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodPost, "/bundles/", buildVastOfferSearchPayload(input.Config), &raw); err != nil { + return nil, err + } + return decodeVastOffers(raw) +} + +func (c *vastClient) CreateInstance(ctx context.Context, offerID int, input vastCreateInstanceInput) (vastCreateInstanceResponse, error) { + var raw json.RawMessage + path := "/asks/" + strconv.Itoa(offerID) + "/" + if err := c.do(ctx, http.MethodPut, path, buildVastCreatePayload(input), &raw); err != nil { + return vastCreateInstanceResponse{}, err + } + return decodeVastCreateInstanceResponse(raw) +} + +func (c *vastClient) GetInstance(ctx context.Context, id int) (vastInstance, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodGet, "/instances/"+strconv.Itoa(id)+"/", nil, &raw); err != nil { + return vastInstance{}, err + } + return decodeVastInstance(raw) +} + +func (c *vastClient) ListInstances(ctx context.Context) ([]vastInstance, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodGet, "/instances/", nil, &raw); err != nil { + return nil, err + } + return decodeVastInstances(raw) +} + +func (c *vastClient) ManageInstance(ctx context.Context, id int, input vastManageInstanceInput) (vastInstance, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodPut, "/instances/"+strconv.Itoa(id)+"/", input, &raw); err != nil { + return vastInstance{}, err + } + return decodeVastInstance(raw) +} + +func (c *vastClient) DestroyInstance(ctx context.Context, id int) error { + return c.do(ctx, http.MethodDelete, "/instances/"+strconv.Itoa(id)+"/", nil, nil) +} + +func (c *vastClient) ListInstanceSSHKeys(ctx context.Context, id int) ([]vastInstanceSSHKey, error) { + var raw json.RawMessage + if err := c.do(ctx, http.MethodGet, "/instances/"+strconv.Itoa(id)+"/ssh/", nil, &raw); err != nil { + return nil, err + } + return decodeVastSSHKeys(raw) +} + +func (c *vastClient) AttachInstanceSSHKey(ctx context.Context, id int, publicKey string) (vastAttachSSHKeyResponse, error) { + var out vastAttachSSHKeyResponse + body := map[string]string{"ssh_key": publicKey} + err := c.do(ctx, http.MethodPost, "/instances/"+strconv.Itoa(id)+"/ssh/", body, &out) + return out, err +} + +func (c *vastClient) DetachInstanceSSHKey(ctx context.Context, id int, keyID string) error { + return c.do(ctx, http.MethodDelete, "/instances/"+strconv.Itoa(id)+"/ssh/"+url.PathEscape(keyID)+"/", nil, nil) +} + +func buildVastOfferSearchPayload(cfg VastConfig) map[string]any { + query := map[string]any{ + "verified": true, + "rentable": true, + "rented": false, + "direct_port_count": map[string]any{ + "gte": 1, + }, + } + if cfg.GPUName != "" { + query["gpu_name"] = cfg.GPUName + } + if cfg.GPUCount > 0 { + query["num_gpus"] = map[string]any{"gte": cfg.GPUCount} + } + if cfg.MinReliability > 0 { + query["reliability2"] = map[string]any{"gte": cfg.MinReliability} + } + if cfg.MaxDphTotal > 0 { + query["dph_total"] = map[string]any{"lte": cfg.MaxDphTotal} + } + instanceType := normalizeInstanceType(strings.TrimSpace(cfg.InstanceType)) + if instanceType == "" { + instanceType = "ondemand" + } + return map[string]any{ + "type": instanceType, + "query": query, + "order": strings.TrimSpace(cfg.Order), + } +} + +func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { + cfg := input.Config + payload := map[string]any{ + "runtype": strings.TrimSpace(cfg.Runtype), + "target_state": "running", + "cancel_unavail": true, + "vm": false, + } + if cfg.Image != "" { + payload["image"] = cfg.Image + } + if cfg.TemplateID != "" { + payload["template_id"] = cfg.TemplateID + } + if cfg.DiskGB > 0 { + payload["disk"] = cfg.DiskGB + } + if input.Label != "" { + payload["label"] = input.Label + } + if input.SSHKey != "" { + payload["ssh_key"] = input.SSHKey + } + if len(input.Environment) > 0 { + payload["env"] = input.Environment + } + if input.OnStart != "" { + payload["onstart"] = input.OnStart + } + return payload +} + +func decodeVastOffers(raw json.RawMessage) ([]vastOffer, error) { + var direct []vastOffer + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var envelope struct { + Offers []vastOffer `json:"offers"` + Bundles []vastOffer `json:"bundles"` + Data []vastOffer `json:"data"` + Results []vastOffer `json:"results"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, err + } + switch { + case envelope.Offers != nil: + return envelope.Offers, nil + case envelope.Bundles != nil: + return envelope.Bundles, nil + case envelope.Data != nil: + return envelope.Data, nil + default: + return envelope.Results, nil + } +} + +func decodeVastCreateInstanceResponse(raw json.RawMessage) (vastCreateInstanceResponse, error) { + var out vastCreateInstanceResponse + if err := json.Unmarshal(raw, &out); err != nil { + return out, err + } + if out.Instance.ID == 0 { + inst, err := decodeVastInstance(raw) + if err == nil { + out.Instance = inst + } + } + if out.NewContract == 0 { + out.NewContract = firstNonZero(out.Instance.ContractID, out.Instance.ID) + } + return out, nil +} + +func decodeVastInstance(raw json.RawMessage) (vastInstance, error) { + var direct vastInstance + if err := json.Unmarshal(raw, &direct); err != nil { + return vastInstance{}, err + } + if direct.ID != 0 || direct.ContractID != 0 || direct.SSHHost != "" { + return normalizeVastInstance(direct), nil + } + var envelope struct { + Instance vastInstance `json:"instance"` + Data vastInstance `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return vastInstance{}, err + } + if envelope.Instance.ID != 0 || envelope.Instance.ContractID != 0 || envelope.Instance.SSHHost != "" { + return normalizeVastInstance(envelope.Instance), nil + } + return normalizeVastInstance(envelope.Data), nil +} + +func decodeVastInstances(raw json.RawMessage) ([]vastInstance, error) { + var direct []vastInstance + if err := json.Unmarshal(raw, &direct); err == nil { + return normalizeVastInstances(direct), nil + } + var envelope struct { + Instances []vastInstance `json:"instances"` + Data []vastInstance `json:"data"` + Results []vastInstance `json:"results"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, err + } + switch { + case envelope.Instances != nil: + return normalizeVastInstances(envelope.Instances), nil + case envelope.Data != nil: + return normalizeVastInstances(envelope.Data), nil + default: + return normalizeVastInstances(envelope.Results), nil + } +} + +func decodeVastSSHKeys(raw json.RawMessage) ([]vastInstanceSSHKey, error) { + var direct []vastInstanceSSHKey + if err := json.Unmarshal(raw, &direct); err == nil { + return direct, nil + } + var envelope struct { + Keys []vastInstanceSSHKey `json:"keys"` + Data []vastInstanceSSHKey `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, err + } + if envelope.Keys != nil { + return envelope.Keys, nil + } + return envelope.Data, nil +} + +func normalizeVastInstances(instances []vastInstance) []vastInstance { + for i := range instances { + instances[i] = normalizeVastInstance(instances[i]) + } + return instances +} + +func normalizeVastInstance(instance vastInstance) vastInstance { + if instance.ContractID == 0 { + instance.ContractID = instance.ID + } + if instance.ID == 0 { + instance.ID = instance.ContractID + } + if instance.Status == "" { + instance.Status = instance.IntendedStatus + } + return instance +} + +func sameVastOrigin(a, b *url.URL) bool { + return a != nil && b != nil && + strings.EqualFold(a.Scheme, b.Scheme) && + strings.EqualFold(a.Hostname(), b.Hostname()) && + effectiveVastPort(a) == effectiveVastPort(b) +} + +func effectiveVastPort(value *url.URL) string { + if port := value.Port(); port != "" { + return port + } + switch strings.ToLower(value.Scheme) { + case "https": + return "443" + case "http": + return "80" + default: + return "" + } +} + +func isLoopbackHTTPURL(parsed *url.URL) bool { + if parsed == nil || parsed.Scheme != "http" { + return false + } + host := strings.ToLower(parsed.Hostname()) + ip := net.ParseIP(host) + return host == "localhost" || host == "127.0.0.1" || host == "::1" || (ip != nil && ip.IsLoopback()) +} + +func redactVastString(value, apiKey string) error { + return errors.New(redactVastText(value, apiKey)) +} + +func redactVastText(value, apiKey string) string { + out := value + if apiKey != "" { + out = strings.ReplaceAll(out, apiKey, "") + } + for _, field := range []string{ + "authorization", "api_key", "apiKey", "instance_api_key", "instanceApiKey", + "ssh_key", "private_key", "privateKey", "jupyter_token", "jupyterToken", + "user_data", "userData", "token", + } { + out = redactVastJSONishField(out, field) + out = redactVastInlineField(out, field) + } + out = redactVastPrivateKeyBlock(out) + out = redactVastTokenURLs(out) + return out +} + +func redactVastJSONishField(body, field string) string { + pattern := regexp.MustCompile(`(?i)("` + regexp.QuoteMeta(field) + `"\s*:\s*)("[^"]*"|[^,}\s]+)`) + return pattern.ReplaceAllString(body, `${1}""`) +} + +func redactVastInlineField(body, field string) string { + pattern := regexp.MustCompile(`(?i)(\b` + regexp.QuoteMeta(field) + `\s*[=:]\s*)[^",\s]+`) + return pattern.ReplaceAllString(body, `${1}`) +} + +func redactVastPrivateKeyBlock(body string) string { + pattern := regexp.MustCompile(`-----BEGIN [A-Z ]*PRIVATE KEY-----[\s\S]*?-----END [A-Z ]*PRIVATE KEY-----`) + return pattern.ReplaceAllString(body, "") +} + +func redactVastTokenURLs(body string) string { + pattern := regexp.MustCompile(`https?://[^\s"']*(?i:(token|api_key|instance_api_key)=)[^\s"']+`) + return pattern.ReplaceAllString(body, "") +} + +func firstNonZero(values ...int) int { + for _, value := range values { + if value != 0 { + return value + } + } + return 0 +} diff --git a/internal/providers/vast/client_test.go b/internal/providers/vast/client_test.go new file mode 100644 index 000000000..6562baa9c --- /dev/null +++ b/internal/providers/vast/client_test.go @@ -0,0 +1,284 @@ +package vast + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestClientSendsBearerAuthAndRefusesCrossOriginRedirect(t *testing.T) { + redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("cross-origin redirect target received request with auth=%q", r.Header.Get("Authorization")) + })) + defer redirectTarget.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectTarget.URL+"/capture", http.StatusFound) + })) + defer server.Close() + + client, err := newVastClient(VastConfig{APIKey: "vast-secret", APIURL: server.URL}, Runtime{HTTP: server.Client()}) + if err != nil { + t.Fatal(err) + } + _, err = client.CheckAuth(context.Background()) + if err == nil || !strings.Contains(err.Error(), "refused cross-origin redirect") { + t.Fatalf("err=%v want cross-origin redirect refusal", err) + } +} + +func TestClientCheckAuthUsesCurrentUserEndpoint(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/api/v0/users/current/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer vast-secret" { + t.Fatalf("Authorization=%q", got) + } + writeJSON(t, w, map[string]any{"id": 7, "username": "alice"}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + user, err := client.CheckAuth(context.Background()) + if err != nil { + t.Fatal(err) + } + if user.ID != 7 || user.Username != "alice" { + t.Fatalf("user=%#v", user) + } +} + +func TestRedactVastAPIErrorSecrets(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"api_key":"vast-secret","instance_api_key":"inst-secret","jupyter_url":"https://host/?token=jupyter-secret","user_data":"secret cloud init","private_key":"-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----"}`)) + })) + defer server.Close() + + client := newTestVastClient(t, server) + _, err := client.CheckAuth(context.Background()) + if err == nil { + t.Fatal("expected API error") + } + text := err.Error() + for _, secret := range []string{"vast-secret", "inst-secret", "jupyter-secret", "secret cloud init", "BEGIN PRIVATE KEY", "abc"} { + if strings.Contains(text, secret) { + t.Fatalf("error leaked %q in %q", secret, text) + } + } + if !strings.Contains(text, "") { + t.Fatalf("error was not redacted: %q", text) + } +} + +func TestOfferSearchPayloadAndDecode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/api/v0/bundles/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + query := body["query"].(map[string]any) + if body["type"] != "ondemand" || body["order"] != "dlperf_per_dphtotal desc" || + query["gpu_name"] != "H100" || query["verified"] != true || query["rentable"] != true || query["rented"] != false { + t.Fatalf("unexpected search body: %#v", body) + } + if query["direct_port_count"].(map[string]any)["gte"] != float64(1) || + query["num_gpus"].(map[string]any)["gte"] != float64(4) || + query["reliability2"].(map[string]any)["gte"] != 0.95 || + query["dph_total"].(map[string]any)["lte"] != 3.5 { + t.Fatalf("unexpected query filters: %#v", query) + } + writeJSON(t, w, map[string]any{"offers": []map[string]any{{"id": 11, "gpu_name": "H100", "ssh_host": "203.0.113.10", "ssh_port": 2201}}}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + offers, err := client.SearchOffers(context.Background(), vastOfferSearchInput{Config: VastConfig{ + InstanceType: "on-demand", + GPUName: "H100", + GPUCount: 4, + MaxDphTotal: 3.5, + MinReliability: 0.95, + Order: "dlperf_per_dphtotal desc", + }}) + if err != nil { + t.Fatal(err) + } + if len(offers) != 1 || offers[0].ID != 11 || offers[0].SSHPort != 2201 { + t.Fatalf("offers=%#v", offers) + } +} + +func TestCreateInstancePayloadAndDecodeNewContract(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut || r.URL.Path != "/api/v0/asks/42/" { + t.Fatalf("request=%s %s", r.Method, r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body["runtype"] != "ssh_direct" || body["target_state"] != "running" || + body["cancel_unavail"] != true || body["vm"] != false || + body["image"] != "nvidia/cuda:12" || body["template_id"] != "tpl-123" || + body["disk"] != float64(80) || body["label"] != "cbx1|lease|slug|active" || + body["ssh_key"] != "ssh-ed25519 AAAA..." { + t.Fatalf("unexpected create body: %#v", body) + } + env := body["env"].(map[string]any) + if env["CRABBOX"] != "1" { + t.Fatalf("env=%#v", env) + } + writeJSON(t, w, map[string]any{"success": true, "new_contract": 99}) + })) + defer server.Close() + + client := newTestVastClient(t, server) + resp, err := client.CreateInstance(context.Background(), 42, vastCreateInstanceInput{ + Config: VastConfig{Image: "nvidia/cuda:12", TemplateID: "tpl-123", Runtype: "ssh_direct", DiskGB: 80}, + Label: "cbx1|lease|slug|active", + SSHKey: "ssh-ed25519 AAAA...", + Environment: map[string]string{"CRABBOX": "1"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.NewContract != 99 || !resp.Success { + t.Fatalf("resp=%#v", resp) + } +} + +func TestInstanceMethodsAndDecoding(t *testing.T) { + var seen []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen = append(seen, r.Method+" "+r.URL.Path) + switch r.Method + " " + r.URL.Path { + case "GET /api/v0/instances/99/": + writeJSON(t, w, map[string]any{"id": 99, "actual_status": "running", "ssh_host": "198.51.100.8", "ssh_port": 2222, "gpu_name": "RTX 4090"}) + case "GET /api/v0/instances/": + writeJSON(t, w, map[string]any{"instances": []map[string]any{{"id": 99}, {"contract_id": 100}}}) + case "PUT /api/v0/instances/99/": + var body vastManageInstanceInput + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body.State != "stopped" || body.Label != "cbx1|lease|slug|stopped" { + t.Fatalf("manage body=%#v", body) + } + writeJSON(t, w, map[string]any{"instance": map[string]any{"id": 99, "intended_status": "stopped"}}) + case "DELETE /api/v0/instances/99/": + w.WriteHeader(http.StatusNoContent) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + client := newTestVastClient(t, server) + instance, err := client.GetInstance(context.Background(), 99) + if err != nil { + t.Fatal(err) + } + if instance.ID != 99 || instance.SSHHost != "198.51.100.8" || instance.SSHPort != 2222 { + t.Fatalf("instance=%#v", instance) + } + list, err := client.ListInstances(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(list) != 2 || list[1].ID != 100 { + t.Fatalf("list=%#v", list) + } + managed, err := client.ManageInstance(context.Background(), 99, vastManageInstanceInput{State: "stopped", Label: "cbx1|lease|slug|stopped"}) + if err != nil { + t.Fatal(err) + } + if managed.Status != "stopped" { + t.Fatalf("managed=%#v", managed) + } + if err := client.DestroyInstance(context.Background(), 99); err != nil { + t.Fatal(err) + } + if strings.Join(seen, ",") != "GET /api/v0/instances/99/,GET /api/v0/instances/,PUT /api/v0/instances/99/,DELETE /api/v0/instances/99/" { + t.Fatalf("seen=%v", seen) + } +} + +func TestInstanceSSHKeyMethods(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /api/v0/instances/99/ssh/": + writeJSON(t, w, map[string]any{"keys": []map[string]any{{"id": "key-1", "ssh_key": "ssh-ed25519 AAAA..."}}}) + case "POST /api/v0/instances/99/ssh/": + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body["ssh_key"] != "ssh-ed25519 AAAA..." { + t.Fatalf("attach body=%#v", body) + } + writeJSON(t, w, map[string]any{"success": true, "key": map[string]any{"id": "key-1"}}) + case "DELETE /api/v0/instances/99/ssh/key-1/": + w.WriteHeader(http.StatusNoContent) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + client := newTestVastClient(t, server) + keys, err := client.ListInstanceSSHKeys(context.Background(), 99) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 || keys[0].ID != "key-1" { + t.Fatalf("keys=%#v", keys) + } + attached, err := client.AttachInstanceSSHKey(context.Background(), 99, "ssh-ed25519 AAAA...") + if err != nil { + t.Fatal(err) + } + if !attached.Success || attached.Key.ID != "key-1" { + t.Fatalf("attached=%#v", attached) + } + if err := client.DetachInstanceSSHKey(context.Background(), 99, "key-1"); err != nil { + t.Fatal(err) + } +} + +func TestClientRejectsNonHTTPSExceptLoopback(t *testing.T) { + if _, err := newVastClient(VastConfig{APIKey: "secret", APIURL: "http://vast.example.test"}, Runtime{}); err == nil { + t.Fatal("expected non-https non-loopback rejection") + } + if _, err := newVastClient(VastConfig{APIKey: "secret", APIURL: "http://127.0.0.1:8080/api/v0"}, Runtime{}); err != nil { + t.Fatalf("loopback rejected: %v", err) + } +} + +func newTestVastClient(t *testing.T, server *httptest.Server) *vastClient { + t.Helper() + api, err := newVastClient(VastConfig{APIKey: "vast-secret", APIURL: server.URL + "/api/v0"}, Runtime{HTTP: server.Client()}) + if err != nil { + t.Fatal(err) + } + client, ok := api.(*vastClient) + if !ok { + t.Fatalf("client type=%T", api) + } + return client +} + +func writeJSON(t *testing.T, w http.ResponseWriter, value any) { + t.Helper() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(value); err != nil { + t.Fatal(err) + } +} diff --git a/internal/providers/vast/labels.go b/internal/providers/vast/labels.go new file mode 100644 index 000000000..872ab0b5e --- /dev/null +++ b/internal/providers/vast/labels.go @@ -0,0 +1,60 @@ +package vast + +import ( + "strings" + "unicode" +) + +const ( + vastOwnershipLabelPrefix = "cbx1|" + vastLabelMaxPart = 48 +) + +type vastOwnershipLabel struct { + LeaseID string + Slug string + State string +} + +func encodeVastOwnershipLabel(leaseID, slug, state string) string { + leaseID = sanitizeVastLabelPart(leaseID, vastLabelMaxPart) + slug = sanitizeVastLabelPart(slug, vastLabelMaxPart) + state = sanitizeVastLabelPart(state, 16) + parts := []string{leaseID, slug, state} + return vastOwnershipLabelPrefix + strings.Join(parts, "|") +} + +func decodeVastOwnershipLabel(label string) (vastOwnershipLabel, bool) { + if !strings.HasPrefix(label, vastOwnershipLabelPrefix) { + return vastOwnershipLabel{}, false + } + parts := strings.Split(strings.TrimPrefix(label, vastOwnershipLabelPrefix), "|") + if len(parts) != 3 || parts[0] == "" { + return vastOwnershipLabel{}, false + } + return vastOwnershipLabel{LeaseID: parts[0], Slug: parts[1], State: parts[2]}, true +} + +func isVastCrabboxOwnedLabel(label string) bool { + _, ok := decodeVastOwnershipLabel(label) + return ok +} + +func sanitizeVastLabelPart(value string, limit int) string { + value = strings.TrimSpace(value) + var b strings.Builder + for _, r := range value { + switch { + case unicode.IsLetter(r) || unicode.IsDigit(r): + b.WriteRune(r) + case r == '-' || r == '_' || r == '.': + b.WriteRune(r) + default: + b.WriteByte('-') + } + if b.Len() >= limit { + break + } + } + return strings.Trim(b.String(), "-_.") +} diff --git a/internal/providers/vast/labels_test.go b/internal/providers/vast/labels_test.go new file mode 100644 index 000000000..331160922 --- /dev/null +++ b/internal/providers/vast/labels_test.go @@ -0,0 +1,52 @@ +package vast + +import ( + "strings" + "testing" +) + +func TestOwnershipLabelRoundTrip(t *testing.T) { + label := encodeVastOwnershipLabel("lease_123", "gpu-box", "active") + got, ok := decodeVastOwnershipLabel(label) + if !ok { + t.Fatalf("label did not decode: %q", label) + } + if got.LeaseID != "lease_123" || got.Slug != "gpu-box" || got.State != "active" { + t.Fatalf("decoded=%#v", got) + } + if !isVastCrabboxOwnedLabel(label) { + t.Fatalf("owned label not recognized: %q", label) + } +} + +func TestOwnershipLabelRejectsMalformedAndManualLabels(t *testing.T) { + for _, label := range []string{ + "", + "crabbox lease lease_123", + "manual-crabbox-instance", + "cbx1|", + "cbx1|lease|missing-state", + "cbx2|lease|slug|active", + } { + if _, ok := decodeVastOwnershipLabel(label); ok { + t.Fatalf("label %q decoded as owned", label) + } + if isVastCrabboxOwnedLabel(label) { + t.Fatalf("label %q recognized as owned", label) + } + } +} + +func TestOwnershipLabelSanitizesAndBoundsParts(t *testing.T) { + label := encodeVastOwnershipLabel("lease/with spaces", strings.Repeat("x", 80), "active now") + if strings.ContainsAny(label, " /") { + t.Fatalf("label was not sanitized: %q", label) + } + got, ok := decodeVastOwnershipLabel(label) + if !ok { + t.Fatalf("label did not decode: %q", label) + } + if got.LeaseID != "lease-with-spaces" || len(got.Slug) != vastLabelMaxPart || got.State != "active-now" { + t.Fatalf("decoded=%#v label=%q", got, label) + } +} From 3f9c6221ce35713752830acf6caad2925e259016 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:40:57 -0700 Subject: [PATCH 03/18] feat(vast): implement SSH lease lifecycle Add the Vast SSH-lease backend for acquire, resolve, list, release, touch, doctor, and cleanup using generated per-lease keys and strict cbx1 ownership labels. Cleanup is guarded by local claims plus provider labels, release destroys by default, and lifecycle tests use a fake Vast API/readiness path without live credentials. --- internal/providers/vast/backend.go | 720 ++++++++++++++++++++++++ internal/providers/vast/backend_test.go | 419 ++++++++++++++ internal/providers/vast/core.go | 1 + internal/providers/vast/provider.go | 43 +- 4 files changed, 1142 insertions(+), 41 deletions(-) create mode 100644 internal/providers/vast/backend.go create mode 100644 internal/providers/vast/backend_test.go diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go new file mode 100644 index 000000000..4c7ad3ca5 --- /dev/null +++ b/internal/providers/vast/backend.go @@ -0,0 +1,720 @@ +package vast + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "strings" + "time" + + core "github.com/openclaw/crabbox/internal/cli" + "github.com/openclaw/crabbox/internal/providers/shared" +) + +const ( + vastPollInterval = 3 * time.Second + vastPollTimeout = 10 * time.Minute + vastCleanupTimeout = 30 * time.Second + vastKeyIDLabel = "provider_key_id" + vastKeyOwnedLabel = "provider_key_owned" + vastOfferIDLabel = "vast_offer_id" + vastReleaseActionLabel = "release_action" +) + +type backend struct { + shared.DirectSSHBackend + cfg core.Config + rt core.Runtime + apiFactory func(core.Runtime) (vastAPI, error) + waitSSH func(context.Context, *core.SSHTarget, string, time.Duration) error + sleep func(context.Context, time.Duration) error + pollTimeout time.Duration + cleanupTimeout time.Duration +} + +func newBackend(spec core.ProviderSpec, cfg core.Config, rt core.Runtime) *backend { + applyVastDefaults(&cfg) + b := &backend{cfg: cfg, rt: rt, pollTimeout: vastPollTimeout, cleanupTimeout: vastCleanupTimeout} + b.DirectSSHBackend = shared.DirectSSHBackend{SpecValue: spec, Cfg: cfg, RT: rt, Delete: b.deleteServer, StoredLeaseKeys: true} + b.apiFactory = func(rt core.Runtime) (vastAPI, error) { return newVastClient(cfg.Vast, rt) } + b.waitSSH = func(ctx context.Context, target *core.SSHTarget, phase string, timeout time.Duration) error { + return core.WaitForSSHReady(ctx, target, b.stderr(), phase, timeout) + } + b.sleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + return b +} + +func applyVastDefaults(cfg *core.Config) { + cfg.Provider = providerName + if cfg.TargetOS == "" { + cfg.TargetOS = core.TargetLinux + } + if cfg.Vast.User != "" { + cfg.SSHUser = cfg.Vast.User + } else if cfg.SSHUser == "" { + cfg.SSHUser = "root" + } + if cfg.Vast.WorkRoot != "" { + cfg.WorkRoot = cfg.Vast.WorkRoot + } + if cfg.WorkRoot == "" { + cfg.WorkRoot = "/work/crabbox" + } + if cfg.SSHPort == "" { + cfg.SSHPort = "22" + } + if cfg.Vast.InstanceType == "" { + cfg.Vast.InstanceType = "ondemand" + } + if cfg.Vast.Runtype == "" { + cfg.Vast.Runtype = "ssh_direct" + } + if cfg.Vast.Order == "" { + cfg.Vast.Order = "dlperf_per_dphtotal desc" + } + if cfg.Vast.ReleaseAction == "" { + cfg.Vast.ReleaseAction = "destroy" + } +} + +func (b *backend) stderr() io.Writer { + if b.rt.Stderr != nil { + return b.rt.Stderr + } + return io.Discard +} + +func (b *backend) stdout() io.Writer { + if b.rt.Stdout != nil { + return b.rt.Stdout + } + return io.Discard +} + +func (b *backend) now() time.Time { + if b.rt.Clock != nil { + return b.rt.Clock.Now().UTC() + } + return time.Now().UTC() +} + +func (b *backend) api() (vastAPI, error) { + if b.apiFactory != nil { + return b.apiFactory(b.rt) + } + return newVastClient(b.cfg.Vast, b.rt) +} + +func (b *backend) Doctor(ctx context.Context, _ core.DoctorRequest) (core.DoctorResult, error) { + client, err := b.api() + if err != nil { + return core.DoctorResult{}, err + } + if _, err := client.CheckAuth(ctx); err != nil { + return core.DoctorResult{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.DoctorResult{}, err + } + count := 0 + for _, item := range instances { + if isOwnedVastInstance(item) { + count++ + } + } + result := core.InventoryDoctorResult(providerName, count) + result.Message += fmt.Sprintf(" default_order=%s runtype=%s user=%s", b.cfg.Vast.Order, b.cfg.Vast.Runtype, b.cfg.SSHUser) + return result, nil +} + +func (b *backend) Acquire(ctx context.Context, req core.AcquireRequest) (core.LeaseTarget, error) { + return shared.AcquireAttemptsRetry(b.rt, req.Keep, func() (core.LeaseTarget, error) { + return b.acquireOnce(ctx, req) + }) +} + +func (b *backend) acquireOnce(ctx context.Context, req core.AcquireRequest) (target core.LeaseTarget, err error) { + if b.cfg.TargetOS != "" && b.cfg.TargetOS != core.TargetLinux { + return core.LeaseTarget{}, exit(2, "provider=%s supports target=linux only", providerName) + } + client, err := b.api() + if err != nil { + return core.LeaseTarget{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.LeaseTarget{}, err + } + servers := serversFromInstances(instances, b.cfg, false) + leaseID := core.NewLeaseID() + slug, err := core.AllocateDirectLeaseSlug(leaseID, req.RequestedSlug, servers) + if err != nil { + return core.LeaseTarget{}, err + } + keyPath, publicKey, err := core.EnsureTestboxKeyForConfig(b.cfg, leaseID) + if err != nil { + return core.LeaseTarget{}, err + } + cfg := b.cfg + cfg.SSHKey = keyPath + cfg.ProviderKey = core.ProviderKeyForLease(leaseID) + now := b.now() + label := encodeVastOwnershipLabel(leaseID, slug, "provisioning") + var ( + instanceID int + keyID string + committed bool + ) + defer func() { + if err == nil || committed { + return + } + if instanceID == 0 { + core.RemoveStoredTestboxKey(leaseID) + return + } + _ = b.persistRecoveryClaim(leaseID, slug, cfg, req.Repo.Root, instanceID, keyID, "rollback-cleanup", req.Keep, now) + if !req.Keep { + cleanupErr := rollbackVastAcquire(client, instanceID, keyID) + if cleanupErr != nil { + err = fmt.Errorf("%v; vast cleanup failed: %w", err, cleanupErr) + return + } + core.RemoveLeaseClaim(leaseID) + core.RemoveStoredTestboxKey(leaseID) + } + }() + offers, err := client.SearchOffers(ctx, vastOfferSearchInput{Config: cfg.Vast}) + if err != nil { + return core.LeaseTarget{}, err + } + offer, err := selectVastOffer(offers) + if err != nil { + return core.LeaseTarget{}, err + } + fmt.Fprintf(b.stderr(), "provisioning provider=vast lease=%s slug=%s offer=%d gpu=%s count=%d max_dph=%.4f keep=%v\n", leaseID, slug, offer.ID, offer.GPUName, offer.GPUCount, cfg.Vast.MaxDphTotal, req.Keep) + created, err := client.CreateInstance(ctx, offer.ID, vastCreateInstanceInput{ + Config: cfg.Vast, + Label: label, + SSHKey: publicKey, + Environment: map[string]string{"CRABBOX": "1"}, + }) + if err != nil { + return core.LeaseTarget{}, err + } + instanceID = firstNonZero(created.Instance.ID, created.NewContract) + if instanceID == 0 { + err = exit(5, "vast create returned no instance id") + return core.LeaseTarget{}, err + } + if attach, attachErr := client.AttachInstanceSSHKey(ctx, instanceID, publicKey); attachErr != nil { + err = attachErr + return core.LeaseTarget{}, err + } else { + keyID = vastAttachedKeyID(attach) + } + instance, err := b.waitForInstanceReady(ctx, client, instanceID) + if err != nil { + return core.LeaseTarget{}, err + } + readyLabel := encodeVastOwnershipLabel(leaseID, slug, "ready") + if _, err := client.ManageInstance(ctx, instanceID, vastManageInstanceInput{Label: readyLabel}); err != nil { + return core.LeaseTarget{}, err + } + instance.Label = readyLabel + server := serverFromInstance(instance, cfg) + server.Labels = vastLeaseLabels(cfg, leaseID, slug, "ready", req.Keep, now) + server.Labels[vastOfferIDLabel] = strconv.Itoa(offer.ID) + if keyID != "" { + server.Labels[vastKeyIDLabel] = keyID + } + server.Labels[vastKeyOwnedLabel] = fmt.Sprint(keyID != "") + ssh, err := sshTargetFromInstance(cfg, instance) + if err != nil { + return core.LeaseTarget{}, err + } + if err := b.waitSSH(ctx, &ssh, "vast bootstrap", core.BootstrapWaitTimeout(cfg)); err != nil { + return core.LeaseTarget{}, err + } + target = core.LeaseTarget{Server: server, SSH: ssh, LeaseID: leaseID} + if req.OnAcquired != nil { + if err := req.OnAcquired(target); err != nil { + return core.LeaseTarget{}, err + } + } + if err := core.ClaimLeaseTargetForRepoConfig(leaseID, slug, cfg, server, ssh, req.Repo.Root, cfg.IdleTimeout, req.Reclaim); err != nil { + return core.LeaseTarget{}, err + } + committed = true + fmt.Fprintf(b.stderr(), "provisioned lease=%s vast=%d gpu=%s state=ready\n", leaseID, instanceID, server.ServerType.Name) + return target, nil +} + +func selectVastOffer(offers []vastOffer) (vastOffer, error) { + for _, offer := range offers { + if offer.ID != 0 && offer.Rentable && !offer.Rented { + return offer, nil + } + } + if len(offers) > 0 && offers[0].ID != 0 { + return offers[0], nil + } + return vastOffer{}, exit(4, "vast found no eligible offers") +} + +func vastAttachedKeyID(resp vastAttachSSHKeyResponse) string { + if strings.TrimSpace(resp.Key.ID) != "" { + return strings.TrimSpace(resp.Key.ID) + } + for _, key := range resp.Keys { + if strings.TrimSpace(key.ID) != "" { + return strings.TrimSpace(key.ID) + } + } + return "" +} + +func (b *backend) waitForInstanceReady(ctx context.Context, client vastAPI, id int) (vastInstance, error) { + deadline := b.now().Add(b.pollTimeout) + for { + instance, err := client.GetInstance(ctx, id) + if err != nil { + return vastInstance{}, err + } + if isVastInstanceRunning(instance) && strings.TrimSpace(instance.SSHHost) != "" && instance.SSHPort > 0 { + return instance, nil + } + if isTerminalVastStatus(instance.Status) { + return vastInstance{}, exit(5, "vast instance %d reached terminal status %s", id, instance.Status) + } + if b.now().After(deadline) { + return vastInstance{}, exit(5, "timed out waiting for Vast instance %d to expose SSH", id) + } + if err := b.sleep(ctx, vastPollInterval); err != nil { + return vastInstance{}, err + } + } +} + +func (b *backend) Resolve(ctx context.Context, req core.ResolveRequest) (core.LeaseTarget, error) { + client, err := b.api() + if err != nil { + return core.LeaseTarget{}, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return core.LeaseTarget{}, err + } + byID := map[int]vastInstance{} + for _, item := range instances { + byID[item.ID] = item + } + if id, ok := parseVastInstanceID(req.ID); ok { + item, found := byID[id] + if !found { + item, err = client.GetInstance(ctx, id) + if err != nil { + return b.releaseTargetFromClaim(req.ID, err, req.ReleaseOnly) + } + } + return b.targetFromInstance(item, req) + } + servers := serversFromInstances(instances, b.cfg, false) + server, leaseID, err := core.FindServerByAlias(servers, req.ID) + if err == nil && leaseID != "" { + if id, ok := parseVastInstanceID(server.CloudID); ok { + return b.targetFromInstance(byID[id], req) + } + } + if claim, ok, claimErr := core.ResolveLeaseClaimForProvider(req.ID, providerName); claimErr != nil { + return core.LeaseTarget{}, claimErr + } else if ok { + if id, parseOK := parseVastInstanceID(claim.CloudID); parseOK { + item, getErr := client.GetInstance(ctx, id) + if getErr == nil { + return b.targetFromInstance(item, req) + } + if req.ReleaseOnly { + return claimTarget(claim), nil + } + return core.LeaseTarget{}, getErr + } + if req.ReleaseOnly { + return claimTarget(claim), nil + } + } + if err != nil { + return core.LeaseTarget{}, err + } + return core.LeaseTarget{}, exit(4, "lease/instance not found: %s", req.ID) +} + +func (b *backend) releaseTargetFromClaim(id string, cause error, releaseOnly bool) (core.LeaseTarget, error) { + if !releaseOnly { + return core.LeaseTarget{}, cause + } + claim, ok, err := core.ResolveLeaseClaimForProvider(id, providerName) + if err != nil || !ok { + if err != nil { + return core.LeaseTarget{}, err + } + return core.LeaseTarget{}, cause + } + return claimTarget(claim), nil +} + +func (b *backend) targetFromInstance(item vastInstance, req core.ResolveRequest) (core.LeaseTarget, error) { + if !isOwnedVastInstance(item) { + return core.LeaseTarget{}, exit(2, "refusing to operate on non-Crabbox Vast instance %d", item.ID) + } + if isTerminalVastStatus(item.Status) && !req.ReleaseOnly && !req.StatusOnly { + return core.LeaseTarget{}, exit(5, "vast instance %d reached terminal status %s", item.ID, item.Status) + } + server := serverFromInstance(item, b.cfg) + leaseID := server.Labels["lease"] + target := core.LeaseTarget{Server: server, LeaseID: leaseID} + if !req.ReleaseOnly { + ssh, err := sshTargetFromInstance(b.cfg, item) + if err != nil { + return core.LeaseTarget{}, err + } + target.SSH = ssh + } + if req.Repo.Root != "" && !req.NoLocalStateMutations { + if err := core.ClaimLeaseTargetForRepoConfig(leaseID, server.Labels["slug"], b.cfg, target.Server, target.SSH, req.Repo.Root, b.cfg.IdleTimeout, req.Reclaim); err != nil { + return core.LeaseTarget{}, err + } + } + return target, nil +} + +func (b *backend) List(ctx context.Context, req core.ListRequest) ([]core.LeaseView, error) { + client, err := b.api() + if err != nil { + return nil, err + } + instances, err := client.ListInstances(ctx) + if err != nil { + return nil, err + } + return serversFromInstances(instances, b.cfg, req.All), nil +} + +func (b *backend) ReleaseLease(ctx context.Context, req core.ReleaseLeaseRequest) error { + if err := core.ValidateLeaseTargetProviderIdentity(req.Lease, req.ExpectedProviderIdentity); err != nil { + return err + } + return b.deleteServer(ctx, b.cfg, req.Lease.Server) +} + +func (b *backend) ReleaseLeaseMessage(lease core.LeaseTarget) string { + action := normalizeVastReleaseAction(b.cfg.Vast.ReleaseAction) + if action == "stop" || action == "keep" { + return fmt.Sprintf("%s lease=%s vast=%s name=%s", action, lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) + } + return fmt.Sprintf("destroyed lease=%s vast=%s name=%s", lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) +} + +func (b *backend) Touch(_ context.Context, req core.TouchRequest) (core.Server, error) { + server := req.Lease.Server + if err := validateVastServer(server); err != nil { + return core.Server{}, err + } + cfg := b.cfg + if req.IdleTimeout > 0 { + cfg.IdleTimeout = req.IdleTimeout + } + server.Labels = core.TouchDirectLeaseLabels(server.Labels, cfg, req.State, b.now()) + if claim, ok, err := core.ReadLeaseClaimWithPresence(req.Lease.LeaseID); err == nil && ok { + if _, err := core.UpdateLeaseClaimLabelsIfUnchanged(req.Lease.LeaseID, claim, server.Labels); err != nil { + return core.Server{}, err + } + } + return server, nil +} + +func (b *backend) Cleanup(ctx context.Context, req core.CleanupRequest) error { + servers, err := b.List(ctx, core.ListRequest{Options: req.Options}) + if err != nil { + return err + } + return b.CleanupServers(ctx, req, servers) +} + +func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.Server) error { + if err := validateVastServer(server); err != nil { + return err + } + client, err := b.api() + if err != nil { + return err + } + leaseID := server.Labels["lease"] + claim, claimExists, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil { + return fmt.Errorf("read vast cleanup claim: %w", err) + } + if !claimExists { + return exit(2, "lease=%s has no local Vast claim; refusing destructive cleanup", leaseID) + } + if claim.Provider != providerName { + return exit(2, "lease=%s is claimed by provider=%s; refusing Vast cleanup", leaseID, claim.Provider) + } + if claim.CloudID != "" && server.CloudID != "" && claim.CloudID != server.CloudID { + return exit(2, "refusing to release Vast instance %s from stale local claim", server.CloudID) + } + instanceID, ok := parseVastInstanceID(firstNonBlank(server.CloudID, claim.CloudID)) + if !ok { + return exit(2, "provider=%s release requires a Vast instance id", providerName) + } + if live, getErr := client.GetInstance(ctx, instanceID); getErr == nil { + if err := validateLiveVastInstance(live, server); err != nil { + return err + } + } else if !isVastNotFound(getErr) { + return getErr + } + action := normalizeVastReleaseAction(firstNonBlank(claim.Labels[vastReleaseActionLabel], b.cfg.Vast.ReleaseAction)) + switch action { + case "keep": + return nil + case "stop": + if _, err := client.ManageInstance(ctx, instanceID, vastManageInstanceInput{State: "stopped", Label: encodeVastOwnershipLabel(leaseID, server.Labels["slug"], "stopped")}); err != nil { + return err + } + if err := core.RemoveLeaseClaimIfUnchanged(leaseID, claim); err != nil { + return fmt.Errorf("finalize vast stop claim: %w", err) + } + default: + if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { + return err + } + if keyID := strings.TrimSpace(claim.Labels[vastKeyIDLabel]); keyID != "" && claim.Labels[vastKeyOwnedLabel] == "true" { + if err := client.DetachInstanceSSHKey(ctx, instanceID, keyID); err != nil && !isVastNotFound(err) { + return err + } + } + if err := core.RemoveLeaseClaimIfUnchanged(leaseID, claim); err != nil { + return fmt.Errorf("finalize vast cleanup claim: %w", err) + } + core.RemoveStoredTestboxKey(leaseID) + } + return nil +} + +func serversFromInstances(instances []vastInstance, cfg core.Config, includeAll bool) []core.Server { + out := make([]core.Server, 0, len(instances)) + for _, item := range instances { + if !includeAll && !isOwnedVastInstance(item) { + continue + } + server := serverFromInstance(item, cfg) + if isOwnedVastInstance(item) { + server = mergeVastClaimLabels(server) + } + out = append(out, server) + } + return out +} + +func mergeVastClaimLabels(server core.Server) core.Server { + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server + } + claim, ok, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil || !ok || claim.Provider != providerName { + return server + } + if claim.CloudID != "" && claim.CloudID != server.CloudID { + return server + } + if len(claim.Labels) > 0 { + server.Labels = claim.Labels + } + return server +} + +func serverFromInstance(item vastInstance, cfg core.Config) core.Server { + labels := labelsFromVastInstance(item, cfg) + server := core.Server{ + CloudID: strconv.Itoa(item.ID), + Provider: providerName, + Name: firstNonBlank(labels["slug"], item.Label, strconv.Itoa(item.ID)), + Status: normalizeVastStatus(item.Status), + Labels: labels, + } + server.PublicNet.IPv4.IP = strings.TrimSpace(item.SSHHost) + server.ServerType.Name = firstNonBlank(item.GPUName, cfg.ServerType) + return server +} + +func labelsFromVastInstance(item vastInstance, cfg core.Config) map[string]string { + if owner, ok := decodeVastOwnershipLabel(item.Label); ok { + labels := vastLeaseLabels(cfg, owner.LeaseID, owner.Slug, owner.State, false, time.Now().UTC()) + labels["provider_key"] = core.ProviderKeyForLease(owner.LeaseID) + labels[vastReleaseActionLabel] = normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + return labels + } + return map[string]string{"label": strings.TrimSpace(item.Label)} +} + +func vastLeaseLabels(cfg core.Config, leaseID, slug, state string, keep bool, now time.Time) map[string]string { + labels := core.DirectLeaseLabels(cfg, leaseID, slug, providerName, "", keep, now) + labels["state"] = state + labels[vastReleaseActionLabel] = normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + return labels +} + +func isOwnedVastInstance(item vastInstance) bool { + return isVastCrabboxOwnedLabel(item.Label) +} + +func validateVastServer(server core.Server) error { + if server.Provider != "" && server.Provider != providerName { + return exit(2, "refusing to operate on provider=%s server as Vast", server.Provider) + } + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" || strings.TrimSpace(server.Labels["slug"]) == "" { + return exit(2, "refusing to operate on non-Crabbox Vast instance %s", server.DisplayID()) + } + return nil +} + +func validateLiveVastInstance(item vastInstance, expected core.Server) error { + if !isOwnedVastInstance(item) { + return exit(2, "refusing to operate on non-Crabbox Vast instance %d", item.ID) + } + owner, _ := decodeVastOwnershipLabel(item.Label) + if strconv.Itoa(item.ID) != expected.CloudID || + owner.LeaseID != expected.Labels["lease"] || + owner.Slug != expected.Labels["slug"] { + return exit(2, "refusing to operate on changed Vast instance %s", expected.CloudID) + } + return nil +} + +func sshTargetFromInstance(cfg core.Config, item vastInstance) (core.SSHTarget, error) { + host := strings.TrimSpace(item.SSHHost) + if host == "" || item.SSHPort <= 0 { + return core.SSHTarget{}, exit(5, "vast instance %d is missing SSH endpoint", item.ID) + } + ssh := core.SSHTargetFromConfig(cfg, host) + ssh.Port = strconv.Itoa(item.SSHPort) + ssh.User = firstNonBlank(cfg.SSHUser, cfg.Vast.User, "root") + ssh.TargetOS = core.TargetLinux + return ssh, nil +} + +func isVastInstanceRunning(item vastInstance) bool { + switch strings.ToLower(strings.TrimSpace(item.Status)) { + case "running", "active", "ready": + return true + default: + return false + } +} + +func isTerminalVastStatus(status string) bool { + switch strings.ToLower(strings.TrimSpace(status)) { + case "failed", "error", "exited", "cancelled", "canceled", "destroyed", "deleted", "dead": + return true + default: + return false + } +} + +func normalizeVastStatus(status string) string { + if isVastInstanceRunning(vastInstance{Status: status}) { + return "ready" + } + if status = strings.TrimSpace(status); status != "" { + return status + } + return "unknown" +} + +func normalizeVastReleaseAction(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "stop": + return "stop" + case "keep": + return "keep" + default: + return "destroy" + } +} + +func parseVastInstanceID(value string) (int, bool) { + id, err := strconv.Atoi(strings.TrimSpace(value)) + return id, err == nil && id > 0 +} + +func firstNonBlank(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func claimTarget(claim core.LeaseClaim) core.LeaseTarget { + server := core.Server{ + CloudID: claim.CloudID, + Provider: providerName, + Name: claim.Slug, + Status: claim.Labels["state"], + Labels: claim.Labels, + } + server.PublicNet.IPv4.IP = claim.SSHHost + target := core.SSHTarget{Host: claim.SSHHost, Port: strconv.Itoa(claim.SSHPort), TargetOS: core.TargetLinux} + core.UseStoredTestboxKey(&target, claim.LeaseID) + return core.LeaseTarget{LeaseID: claim.LeaseID, Server: server, SSH: target} +} + +func (b *backend) persistRecoveryClaim(leaseID, slug string, cfg core.Config, repoRoot string, instanceID int, keyID, reason string, keep bool, now time.Time) error { + label := encodeVastOwnershipLabel(leaseID, slug, reason) + server := serverFromInstance(vastInstance{ID: instanceID, Label: label, Status: reason}, cfg) + server.Labels = vastLeaseLabels(cfg, leaseID, slug, reason, keep, now) + server.Labels["recovery"] = reason + if keyID != "" { + server.Labels[vastKeyIDLabel] = keyID + server.Labels[vastKeyOwnedLabel] = "true" + } + return core.ClaimLeaseTargetForRepoConfig(leaseID, slug, cfg, server, core.SSHTarget{}, repoRoot, cfg.IdleTimeout, true) +} + +func rollbackVastAcquire(client vastAPI, instanceID int, keyID string) error { + ctx, cancel := context.WithTimeout(context.Background(), vastCleanupTimeout) + defer cancel() + var errs []error + if keyID != "" { + if err := client.DetachInstanceSSHKey(ctx, instanceID, keyID); err != nil && !isVastNotFound(err) { + errs = append(errs, err) + } + } + if instanceID != 0 { + if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func isVastNotFound(err error) bool { + var apiErr *vastAPIError + return errors.As(err, &apiErr) && apiErr.StatusCode == 404 +} diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go new file mode 100644 index 000000000..ace2fd1ba --- /dev/null +++ b/internal/providers/vast/backend_test.go @@ -0,0 +1,419 @@ +package vast + +import ( + "bytes" + "context" + "errors" + "io" + "strconv" + "strings" + "testing" + "time" + + core "github.com/openclaw/crabbox/internal/cli" +) + +type fakeVastAPI struct { + user vastUser + offers []vastOffer + instances []vastInstance + authErr error + listErr error + createErr error + getErr error + manageErr error + destroyErr error + attachErr error + detachErr error + + searches []vastOfferSearchInput + creates []struct { + offerID int + input vastCreateInstanceInput + } + managed []struct { + id int + input vastManageInstanceInput + } + destroyed []int + attached []struct { + id int + publicKey string + } + detached []struct { + id int + keyID string + } + nextID int +} + +func (f *fakeVastAPI) CheckAuth(context.Context) (vastUser, error) { + if f.authErr != nil { + return vastUser{}, f.authErr + } + if f.user.ID == 0 { + return vastUser{ID: 7, Username: "alice"}, nil + } + return f.user, nil +} + +func (f *fakeVastAPI) SearchOffers(_ context.Context, input vastOfferSearchInput) ([]vastOffer, error) { + f.searches = append(f.searches, input) + return append([]vastOffer(nil), f.offers...), nil +} + +func (f *fakeVastAPI) CreateInstance(_ context.Context, offerID int, input vastCreateInstanceInput) (vastCreateInstanceResponse, error) { + f.creates = append(f.creates, struct { + offerID int + input vastCreateInstanceInput + }{offerID: offerID, input: input}) + if f.createErr != nil { + return vastCreateInstanceResponse{}, f.createErr + } + if f.nextID == 0 { + f.nextID = 100 + } + item := vastInstance{ + ID: f.nextID, + Label: input.Label, + Status: "running", + SSHHost: "203.0.113.24", + SSHPort: 2201, + GPUName: "RTX 4090", + GPUCount: 1, + DphTotal: 0.75, + } + f.instances = append(f.instances, item) + f.nextID++ + return vastCreateInstanceResponse{Success: true, NewContract: item.ID, Instance: item}, nil +} + +func (f *fakeVastAPI) GetInstance(_ context.Context, id int) (vastInstance, error) { + if f.getErr != nil { + return vastInstance{}, f.getErr + } + for _, item := range f.instances { + if item.ID == id { + return item, nil + } + } + return vastInstance{}, &vastAPIError{StatusCode: 404, Status: "404 Not Found"} +} + +func (f *fakeVastAPI) ListInstances(context.Context) ([]vastInstance, error) { + if f.listErr != nil { + return nil, f.listErr + } + return append([]vastInstance(nil), f.instances...), nil +} + +func (f *fakeVastAPI) ManageInstance(_ context.Context, id int, input vastManageInstanceInput) (vastInstance, error) { + f.managed = append(f.managed, struct { + id int + input vastManageInstanceInput + }{id: id, input: input}) + if f.manageErr != nil { + return vastInstance{}, f.manageErr + } + for i := range f.instances { + if f.instances[i].ID == id { + if input.Label != "" { + f.instances[i].Label = input.Label + } + if input.State != "" { + f.instances[i].Status = input.State + } else if f.instances[i].Status == "starting" { + f.instances[i].Status = "running" + } + return f.instances[i], nil + } + } + return vastInstance{}, &vastAPIError{StatusCode: 404, Status: "404 Not Found"} +} + +func (f *fakeVastAPI) DestroyInstance(_ context.Context, id int) error { + f.destroyed = append(f.destroyed, id) + if f.destroyErr != nil { + return f.destroyErr + } + out := f.instances[:0] + for _, item := range f.instances { + if item.ID != id { + out = append(out, item) + } + } + f.instances = out + return nil +} + +func (f *fakeVastAPI) ListInstanceSSHKeys(context.Context, int) ([]vastInstanceSSHKey, error) { + return nil, nil +} + +func (f *fakeVastAPI) AttachInstanceSSHKey(_ context.Context, id int, publicKey string) (vastAttachSSHKeyResponse, error) { + f.attached = append(f.attached, struct { + id int + publicKey string + }{id: id, publicKey: publicKey}) + if f.attachErr != nil { + return vastAttachSSHKeyResponse{}, f.attachErr + } + return vastAttachSSHKeyResponse{Success: true, Key: vastInstanceSSHKey{ID: "key-" + strconv.Itoa(id), PublicKey: publicKey}}, nil +} + +func (f *fakeVastAPI) DetachInstanceSSHKey(_ context.Context, id int, keyID string) error { + f.detached = append(f.detached, struct { + id int + keyID string + }{id: id, keyID: keyID}) + return f.detachErr +} + +func newTestBackend(t *testing.T, api *fakeVastAPI) *backend { + t.Helper() + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", home) + cfg := core.BaseConfig() + cfg.Provider = providerName + cfg.TargetOS = core.TargetLinux + cfg.SSHUser = "root" + cfg.SSHPort = "22" + cfg.WorkRoot = "/work/crabbox" + cfg.Vast.APIURL = "https://console.vast.ai/api/v0" + cfg.Vast.APIKey = "test-key" + cfg.Vast.User = "root" + cfg.Vast.WorkRoot = "/work/crabbox" + cfg.Vast.InstanceType = "ondemand" + cfg.Vast.Runtype = "ssh_direct" + cfg.Vast.Image = "nvidia/cuda:12" + cfg.Vast.Order = "dlperf_per_dphtotal desc" + cfg.Vast.ReleaseAction = "destroy" + b := newBackend(Provider{}.Spec(), cfg, core.Runtime{Stderr: io.Discard}) + b.apiFactory = func(core.Runtime) (vastAPI, error) { return api, nil } + b.waitSSH = func(context.Context, *core.SSHTarget, string, time.Duration) error { return nil } + b.sleep = func(context.Context, time.Duration) error { return nil } + return b +} + +func TestDoctorIsReadOnlyAndCountsOwnedInventory(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 1, Label: encodeVastOwnershipLabel("cbx_owned", "owned", "ready"), Status: "running"}, + {ID: 2, Label: "manual", Status: "running"}, + }} + result, err := newTestBackend(t, api).Doctor(context.Background(), core.DoctorRequest{}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(result.Message, "leases=1") || !strings.Contains(result.Message, "mutation=false") { + t.Fatalf("doctor result=%#v", result) + } + if len(api.creates) != 0 || len(api.destroyed) != 0 || len(api.managed) != 0 { + t.Fatalf("doctor mutated api: creates=%v destroyed=%v managed=%v", api.creates, api.destroyed, api.managed) + } +} + +func TestListFiltersOwnedByDefaultAndAllShowsManual(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 1, Label: encodeVastOwnershipLabel("cbx_owned", "owned", "ready"), Status: "running"}, + {ID: 2, Label: "manual", Status: "running"}, + }} + b := newTestBackend(t, api) + owned, err := b.List(context.Background(), core.ListRequest{}) + if err != nil { + t.Fatal(err) + } + if len(owned) != 1 || owned[0].CloudID != "1" { + t.Fatalf("owned=%#v", owned) + } + all, err := b.List(context.Background(), core.ListRequest{All: true}) + if err != nil { + t.Fatal(err) + } + if len(all) != 2 { + t.Fatalf("all=%#v", all) + } +} + +func TestAcquireCreatesAttachesPollsReadinessAndClaims(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, GPUName: "RTX 4090", GPUCount: 1, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "gpu-box", Keep: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID == "" || lease.Server.CloudID != "100" || lease.SSH.Host != "203.0.113.24" || lease.SSH.Port != "2201" || lease.SSH.User != "root" || lease.SSH.Key == "" { + t.Fatalf("lease=%#v", lease) + } + if len(api.searches) != 1 || api.searches[0].Config.Order != "dlperf_per_dphtotal desc" { + t.Fatalf("searches=%#v", api.searches) + } + if len(api.creates) != 1 || api.creates[0].offerID != 42 || api.creates[0].input.Config.Runtype != "ssh_direct" || api.creates[0].input.Environment["CRABBOX"] != "1" { + t.Fatalf("creates=%#v", api.creates) + } + if !isVastCrabboxOwnedLabel(api.creates[0].input.Label) || api.creates[0].input.SSHKey == "" { + t.Fatalf("create input=%#v", api.creates[0].input) + } + if len(api.attached) != 1 || api.attached[0].id != 100 || api.attached[0].publicKey == "" { + t.Fatalf("attached=%#v", api.attached) + } + if len(api.managed) != 1 || !strings.Contains(api.managed[0].input.Label, "|ready") { + t.Fatalf("managed=%#v", api.managed) + } + claim, ok, err := core.ResolveLeaseClaimForProvider("gpu-box", providerName) + if err != nil || !ok || claim.CloudID != "100" || claim.Labels[vastKeyIDLabel] != "key-100" || claim.Labels[vastReleaseActionLabel] != "destroy" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, err) + } +} + +func TestResolveRejectsTerminalStatusForRunButAllowsRelease(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 9, Label: encodeVastOwnershipLabel("cbx_failed", "failed", "ready"), Status: "failed", SSHHost: "203.0.113.9", SSHPort: 22}}} + b := newTestBackend(t, api) + _, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "failed"}) + if err == nil || !strings.Contains(err.Error(), "terminal status failed") { + t.Fatalf("err=%v", err) + } + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "failed", ReleaseOnly: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_failed" || lease.Server.CloudID != "9" { + t.Fatalf("lease=%#v", lease) + } +} + +func TestAcquireRollsBackOnCallbackFailure(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + _, err := b.Acquire(context.Background(), core.AcquireRequest{ + Repo: core.Repo{Root: t.TempDir()}, + RequestedSlug: "rollback", + OnAcquired: func(core.LeaseTarget) error { + return errors.New("controller unavailable") + }, + }) + if err == nil || !strings.Contains(err.Error(), "controller unavailable") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if len(api.detached) != 1 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("rollback", providerName); claimErr != nil || ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + +func TestAcquirePreservesRecoveryClaimWhenRollbackCleanupFails(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}, destroyErr: errors.New("destroy uncertain")} + b := newTestBackend(t, api) + _, err := b.Acquire(context.Background(), core.AcquireRequest{ + Repo: core.Repo{Root: t.TempDir()}, + RequestedSlug: "recover-me", + OnAcquired: func(core.LeaseTarget) error { + return errors.New("controller unavailable") + }, + }) + if err == nil || !strings.Contains(err.Error(), "vast cleanup failed") { + t.Fatalf("err=%v", err) + } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("recover-me", providerName) + if claimErr != nil || !ok || claim.Labels["recovery"] != "rollback-cleanup" || claim.CloudID != "100" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, claimErr) + } +} + +func TestReleaseDestroysByDefaultAndRemovesClaim(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "destroy-me"}) + if err != nil { + t.Fatal(err) + } + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("destroy-me", providerName); claimErr != nil || ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + +func TestReleaseStopIsExplicitAndTested(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + b.cfg.Vast.ReleaseAction = "stop" + b.DirectSSHBackend.Cfg = b.cfg + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "stop-me"}) + if err != nil { + t.Fatal(err) + } + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if len(api.managed) < 2 || api.managed[len(api.managed)-1].input.State != "stopped" { + t.Fatalf("managed=%#v", api.managed) + } +} + +func TestCleanupDryRunDoesNotDestroyExpiredOwnedInstance(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 8, Label: encodeVastOwnershipLabel("cbx_old", "old", "ready"), Status: "running"}}} + b := newTestBackend(t, api) + server := serverFromInstance(api.instances[0], b.cfg) + server.Labels["expires_at"] = core.LeaseLabelTime(time.Now().Add(-time.Hour)) + if err := core.ClaimLeaseTargetForRepoConfig("cbx_old", "old", b.cfg, server, core.SSHTarget{}, t.TempDir(), time.Minute, false); err != nil { + t.Fatal(err) + } + var stderr bytes.Buffer + b.rt.Stderr = &stderr + b.DirectSSHBackend.RT = b.rt + if err := b.Cleanup(context.Background(), core.CleanupRequest{DryRun: true}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if !strings.Contains(stderr.String(), "delete server id=8") { + t.Fatalf("stderr=%q", stderr.String()) + } +} + +func TestManualUnownedCleanupIsRejected(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 5, Label: "manual-instance", Status: "running", SSHHost: "203.0.113.5", SSHPort: 22}}} + b := newTestBackend(t, api) + manual := serverFromInstance(api.instances[0], b.cfg) + err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: core.LeaseTarget{LeaseID: "manual", Server: manual}}) + if err == nil || !strings.Contains(err.Error(), "non-Crabbox Vast instance") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } +} + +func TestTouchUpdatesLocalClaimLabels(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "touch-me"}) + if err != nil { + t.Fatal(err) + } + updated, err := b.Touch(context.Background(), core.TouchRequest{Lease: lease, State: "busy", IdleTimeout: 2 * time.Hour}) + if err != nil { + t.Fatal(err) + } + if updated.Labels["state"] != "busy" { + t.Fatalf("updated=%#v", updated.Labels) + } + claim, ok, err := core.ResolveLeaseClaimForProvider("touch-me", providerName) + if err != nil || !ok || claim.Labels["state"] != "busy" { + t.Fatalf("claim=%#v ok=%v err=%v", claim, ok, err) + } +} diff --git a/internal/providers/vast/core.go b/internal/providers/vast/core.go index f0214e9cc..30d0c42b3 100644 --- a/internal/providers/vast/core.go +++ b/internal/providers/vast/core.go @@ -22,6 +22,7 @@ type ReleaseLeaseRequest = core.ReleaseLeaseRequest type TouchRequest = core.TouchRequest type LeaseTarget = core.LeaseTarget type Server = core.Server +type SSHTarget = core.SSHTarget const ( providerName = "vast" diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go index 6bf6de77c..0675cf44b 100644 --- a/internal/providers/vast/provider.go +++ b/internal/providers/vast/provider.go @@ -1,7 +1,6 @@ package vast import ( - "context" "flag" "net/url" "strings" @@ -40,14 +39,14 @@ func (Provider) ApplyFlags(cfg *core.Config, fs *flag.FlagSet, values any) error return ApplyVastProviderFlags(cfg, fs, values) } -func (p Provider) Configure(cfg core.Config, _ core.Runtime) (core.Backend, error) { +func (p Provider) Configure(cfg core.Config, rt core.Runtime) (core.Backend, error) { if err := p.ValidateConfig(cfg); err != nil { return nil, err } if cfg.TargetOS != "" && cfg.TargetOS != core.TargetLinux { return nil, exit(2, "provider=%s supports target=linux only", providerName) } - return backend{spec: p.Spec()}, nil + return newBackend(p.Spec(), cfg, rt), nil } func (p Provider) ConfigureDoctor(cfg core.Config, rt core.Runtime) (core.DoctorBackend, error) { @@ -105,41 +104,3 @@ func (Provider) ValidateConfig(cfg core.Config) error { } return nil } - -type backend struct { - spec core.ProviderSpec -} - -func (b backend) Spec() core.ProviderSpec { return b.spec } - -func (b backend) Doctor(context.Context, core.DoctorRequest) (core.DoctorResult, error) { - return core.DoctorResult{ - Provider: providerName, - Status: "unsupported", - Message: "vast lifecycle is not implemented yet", - }, nil -} - -func (b backend) Acquire(context.Context, core.AcquireRequest) (core.LeaseTarget, error) { - return core.LeaseTarget{}, notImplemented("acquire") -} - -func (b backend) Resolve(context.Context, core.ResolveRequest) (core.LeaseTarget, error) { - return core.LeaseTarget{}, notImplemented("resolve") -} - -func (b backend) List(context.Context, core.ListRequest) ([]core.LeaseView, error) { - return nil, notImplemented("list") -} - -func (b backend) ReleaseLease(context.Context, core.ReleaseLeaseRequest) error { - return notImplemented("release") -} - -func (b backend) Touch(context.Context, core.TouchRequest) (core.Server, error) { - return core.Server{}, notImplemented("touch") -} - -func (b backend) Cleanup(context.Context, core.CleanupRequest) error { - return notImplemented("cleanup") -} From e011540944bb6c0d42f4a496e315aaf1f7690b37 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:59:52 -0700 Subject: [PATCH 04/18] docs: add vast provider docs and live smoke --- docs/providers/README.md | 3 +- docs/providers/provider-metadata.json | 14 ++ docs/providers/vast.md | 293 +++++++++++++++++++++++ docs/source-map.md | 9 +- scripts/live-vast-smoke.sh | 311 +++++++++++++++++++++++++ scripts/live-vast-smoke.test.js | 322 ++++++++++++++++++++++++++ 6 files changed, 947 insertions(+), 5 deletions(-) create mode 100644 docs/providers/vast.md create mode 100755 scripts/live-vast-smoke.sh create mode 100644 scripts/live-vast-smoke.test.js diff --git a/docs/providers/README.md b/docs/providers/README.md index 501039077..2f881d34d 100644 --- a/docs/providers/README.md +++ b/docs/providers/README.md @@ -61,7 +61,7 @@ selection metadata. Regenerate it with `node scripts/generate-provider-matrix.mj `scripts/check-docs.sh` fails when provider registration, metadata, docs paths, or this generated table drift. -Current built-in surface: 67 providers (39 SSH lease, 26 delegated run, 2 service control). +Current built-in surface: 68 providers (40 SSH lease, 26 delegated run, 2 service control). Access terms: @@ -133,6 +133,7 @@ Access terms: | [tenki](tenki.md) | built-in; `ssh-lease` · direct-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync` | `linux`; Tenki sandbox VM | `provider-managed`; GPU: unknown | Tenki; sandbox release | Managed Linux sandbox with SSH proxy | Gateway auth uses Tenki-managed key and certificate files | | [tensorlake](tensorlake.md) (`tl`, `tensorlake-sbx`) | built-in; `delegated-run` · delegated-sandbox | No SSH; `provider-owned` · direct only; features: `run-session` | `linux`; Tensorlake Firecracker sandbox | `provider-managed`; GPU: unknown | Tensorlake; provider sandbox cleanup | Hosted Firecracker-backed delegated execution | Does not expose raw Firecracker provisioning | | [upstash-box](upstash-box.md) (`upstash`, `box`, `upstashbox`) | built-in; `delegated-run` · delegated-sandbox | No SSH; `archive-sync` · direct only; features: `archive-sync`, `run-session` | `linux`; Upstash Box sandbox | `provider-managed`; GPU: no | Upstash; sandbox cleanup | Hosted short-lived delegated sandbox | No normal SSH access or coordinator routing | +| [vast](vast.md) (`vast-ai`, `vastai`) | built-in; `ssh-lease` · gpu-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync`, `cleanup` | `linux`; Vast.ai direct GPU instance | `provider-managed`; GPU: yes | Crabbox; destroy by default; optional stop or keep | Direct Linux GPU lease from the Vast.ai offer market | Direct-only and billable; capacity, quota, and offer availability vary | | [vercel-sandbox](vercel-sandbox.md) | built-in; `delegated-run` · delegated-sandbox | No SSH; `archive-sync` · direct only; features: `archive-sync`, `cleanup`, `run-session` | `linux`; Vercel Sandbox microVM | `provider-managed`; GPU: no | Vercel Sandbox; sandbox delete | Hosted delegated Linux microVM execution | Requires SDK bridge support and Vercel Sandbox auth | | [vultr](vultr.md) | built-in; `ssh-lease` · direct-cloud | Crabbox-managed SSH; `crabbox-sync` · direct only; features: `ssh`, `crabbox-sync`, `cleanup` | `linux`; Vultr instance | `cloud`; GPU: optional | Crabbox; instance and key delete | Direct Linux VM on Vultr | Direct-only; firewall groups and VPCs must already exist | | [wandb](wandb.md) (`weights-and-biases`) | built-in; `delegated-run` · gpu-cloud | No SSH; `provider-owned` · direct only; features: `run-session` | `linux`; Weights & Biases run sandbox | `provider-managed`; GPU: optional | Weights & Biases; run termination | Delegated ML or GPU run environment | Execution follows the W&B run contract | diff --git a/docs/providers/provider-metadata.json b/docs/providers/provider-metadata.json index d8f001dbe..d1fb8dc94 100644 --- a/docs/providers/provider-metadata.json +++ b/docs/providers/provider-metadata.json @@ -503,6 +503,20 @@ "caveat": "Direct-only; firewall groups and VPCs must already exist", "docs": "vultr.md" }, + "vast": { + "status": "built-in", + "category": "gpu-cloud", + "substrate": "Vast.ai direct GPU instance", + "location": "provider-managed", + "ssh": "crabbox-managed", + "sync": "crabbox-sync", + "gpu": "yes", + "lifecycle": "Crabbox", + "cleanup": "destroy by default; optional stop or keep", + "bestFit": "Direct Linux GPU lease from the Vast.ai offer market", + "caveat": "Direct-only and billable; capacity, quota, and offer availability vary", + "docs": "vast.md" + }, "ovh": { "status": "built-in", "category": "direct-cloud", diff --git a/docs/providers/vast.md b/docs/providers/vast.md new file mode 100644 index 000000000..447254eab --- /dev/null +++ b/docs/providers/vast.md @@ -0,0 +1,293 @@ +# Vast Provider + +Read this when you are: + +- choosing `provider: vast`; +- validating a direct Vast.ai SSH lease; +- changing `internal/providers/vast` or the guarded live smoke. + +Vast is a Linux-only **SSH lease** provider for Vast.ai GPU instances. Crabbox +searches Vast offers, creates one `ssh_direct` instance from the selected offer, +injects a per-lease SSH key, marks the instance with a compact Crabbox ownership +label, waits for the direct SSH endpoint, and then uses the normal Crabbox SSH +sync/run/status/list/stop/cleanup path. + +Vast is **direct-only** in this release. It does not run through the Crabbox +coordinator, so the local CLI must have a Vast API key and direct cleanup +remains the operator's responsibility. Vast instances are billable while they +are running. The default release action destroys the instance. + +## When To Use It + +Use Vast when you need a direct Linux GPU lease and local Vast credentials are +acceptable. Prefer AWS, Azure, GCP, or Hetzner when you need brokered team +credentials, coordinator-side cost accounting, or non-GPU cloud VM coverage. +Prefer Lambda, Nebius, RunPod, or NVIDIA Brev when those provider catalogs, +images, or account policies are a better fit for the workload. + +## Commands + +```sh +crabbox doctor --provider vast +crabbox warmup --provider vast --vast-gpu-name "RTX 4090" --keep +crabbox run --provider vast --vast-gpu-count 1 --no-sync -- nvidia-smi +crabbox ssh --provider vast --id my-app +crabbox stop --provider vast my-app +crabbox cleanup --provider vast --dry-run +``` + +Aliases: `vast-ai`, `vastai`. + +`--id` accepts the canonical lease id (`cbx_...`), the friendly slug, or the +Vast instance id when it resolves to a complete Crabbox-owned Vast instance. +`--class` and `--type` are not supported for `provider=vast`; use the +Vast-specific GPU, image, and offer-selection flags instead. + +## Configuration + +```yaml +provider: vast +target: linux +vast: + apiUrl: https://console.vast.ai/api/v0 + instanceType: ondemand + gpuName: "" + gpuCount: 0 + image: nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 + templateId: "" + runtype: ssh_direct + diskGB: 20 + maxDphTotal: 0 + minReliability: 0 + order: dlperf_per_dphtotal desc + user: root + workRoot: /work/crabbox + releaseAction: destroy +``` + +Config keys under `vast:`: + +| Key | Maps to | Default | Notes | +| --- | --- | --- | --- | +| `apiUrl` | `cfg.Vast.APIURL` | `https://console.vast.ai/api/v0` | Absolute Vast REST API URL without credentials. HTTPS is required except for localhost test endpoints. | +| `instanceType` | `cfg.Vast.InstanceType` | `ondemand` | Offer type, `ondemand` or `interruptible`; `on-demand` is normalized. | +| `gpuName` | `cfg.Vast.GPUName` | empty | Optional Vast GPU name selector. | +| `gpuCount` | `cfg.Vast.GPUCount` | `0` | Minimum GPU count when greater than zero. | +| `image` | `cfg.Vast.Image` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` | Docker image requested from Vast for the instance. | +| `templateId` | `cfg.Vast.TemplateID` | empty | Optional Vast template id. | +| `runtype` | `cfg.Vast.Runtype` | `ssh_direct` | Only `ssh_direct` is supported. | +| `diskGB` | `cfg.Vast.DiskGB` | `20` | Requested disk size in GB. | +| `maxDphTotal` | `cfg.Vast.MaxDphTotal` | `0` | Maximum dollars per hour when greater than zero. | +| `minReliability` | `cfg.Vast.MinReliability` | `0` | Minimum reliability score from 0 to 1 when greater than zero. | +| `order` | `cfg.Vast.Order` | `dlperf_per_dphtotal desc` | Vast offer ordering expression. | +| `user` | `cfg.Vast.User` | `root` | SSH user. Explicit generic `ssh.user` still wins. | +| `workRoot` | `cfg.Vast.WorkRoot` | `/work/crabbox` | Remote work root for Crabbox sync and commands. | +| `releaseAction` | `cfg.Vast.ReleaseAction` | `destroy` | `destroy`/`delete`, `stop`, or `keep`. | + +Provider flags: + +```text +--vast-api-url +--vast-instance-type +--vast-gpu-name +--vast-gpu-count +--vast-image +--vast-template-id +--vast-runtype +--vast-disk-gb +--vast-max-dph-total +--vast-min-reliability +--vast-order +--vast-user +--vast-work-root +--vast-release-action +``` + +Environment overrides: + +```text +CRABBOX_VAST_API_KEY Vast API key for direct mode +VAST_API_KEY Fallback Vast API key +CRABBOX_VAST_API_URL Override the API URL +VAST_API_URL Fallback API URL override +CRABBOX_VAST_INSTANCE_TYPE Override `ondemand` or `interruptible` +CRABBOX_VAST_GPU_NAME Override the GPU name selector +CRABBOX_VAST_GPU_COUNT Override the minimum GPU count +CRABBOX_VAST_IMAGE Override the Docker image +CRABBOX_VAST_TEMPLATE_ID Override the template id +CRABBOX_VAST_RUNTYPE Override the runtime type; must be `ssh_direct` +CRABBOX_VAST_DISK_GB Override disk size in GB +CRABBOX_VAST_MAX_DPH_TOTAL Override maximum dollars per hour +CRABBOX_VAST_MIN_RELIABILITY Override minimum reliability score +CRABBOX_VAST_ORDER Override offer ordering +CRABBOX_VAST_USER Override the SSH user +CRABBOX_VAST_WORK_ROOT Override the remote work root +CRABBOX_VAST_RELEASE_ACTION Override release action +``` + +Do not pass the Vast API key as a command-line argument or store it in +repository config. Crabbox reads it from `CRABBOX_VAST_API_KEY` or +`VAST_API_KEY` and sends it only in the `Authorization: Bearer ...` header. + +## Token Scope + +The provider uses Vast account identity, offer search, instances, instance +state updates, instance destroy, and instance SSH-key attach/detach APIs. +`crabbox doctor --provider vast` is read-only: it checks auth, lists instances, +counts Crabbox-owned Vast instances, and reports the default order, runtime, and +SSH user. + +Keep API keys in the environment or a local secret manager. Do not commit Vast +keys, generated private keys, instance API keys, user data, or Jupyter token +URLs. + +## Lifecycle + +1. Load the Vast API key from `CRABBOX_VAST_API_KEY` or `VAST_API_KEY`. +2. List instances and allocate a Crabbox slug. +3. Generate a per-lease SSH key in the Crabbox testbox key store. +4. Search Vast offers with the configured type, GPU name/count, reliability, + max dollars per hour, and ordering. +5. Create one `ssh_direct` Vast instance from the selected offer, with the + configured image, template, disk, user, and Crabbox environment marker. +6. Attach the per-lease public SSH key to the instance. +7. Wait until the instance is running and exposes a direct SSH host and port. +8. Update the Vast label from provisioning to ready. +9. Wait for Crabbox SSH bootstrap readiness and write a local lease claim. +10. Run normal Crabbox SSH sync, command execution, status, list, and cleanup. + +The provider requires Linux. It does not advertise desktop, browser, code-server, +Tailscale, coordinator, or provider-managed sync support in this release. +Actions hydration works only as normal command execution on the resulting Linux +SSH lease. + +If create or bootstrap becomes indeterminate after a Vast instance id is known, +Crabbox records a local recovery claim when possible. Retry +`crabbox stop --provider vast ` before deleting resources +manually so Crabbox can reconcile the instance and local key material. + +## Offer Selection + +By default Crabbox searches `ondemand` verified, rentable, not-rented offers +with at least one direct SSH port and orders by `dlperf_per_dphtotal desc`. +Narrow the search when the default catalog is too broad: + +```sh +crabbox run \ + --provider vast \ + --vast-instance-type interruptible \ + --vast-gpu-name "H100" \ + --vast-gpu-count 1 \ + --vast-max-dph-total 4.25 \ + --vast-min-reliability 0.95 \ + --no-sync \ + -- nvidia-smi +``` + +`vast.maxDphTotal` and `--vast-max-dph-total` are guardrails for offer search, +not a billing cap enforced by Crabbox after provisioning. Review the selected +offer and Vast account billing before running long jobs. + +## Release And Cleanup + +The default release action is `destroy`, which deletes the Vast instance, +detaches the Crabbox-managed instance SSH key when its key id is known, removes +the local claim, and removes the local per-lease key. + +Release actions: + +- `destroy` or `delete`: destroy the Vast instance on `stop` or one-shot release. +- `stop`: request Vast to stop the instance and remove the local Crabbox claim. +- `keep`: leave the instance and local claim untouched during release. + +Use `stop` or `keep` only when you explicitly accept the retained resource and +its billing implications. Direct mode has no coordinator alarm. + +Cleanup only mutates instances with a complete Crabbox Vast ownership label and +a matching local claim: + +```sh +crabbox list --provider vast --json +crabbox cleanup --provider vast --dry-run +crabbox cleanup --provider vast +``` + +Crabbox refuses to operate on non-Crabbox Vast instances, changed ownership +labels, stale local claims, missing local claims for destructive release, or +instances whose provider identity no longer matches the local claim. Vast labels +are compact strings beginning with `cbx1|`; they encode the lease id, slug, and +state. + +## Guarded Live Smoke + +The repeatable live check is opt-in and billable: + +```sh +CRABBOX_LIVE=1 CRABBOX_LIVE_PROVIDERS=vast scripts/live-vast-smoke.sh +``` + +The script builds `bin/crabbox`, reads `CRABBOX_VAST_API_KEY` or +`VAST_API_KEY`, requires an empty Crabbox-owned Vast inventory, creates one kept +lease, waits for ready status, runs `nvidia-smi`, verifies `list --json`, stops +the lease, runs dry-run cleanup, and verifies the Crabbox-owned inventory is +empty afterward. + +Optional live-smoke overrides: + +```text +CRABBOX_LIVE_VAST_GPU_NAME GPU name selector, default empty +CRABBOX_LIVE_VAST_GPU_COUNT Minimum GPU count, default 1 +CRABBOX_LIVE_VAST_MAX_DPH_TOTAL Max dollars per hour for offer search, default 0 +CRABBOX_LIVE_VAST_INSTANCE_TYPE Offer type, default ondemand +CRABBOX_LIVE_VAST_IMAGE Docker image, default from provider config +CRABBOX_LIVE_VAST_RELEASE_ACTION Release action, default destroy +``` + +Final classifications include: + +```text +classification=live_vast_smoke_passed +classification=environment_blocked +classification=billing_blocked +classification=quota_blocked +classification=capacity_blocked +classification=validation_failed +classification=cleanup_failed +``` + +Missing opt-in flags, missing credentials, auth failures, disabled account API +access, missing GPU offers, billing blocks, quota blocks, and capacity blocks +are reported as classified outcomes. The script redacts `VAST_API_KEY`, +`CRABBOX_VAST_API_KEY`, `instance_api_key`, Jupyter tokens, user data, private +keys, and URLs carrying token-like query parameters from diagnostics. + +If cleanup fails, use the reported slug and Vast instance id with +`crabbox list --provider vast --json`, `crabbox stop --provider vast `, +and the Vast console. Do not delete unrelated instances that only look similar. + +## Capabilities + +- **OS targets**: Linux only. +- **SSH**: yes, Crabbox-managed SSH over Vast direct SSH endpoints. +- **Crabbox sync**: yes, rsync over SSH. +- **Provider-managed sync**: no. +- **GPU**: yes, provider catalog dependent. +- **Coordinator**: no; direct CLI only. +- **Cleanup**: yes, ownership-label and local-claim guarded. +- **Desktop / browser / code-server**: not advertised in this release. +- **Tailscale**: not advertised in this release. + +## Gotchas + +- Vast is direct-only. Coordinator secrets, usage limits, and cost accounting do + not cover these instances. +- Vast offers are capacity-sensitive. No matching offer, quota, billing, or + capacity failures are external blockers, not docs-check failures. +- `ssh_direct` is required. Other Vast runtime types are rejected by config + validation. +- The default image is CUDA-oriented. If it lacks workload dependencies, install + them in your repo setup or select a different image/template. +- `stop` and `keep` can retain billable resources. Use `destroy` for normal + one-shot Crabbox validation. +- Destructive release requires local Crabbox claim state. Keep the claim until + cleanup is complete. diff --git a/docs/source-map.md b/docs/source-map.md index 515e11327..f9df46484 100644 --- a/docs/source-map.md +++ b/docs/source-map.md @@ -88,9 +88,10 @@ SSH-lease providers: - Canonical Multipass local Ubuntu VM: `internal/providers/multipass` - Cirrus Labs tart local macOS VM: `internal/providers/tart` - Microsoft Hyper-V local Windows VM: `internal/providers/hyperv` -- Daytona, Morph, exe.dev, KubeVirt, External, Tenki, Namespace devbox, RunPod, Semaphore, Sprites, Lambda: +- Daytona, Morph, exe.dev, KubeVirt, External, Tenki, Namespace devbox, RunPod, Semaphore, Sprites, Lambda, Vast: `internal/providers/daytona`, `internal/providers/morph`, `internal/providers/exedev`, `internal/providers/kubevirt`, `internal/providers/external`, `internal/providers/tenki`, `internal/providers/namespace`, - `internal/providers/runpod`, `internal/providers/semaphore`, `internal/providers/sprites`, `internal/providers/lambda` + `internal/providers/runpod`, `internal/providers/semaphore`, `internal/providers/sprites`, `internal/providers/lambda`, + `internal/providers/vast`, live smoke `scripts/live-vast-smoke.sh` Delegated-run providers (no SSH lease): @@ -148,7 +149,7 @@ Actions hydration or repo scripts. Provider docs: - Per-provider feature notes: `docs/features/aws.md`, `docs/features/azure.md`, `docs/features/hetzner.md`, `docs/features/blacksmith-testbox.md`, `docs/features/namespace-devbox.md`, `docs/features/namespace-devbox-setup.md`, `docs/features/semaphore.md`, `docs/features/sprites.md`, `docs/features/daytona.md`, `docs/features/islo.md`, `docs/features/e2b.md` -- Per-provider reference: `docs/providers/README.md` plus one file per provider under `docs/providers/`, including `docs/providers/aws-lambda-microvm.md` for Lambda MicroVM delegated execution and its runner image, `docs/providers/lambda.md` for the direct Lambda GPU SSH lease provider, `docs/providers/blaxel.md` for delegated Blaxel sandbox execution, `docs/providers/apple-vz.md` for the local Apple Silicon `Virtualization.framework` path, `docs/providers/digitalocean.md` for the direct Droplet provider, `docs/providers/vultr.md` for the direct Vultr provider, `docs/providers/ovh.md` for the direct OVHcloud provider, `docs/providers/incus.md` for the separate local live validation contract, `docs/providers/superserve.md` for delegated Superserve execution and live proof, and `docs/providers/cloudflare-sandbox.md` for Cloudflare Sandbox bridge-backed delegated Linux execution +- Per-provider reference: `docs/providers/README.md` plus one file per provider under `docs/providers/`, including `docs/providers/aws-lambda-microvm.md` for Lambda MicroVM delegated execution and its runner image, `docs/providers/lambda.md` for the direct Lambda GPU SSH lease provider, `docs/providers/vast.md` for the direct Vast.ai GPU SSH lease provider, `docs/providers/blaxel.md` for delegated Blaxel sandbox execution, `docs/providers/apple-vz.md` for the local Apple Silicon `Virtualization.framework` path, `docs/providers/digitalocean.md` for the direct Droplet provider, `docs/providers/vultr.md` for the direct Vultr provider, `docs/providers/ovh.md` for the direct OVHcloud provider, `docs/providers/incus.md` for the separate local live validation contract, `docs/providers/superserve.md` for delegated Superserve execution and live proof, and `docs/providers/cloudflare-sandbox.md` for Cloudflare Sandbox bridge-backed delegated Linux execution - Provider selection, landscape, live-smoke, and backend authoring guide: `docs/features/provider-selection.md`, `docs/features/provider-landscape.md`, `docs/features/provider-live-smoke.md`, `docs/provider-backends.md`, `docs/features/provider-authoring.md` - Tailscale contract: `docs/features/tailscale.md` @@ -236,5 +237,5 @@ Provider docs: - Release workflow and Homebrew tap fallback: `.github/workflows/release.yml` - GoReleaser archives and Homebrew formula config: `.goreleaser.yaml` - Docs command-surface check, link check, site builder, and Pages deploy: `scripts/check-command-docs.mjs`, `scripts/check-docs-links.mjs`, `scripts/build-docs-site.mjs`, `.github/workflows/pages.yml` -- Live provider smoke coverage: `scripts/live-smoke.sh`, plus provider-specific guarded smokes such as `scripts/live-blaxel-smoke.sh`, `scripts/live-digitalocean-smoke.sh`, `scripts/live-vultr-smoke.sh`, and `scripts/live-superserve-smoke.sh` +- Live provider smoke coverage: `scripts/live-smoke.sh`, plus provider-specific guarded smokes such as `scripts/live-blaxel-smoke.sh`, `scripts/live-digitalocean-smoke.sh`, `scripts/live-vultr-smoke.sh`, `scripts/live-vast-smoke.sh`, and `scripts/live-superserve-smoke.sh` - Live coordinator auth smoke coverage: `scripts/live-auth-smoke.sh` diff --git a/scripts/live-vast-smoke.sh b/scripts/live-vast-smoke.sh new file mode 100755 index 000000000..1ee1cb541 --- /dev/null +++ b/scripts/live-vast-smoke.sh @@ -0,0 +1,311 @@ +#!/usr/bin/env bash +set -euo pipefail + +provider_enabled() { + local list="${CRABBOX_LIVE_PROVIDERS:-}" + local item + IFS=',' read -ra items <<<"$list" + for item in "${items[@]}"; do + item="${item//[[:space:]]/}" + if [[ "$item" == "vast" || "$item" == "vast-ai" || "$item" == "vastai" ]]; then + return 0 + fi + done + return 1 +} + +redact_output() { + VAST_SMOKE_REDACT_TOKEN_1="${CRABBOX_VAST_API_KEY:-}" \ + VAST_SMOKE_REDACT_TOKEN_2="${VAST_API_KEY:-}" python3 -c ' +import os +import re +import sys + +body = sys.stdin.read() +for token in (os.environ.get("VAST_SMOKE_REDACT_TOKEN_1", ""), os.environ.get("VAST_SMOKE_REDACT_TOKEN_2", "")): + if token: + body = body.replace(token, "") +fields = ( + "api_key", + "instance_api_key", + "instanceApiKey", + "jupyter_token", + "jupyterToken", + "jupyter_url", + "jupyterUrl", + "user_data", + "userData", + "private_key", + "privateKey", + "ssh_key", +) +for field in fields: + body = re.sub(rf"(\"{re.escape(field)}\"\s*:\s*\")[^\"]*(\")", rf"\1\2", body, flags=re.IGNORECASE) + body = re.sub(rf"({re.escape(field)}\s*[=:]\s*)[^\",\s]+", rf"\1", body, flags=re.IGNORECASE) +body = re.sub(r"-----BEGIN [A-Z ]*PRIVATE KEY-----[\s\S]*?-----END [A-Z ]*PRIVATE KEY-----", "", body) +body = re.sub(r"https?://[^\s\"'\''<>]*(?:token|api_key|apikey|auth|signature|sig|access_token)=[^\s\"'\''<>]+", "", body, flags=re.IGNORECASE) +sys.stdout.write(body) +' +} + +classify_known_external_blocker() { + local command="$1" + local status="$2" + local output="$3" + local classification="" + local lower + lower="$(printf '%s' "$output" | tr '[:upper:]' '[:lower:]')" + if [[ "$lower" == *billing* || "$lower" == *fund* || "$lower" == *payment* || "$lower" == *balance* || "$lower" == *credit* ]]; then + classification="billing_blocked" + elif [[ "$lower" == *quota* || "$lower" == *"rate limit"* || "$lower" == *"account limit"* || "$lower" == *"limit exceeded"* ]]; then + classification="quota_blocked" + elif [[ "$lower" == *capacity* || "$lower" == *"no eligible offers"* || "$lower" == *"no offers"* || "$lower" == *"no matching"* || "$lower" == *"not enough"* || "$lower" == *"insufficient"* || "$lower" == *"resource exhausted"* || "$lower" == *"found no eligible offers"* ]]; then + classification="capacity_blocked" + elif [[ "$lower" == *unauthorized* || "$lower" == *forbidden* || "$lower" == *"permission denied"* || "$lower" == *"invalid api"* || "$lower" == *"invalid-api"* || "$lower" == *"missing_token"* || "$lower" == *"api key"* ]]; then + classification="environment_blocked" + else + return 1 + fi + printf 'classification=%s command=%q exit=%s\n' "$classification" "$command" "$status" >&2 + printf '%s\n' "$output" | redact_output >&2 + return 0 +} + +classify_validation_failure() { + local command="$1" + local status="$2" + local output="$3" + printf 'classification=validation_failed command=%q exit=%s\n' "$command" "$status" >&2 + printf '%s\n' "$output" | redact_output >&2 +} + +run_capture() { + local command="$1" + shift + local output + set +e + output="$("$@" 2>&1)" + local status=$? + set -e + if [[ "$status" -ne 0 ]]; then + if classify_known_external_blocker "$command" "$status" "$output"; then + exit 0 + fi + classify_validation_failure "$command" "$status" "$output" + exit "$status" + fi + CAPTURED_OUTPUT="$(printf '%s\n' "$output" | redact_output)" +} + +run_capture_validation() { + local command="$1" + shift + local output + set +e + output="$("$@" 2>&1)" + local status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$output" + exit "$status" + fi + CAPTURED_OUTPUT="$(printf '%s\n' "$output" | redact_output)" +} + +validate_list_json_contains_slug() { + local command="$1" + local output="$2" + local validation_output="" + local status=0 + set +e + validation_output="$(CRABBOX_SMOKE_SLUG="$slug" python3 -c ' +import json +import os +import sys + +slug = os.environ["CRABBOX_SMOKE_SLUG"] +try: + payload = json.load(sys.stdin) +except Exception as exc: + print(f"invalid JSON: {exc}", file=sys.stderr) + sys.exit(1) + +def has_slug(value): + if isinstance(value, dict): + labels = value.get("labels") or value.get("tags") + if isinstance(labels, dict) and (labels.get("slug") == slug or labels.get("crabbox.slug") == slug): + return True + label = str(value.get("label", "")) + if f"|{slug}|" in label or value.get("slug") == slug or value.get("name") == slug or value.get("id") == slug or value.get("leaseId") == slug: + return True + return any(has_slug(child) for child in value.values()) + if isinstance(value, list): + return any(has_slug(child) for child in value) + return False + +if not has_slug(payload): + print(f"list JSON did not include slug {slug}", file=sys.stderr) + sys.exit(1) +' <<<"$output" 2>&1)" + status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$validation_output" + exit "$status" + fi +} + +validate_list_json_empty() { + local command="$1" + local output="$2" + local validation_output="" + local status=0 + set +e + validation_output="$(python3 -c ' +import json +import sys + +try: + payload = json.load(sys.stdin) +except Exception as exc: + print(f"invalid JSON: {exc}", file=sys.stderr) + sys.exit(1) + +if payload != []: + print("Vast Crabbox inventory is not empty", file=sys.stderr) + sys.exit(1) +' <<<"$output" 2>&1)" + status=$? + set -e + if [[ "$status" -ne 0 ]]; then + classify_validation_failure "$command" "$status" "$validation_output" + exit "$status" + fi +} + +validate_nvidia_smi() { + local command="$1" + local output="$2" + if [[ "$output" != *"NVIDIA-SMI"* && "$output" != *"NVIDIA"* ]]; then + classify_validation_failure "$command" 1 "remote command output did not include NVIDIA-SMI output" + exit 1 + fi +} + +cleanup_armed=0 +slug="" +config_file="" + +cleanup() { + local status=$? + if [[ "$cleanup_armed" -eq 1 && -n "$slug" ]]; then + local cleanup_output="" + local cleanup_status=1 + local attempt + for attempt in 1 2 3; do + set +e + cleanup_output="$(bin/crabbox stop --provider vast "$slug" 2>&1)" + cleanup_status=$? + set -e + if [[ "$cleanup_status" -eq 0 ]]; then + cleanup_armed=0 + break + fi + sleep 2 + done + if [[ "$cleanup_status" -ne 0 ]]; then + printf 'classification=cleanup_failed command=%q exit=%s slug=%s\n' "bin/crabbox stop --provider vast $slug" "$cleanup_status" "$slug" >&2 + printf '%s\n' "$cleanup_output" | redact_output >&2 + if [[ "$status" -eq 0 ]]; then + status="$cleanup_status" + fi + fi + fi + if [[ -n "$config_file" ]]; then + rm -f "$config_file" + fi + exit "$status" +} +trap cleanup EXIT + +if [[ "${CRABBOX_LIVE:-}" != "1" ]]; then + printf 'classification=environment_blocked reason=CRABBOX_LIVE_not_enabled\n' + exit 0 +fi + +if ! provider_enabled; then + printf 'classification=environment_blocked reason=vast_not_selected providers=%q\n' "${CRABBOX_LIVE_PROVIDERS:-}" + exit 0 +fi + +if [[ -z "${CRABBOX_VAST_API_KEY:-}${VAST_API_KEY:-}" ]]; then + printf 'classification=environment_blocked reason=VAST_API_KEY_missing\n' + exit 0 +fi + +mkdir -p bin +go build -trimpath -o bin/crabbox ./cmd/crabbox + +slug="vast-smoke-$(date +%Y%m%d%H%M%S)-$$" +vast_gpu_name="${CRABBOX_LIVE_VAST_GPU_NAME:-}" +vast_gpu_count="${CRABBOX_LIVE_VAST_GPU_COUNT:-1}" +vast_max_dph_total="${CRABBOX_LIVE_VAST_MAX_DPH_TOTAL:-0}" +vast_instance_type="${CRABBOX_LIVE_VAST_INSTANCE_TYPE:-ondemand}" +vast_image="${CRABBOX_LIVE_VAST_IMAGE:-nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04}" +vast_release_action="${CRABBOX_LIVE_VAST_RELEASE_ACTION:-destroy}" + +config_file="$(mktemp)" +cat >"$config_file" <"$out" <<'SCRIPT' +${scriptBody} +SCRIPT +chmod +x "$out" +`, + ); +} + +const shellArgHelper = ` +arg_after() { + local want="$1" + shift + while [[ "$#" -gt 0 ]]; do + if [[ "$1" == "$want" ]]; then + printf '%s' "$2" + return 0 + fi + shift + done + return 1 +} +`; + +test("live vast smoke skips unless globally opted in", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-skip-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "", + VAST_API_KEY: "", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=CRABBOX_LIVE_not_enabled/); +}); + +test("live vast smoke skips unless vast is selected", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-provider-skip-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "lambda", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=vast_not_selected/); +}); + +test("live vast smoke requires token before building", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-token-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + fs.mkdirSync(binDir, { recursive: true }); + writeExecutable(path.join(binDir, "go"), "#!/usr/bin/env bash\nexit 99\n"); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "", + VAST_API_KEY: "", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stderr); + assert.match(result.stdout, /classification=environment_blocked reason=VAST_API_KEY_missing/); +}); + +test("live vast smoke runs guarded lifecycle and redacts secret material", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + const slugFile = path.join(dir, "slug.txt"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +${shellArgHelper} +printf '%s\\n' "$*" >>"${calls}" +if [[ "\${CRABBOX_VAST_API_KEY:-}" != "test-secret-token" ]]; then + printf 'missing token\\n' >&2 + exit 91 +fi +case "$1" in + doctor) + printf 'auth=ready control_plane=ready inventory=ready api=list mutation=false leases=0 runtime=unchecked api_key=test-secret-token instance_api_key=visible jupyter_url=https://example.test/?token=abc\\n' + ;; + warmup) + arg_after --slug "$@" >"${slugFile}" + ;; + status) + printf 'status=ready\\n' + ;; + run) + printf 'NVIDIA-SMI 570.00.00\\n' + ;; + list) + slug="$(cat "${slugFile}" 2>/dev/null || true)" + if [[ -z "$slug" || -f "${slugFile}.stopped" ]]; then + printf '[]\\n' + else + printf '[{"labels":{"slug":"%s"},"provider":"vast"}]\\n' "$slug" + fi + ;; + stop) + printf stopped >"${slugFile}.stopped" + ;; + cleanup) + printf 'skip vast instance id=none name=none reason=missing labels\\n' + ;; + *) + printf 'unexpected args: %s\\n' "$*" >&2 + exit 99 + ;; +esac +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: " vast-ai ", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stdout, /classification=live_vast_smoke_passed/); + assert.doesNotMatch(result.stdout + result.stderr, /test-secret-token|visible|token=abc/); + + const seen = fs.readFileSync(calls, "utf8").trim().split("\n"); + assert.equal(seen[0], "doctor --provider vast"); + assert.equal(seen[1], "list --provider vast --json"); + assert.match(seen[2], /^warmup --provider vast --slug vast-smoke-\d{14}-\d+ --keep --ttl 20m --idle-timeout 5m$/); + assert.match(seen[3], /^status --provider vast --id vast-smoke-\d{14}-\d+ --wait --wait-timeout 600s$/); + assert.match(seen[4], /^run --provider vast --id vast-smoke-\d{14}-\d+ --no-sync -- nvidia-smi$/); + assert.equal(seen[5], "list --provider vast --json"); + assert.match(seen[6], /^stop --provider vast vast-smoke-\d{14}-\d+$/); + assert.equal(seen[7], "cleanup --provider vast --dry-run"); + assert.equal(seen[8], "list --provider vast --json"); +}); + +test("live vast smoke attempts cleanup after partial failure", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-fail-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$*" >>"${calls}" +if [[ "$1" == "doctor" || "$1" == "list" ]]; then + [[ "$1" == "list" ]] && printf '[]\\n' || printf 'auth=ready\\n' + exit 0 +fi +if [[ "$1" == "warmup" ]]; then + printf 'no eligible offers after partial create\\n' >&2 + exit 37 +fi +if [[ "$1" == "stop" ]]; then + exit 0 +fi +exit 99 +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stderr, /classification=capacity_blocked/); + assert.match(result.stderr, /no eligible offers/); + assert.match(fs.readFileSync(calls, "utf8"), /warmup .* --keep /); + assert.match(fs.readFileSync(calls, "utf8"), /stop --provider vast vast-smoke-\d{14}-\d+/); +}); + +test("live vast smoke validates nvidia-smi output", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-bad-run-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const slugFile = path.join(dir, "slug.txt"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +${shellArgHelper} +case "$1" in + doctor) + printf 'auth=ready\\n' + ;; + warmup) + arg_after --slug "$@" >"${slugFile}" + ;; + status) + printf 'status=ready\\n' + ;; + run) + printf 'ok\\n' + ;; + list) + slug="$(cat "${slugFile}" 2>/dev/null || true)" + if [[ -z "$slug" || -f "${slugFile}.stopped" ]]; then + printf '[]\\n' + else + printf '[{"labels":{"slug":"%s"}}]\\n' "$slug" + fi + ;; + stop) + printf stopped >"${slugFile}.stopped" + ;; + cleanup) + printf 'cleanup dry-run\\n' + ;; + *) + exit 99 + ;; +esac +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 1, result.stdout + result.stderr); + assert.match(result.stderr, /classification=validation_failed/); + assert.match(result.stderr, /NVIDIA-SMI/); +}); From 0fd5cb8b1d88178897a165240eb2b0e2623d049c Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:20:22 -0700 Subject: [PATCH 05/18] test(worker): decode Vitest mock module path Use fileURLToPath for the Cloudflare Workers mock alias so Vitest resolves the local mock correctly from repository paths that contain spaces. --- worker/vitest.config.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/worker/vitest.config.ts b/worker/vitest.config.ts index 4db9ee006..1907ab03f 100644 --- a/worker/vitest.config.ts +++ b/worker/vitest.config.ts @@ -1,10 +1,13 @@ +import { fileURLToPath } from "node:url"; + import { defineConfig } from "vitest/config"; export default defineConfig({ resolve: { alias: { - "cloudflare:workers": new URL("./test/cloudflare-workers-runtime.ts", import.meta.url) - .pathname, + "cloudflare:workers": fileURLToPath( + new URL("./test/cloudflare-workers-runtime.ts", import.meta.url), + ), }, }, test: { From 80607476950f52725b7ad813d46c8e477063732c Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:25:40 -0700 Subject: [PATCH 06/18] fix(vast): preserve explicit SSH user override Respect the generic SSH user override when Vast provider defaults are applied, so provider-specific defaults cannot replace an operator-supplied SSH user. Add regression coverage for both the Vast backend config and embedded direct SSH backend config. --- internal/providers/vast/backend.go | 5 ++++- internal/providers/vast/backend_test.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index 4c7ad3ca5..282110966 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -60,7 +60,10 @@ func applyVastDefaults(cfg *core.Config) { if cfg.TargetOS == "" { cfg.TargetOS = core.TargetLinux } - if cfg.Vast.User != "" { + if core.IsSSHUserExplicit(cfg) { + // The generic SSH user is the operator-facing override. Vast.User only + // provides the provider default when that override was not used. + } else if cfg.Vast.User != "" { cfg.SSHUser = cfg.Vast.User } else if cfg.SSHUser == "" { cfg.SSHUser = "root" diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index ace2fd1ba..73db130df 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -196,6 +196,22 @@ func newTestBackend(t *testing.T, api *fakeVastAPI) *backend { return b } +func TestNewBackendPreservesExplicitGenericSSHUser(t *testing.T) { + cfg := core.BaseConfig() + cfg.Provider = providerName + cfg.SSHUser = "ubuntu" + core.MarkSSHUserExplicit(&cfg) + cfg.Vast.User = "root" + + b := newBackend(Provider{}.Spec(), cfg, core.Runtime{Stderr: io.Discard}) + if b.cfg.SSHUser != "ubuntu" { + t.Fatalf("backend SSHUser=%q want explicit generic user", b.cfg.SSHUser) + } + if b.DirectSSHBackend.Cfg.SSHUser != "ubuntu" { + t.Fatalf("direct SSH backend SSHUser=%q want explicit generic user", b.DirectSSHBackend.Cfg.SSHUser) + } +} + func TestDoctorIsReadOnlyAndCountsOwnedInventory(t *testing.T) { api := &fakeVastAPI{instances: []vastInstance{ {ID: 1, Label: encodeVastOwnershipLabel("cbx_owned", "owned", "ready"), Status: "running"}, From 102e0f7678abd0652f450c10b12c61c935b35387 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:30:58 -0700 Subject: [PATCH 07/18] fix(vast): preserve live smoke capacity blockers Keep cleanup enabled for possible partial Vast warmup failures, but do not let a pre-acquire not-found stop response convert a known capacity blocker into cleanup_failed. Add regression coverage for the no-lease cleanup path. --- scripts/live-vast-smoke.sh | 21 ++++++++++++--- scripts/live-vast-smoke.test.js | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/scripts/live-vast-smoke.sh b/scripts/live-vast-smoke.sh index 1ee1cb541..91cbb58da 100755 --- a/scripts/live-vast-smoke.sh +++ b/scripts/live-vast-smoke.sh @@ -193,9 +193,17 @@ validate_nvidia_smi() { } cleanup_armed=0 +lease_acquired=0 slug="" config_file="" +is_pre_acquire_not_found() { + local output="$1" + local lower + lower="$(printf '%s' "$output" | tr '[:upper:]' '[:lower:]')" + [[ "$lower" == *"not found"* || "$lower" == *"no lease"* || "$lower" == *"unknown lease"* || "$lower" == *"missing lease"* ]] +} + cleanup() { local status=$? if [[ "$cleanup_armed" -eq 1 && -n "$slug" ]]; then @@ -214,10 +222,14 @@ cleanup() { sleep 2 done if [[ "$cleanup_status" -ne 0 ]]; then - printf 'classification=cleanup_failed command=%q exit=%s slug=%s\n' "bin/crabbox stop --provider vast $slug" "$cleanup_status" "$slug" >&2 - printf '%s\n' "$cleanup_output" | redact_output >&2 - if [[ "$status" -eq 0 ]]; then - status="$cleanup_status" + if [[ "$lease_acquired" -eq 0 && "$status" -eq 0 ]] && is_pre_acquire_not_found "$cleanup_output"; then + cleanup_armed=0 + else + printf 'classification=cleanup_failed command=%q exit=%s slug=%s\n' "bin/crabbox stop --provider vast $slug" "$cleanup_status" "$slug" >&2 + printf '%s\n' "$cleanup_output" | redact_output >&2 + if [[ "$status" -eq 0 ]]; then + status="$cleanup_status" + fi fi fi fi @@ -283,6 +295,7 @@ validate_list_json_empty "bin/crabbox list --provider vast --json" "$initial_lis cleanup_armed=1 run_capture "bin/crabbox warmup --provider vast --slug $slug --keep --ttl 20m --idle-timeout 5m" \ bin/crabbox warmup --provider vast --slug "$slug" --keep --ttl 20m --idle-timeout 5m +lease_acquired=1 run_capture_validation "bin/crabbox status --provider vast --id $slug --wait --wait-timeout 600s" \ bin/crabbox status --provider vast --id "$slug" --wait --wait-timeout 600s diff --git a/scripts/live-vast-smoke.test.js b/scripts/live-vast-smoke.test.js index 0b96a771f..3f55a2dc4 100644 --- a/scripts/live-vast-smoke.test.js +++ b/scripts/live-vast-smoke.test.js @@ -258,6 +258,52 @@ exit 99 assert.match(fs.readFileSync(calls, "utf8"), /stop --provider vast vast-smoke-\d{14}-\d+/); }); +test("live vast smoke preserves capacity classification when pre-acquire cleanup finds no lease", () => { + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-no-lease-")); + const binDir = path.join(dir, "bin"); + const { tempRoot, smokeScript } = prepareSmokeRepo(dir); + const calls = path.join(dir, "calls.log"); + fs.mkdirSync(binDir, { recursive: true }); + + writeGoStub( + binDir, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$*" >>"${calls}" +if [[ "$1" == "doctor" || "$1" == "list" ]]; then + [[ "$1" == "list" ]] && printf '[]\\n' || printf 'auth=ready\\n' + exit 0 +fi +if [[ "$1" == "warmup" ]]; then + printf 'no eligible offers found\\n' >&2 + exit 37 +fi +if [[ "$1" == "stop" ]]; then + printf 'lease not found\\n' >&2 + exit 44 +fi +exit 99 +`, + ); + + const result = spawnSync("bash", [smokeScript], { + cwd: tempRoot, + env: { + ...process.env, + PATH: `${binDir}${path.delimiter}${process.env.PATH ?? ""}`, + CRABBOX_LIVE: "1", + CRABBOX_LIVE_PROVIDERS: "vast", + CRABBOX_VAST_API_KEY: "test-secret-token", + }, + encoding: "utf8", + }); + + assert.equal(result.status, 0, result.stdout + result.stderr); + assert.match(result.stderr, /classification=capacity_blocked/); + assert.doesNotMatch(result.stderr, /classification=cleanup_failed/); + assert.match(fs.readFileSync(calls, "utf8"), /stop --provider vast vast-smoke-\d{14}-\d+/); +}); + test("live vast smoke validates nvidia-smi output", () => { const dir = fs.mkdtempSync(path.join(os.tmpdir(), "crabbox-live-vast-bad-run-")); const binDir = path.join(dir, "bin"); From 8c57825d37f3c91ffb0c016255b46572be067139 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:38:39 -0700 Subject: [PATCH 08/18] fix(config): preserve explicit Vast SSH user in show Keep config show from replacing an explicitly supplied generic SSH user with the Vast provider default. This keeps the displayed effective config aligned with the backend defaulting path. --- internal/cli/config_cmd.go | 4 +++- internal/cli/config_cmd_test.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/cli/config_cmd.go b/internal/cli/config_cmd.go index b9bdf9311..06ec130db 100644 --- a/internal/cli/config_cmd.go +++ b/internal/cli/config_cmd.go @@ -101,7 +101,9 @@ func effectiveConfigForShow(cfg Config) Config { switch normalizeProviderName(cfg.Provider) { case "vast", "vast-ai", "vastai": cfg.WorkRoot = cfg.Vast.WorkRoot - cfg.SSHUser = cfg.Vast.User + if !IsSSHUserExplicit(&cfg) { + cfg.SSHUser = cfg.Vast.User + } cfg.SSHPort = "22" cfg.SSHFallbackPorts = nil case "nvidia-brev", "brev", "nvidia": diff --git a/internal/cli/config_cmd_test.go b/internal/cli/config_cmd_test.go index d3bc800c7..f88c447b1 100644 --- a/internal/cli/config_cmd_test.go +++ b/internal/cli/config_cmd_test.go @@ -1188,6 +1188,8 @@ func TestConfigShowIncludesVastWithoutSecretSurface(t *testing.T) { t.Setenv("CRABBOX_CONFIG", configPath) t.Setenv("CRABBOX_VAST_API_KEY", "vast-redaction-fixture-secret") if err := os.WriteFile(configPath, []byte(`provider: vast +ssh: + user: ubuntu vast: apiUrl: https://user:secret@vast.example.test/api/v0?token=hidden instanceType: on-demand @@ -1266,7 +1268,7 @@ vast: got.Vast.ReleaseAction != "stop" { t.Fatalf("unexpected vast json: %#v", got.Vast) } - if got.SSHUser != "root" || got.SSHPort != "22" { + if got.SSHUser != "ubuntu" || got.SSHPort != "22" { t.Fatalf("unexpected vast ssh json: %#v", got) } if strings.Contains(stdout.String(), "vast-redaction-fixture-secret") || strings.Contains(stdout.String(), "user:secret") || strings.Contains(stdout.String(), "hidden") { From 23e8b2dc15815bd7fb0a2ec72de72664c672e859 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:38:45 -0700 Subject: [PATCH 09/18] fix(vast): align release reporting with lease policy Report the release action captured on the lease target instead of the current config, so stopped or kept Vast instances are not reported as destroyed. Also route cleanup of owned-looking instances without matching local claims through the safe missing-claim refusal path. --- internal/providers/vast/backend.go | 48 ++++++++++++++++++++++++- internal/providers/vast/backend_test.go | 40 +++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index 282110966..0325995fb 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -423,7 +423,7 @@ func (b *backend) ReleaseLease(ctx context.Context, req core.ReleaseLeaseRequest } func (b *backend) ReleaseLeaseMessage(lease core.LeaseTarget) string { - action := normalizeVastReleaseAction(b.cfg.Vast.ReleaseAction) + action := normalizeVastReleaseAction(firstNonBlank(lease.Server.Labels[vastReleaseActionLabel], b.cfg.Vast.ReleaseAction)) if action == "stop" || action == "keep" { return fmt.Sprintf("%s lease=%s vast=%s name=%s", action, lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) } @@ -453,9 +453,47 @@ func (b *backend) Cleanup(ctx context.Context, req core.CleanupRequest) error { if err != nil { return err } + servers, err = b.prepareCleanupServers(servers) + if err != nil { + return err + } return b.CleanupServers(ctx, req, servers) } +func (b *backend) prepareCleanupServers(servers []core.Server) ([]core.Server, error) { + for i := range servers { + updated, err := b.prepareCleanupServer(servers[i]) + if err != nil { + return nil, err + } + servers[i] = updated + } + return servers, nil +} + +func (b *backend) prepareCleanupServer(server core.Server) (core.Server, error) { + if server.Provider != providerName { + return server, nil + } + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server, nil + } + claim, claimExists, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil { + return core.Server{}, fmt.Errorf("read vast cleanup claim: %w", err) + } + if claimExists && claim.Provider == providerName && (claim.CloudID == "" || server.CloudID == "" || claim.CloudID == server.CloudID) { + return server, nil + } + + labels := cloneLabels(server.Labels) + labels["state"] = "expired" + labels["expires_at"] = core.LeaseLabelTime(b.now().Add(-time.Second)) + server.Labels = labels + return server, nil +} + func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.Server) error { if err := validateVastServer(server); err != nil { return err @@ -581,6 +619,14 @@ func vastLeaseLabels(cfg core.Config, leaseID, slug, state string, keep bool, no return labels } +func cloneLabels(labels map[string]string) map[string]string { + out := make(map[string]string, len(labels)) + for key, value := range labels { + out[key] = value + } + return out +} + func isOwnedVastInstance(item vastInstance) bool { return isVastCrabboxOwnedLabel(item.Label) } diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index 73db130df..d220ec115 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -379,6 +379,27 @@ func TestReleaseStopIsExplicitAndTested(t *testing.T) { } } +func TestReleaseLeaseMessageUsesPersistedReleaseAction(t *testing.T) { + b := newTestBackend(t, &fakeVastAPI{}) + b.cfg.Vast.ReleaseAction = "destroy" + lease := core.LeaseTarget{ + LeaseID: "cbx_message", + Server: core.Server{ + CloudID: "100", + Name: "message-me", + Provider: providerName, + Labels: map[string]string{ + vastReleaseActionLabel: "stop", + }, + }, + } + + msg := b.ReleaseLeaseMessage(lease) + if !strings.Contains(msg, "stop lease=cbx_message") || strings.Contains(msg, "destroyed") { + t.Fatalf("message=%q", msg) + } +} + func TestCleanupDryRunDoesNotDestroyExpiredOwnedInstance(t *testing.T) { api := &fakeVastAPI{instances: []vastInstance{{ID: 8, Label: encodeVastOwnershipLabel("cbx_old", "old", "ready"), Status: "running"}}} b := newTestBackend(t, api) @@ -401,6 +422,25 @@ func TestCleanupDryRunDoesNotDestroyExpiredOwnedInstance(t *testing.T) { } } +func TestCleanupReportsMissingClaimForOwnedInstance(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 9, Label: encodeVastOwnershipLabel("cbx_orphan", "orphan", "ready"), Status: "running"}}} + b := newTestBackend(t, api) + var stderr bytes.Buffer + b.rt.Stderr = &stderr + b.DirectSSHBackend.RT = b.rt + + err := b.Cleanup(context.Background(), core.CleanupRequest{}) + if err == nil || !strings.Contains(err.Error(), "lease=cbx_orphan has no local Vast claim") { + t.Fatalf("err=%v", err) + } + if len(api.destroyed) != 0 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if !strings.Contains(stderr.String(), "delete server id=9") { + t.Fatalf("stderr=%q", stderr.String()) + } +} + func TestManualUnownedCleanupIsRejected(t *testing.T) { api := &fakeVastAPI{instances: []vastInstance{{ID: 5, Label: "manual-instance", Status: "running", SSHHost: "203.0.113.5", SSHPort: 22}}} b := newTestBackend(t, api) From 4ec78a540043c353bccd638a7ce5ed5c4c2eb8f1 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:44:49 -0700 Subject: [PATCH 10/18] test(cli): relax fake SSH probe deadlines Give local fake-SSH probe tests the same headroom as other terminal-bound checks so transient scheduling delays do not fail unrelated provider verification. --- internal/cli/bootstrap_test.go | 2 +- internal/cli/connect_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/cli/bootstrap_test.go b/internal/cli/bootstrap_test.go index d04567b6e..2e3b566a8 100644 --- a/internal/cli/bootstrap_test.go +++ b/internal/cli/bootstrap_test.go @@ -790,7 +790,7 @@ func TestWindowsWSL2BootstrapCompleteProbeUsesWindowsMarker(t *testing.T) { target := bootstrapTarget target.WindowsMode = windowsModeWSL2 - if !probeWindowsWSL2BootstrapComplete(context.Background(), bootstrapTarget, &target, 5*time.Second) { + if !probeWindowsWSL2BootstrapComplete(context.Background(), bootstrapTarget, &target, 30*time.Second) { t.Fatal("setup marker probe should pass with fake ssh") } data, err := os.ReadFile(logPath) diff --git a/internal/cli/connect_test.go b/internal/cli/connect_test.go index 4bf36f090..5cc2d63d8 100644 --- a/internal/cli/connect_test.go +++ b/internal/cli/connect_test.go @@ -140,7 +140,7 @@ cat DisableHostKeyChecking: true, NoControlMaster: true, } - if !probeConnectSSHTransport(context.Background(), &target, 5*time.Second) { + if !probeConnectSSHTransport(context.Background(), &target, 30*time.Second) { t.Fatal("resolve fallback port failed") } err := runInteractiveSSH(context.Background(), target, strings.NewReader(""), &stdout, &stderr) From e71c6679e3d24fd6652ef4f185cc66cee821123d Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:44:54 -0700 Subject: [PATCH 11/18] fix(vast): preserve claim metadata on resolve Preserve persisted Vast release and key metadata when normal resolve refreshes local claim endpoints, so stop/keep leases do not drift back to destroy. Also allow status-only resolution for owned instances that do not yet expose an SSH endpoint. --- internal/providers/vast/backend.go | 36 +++++++++++++++++++++- internal/providers/vast/backend_test.go | 41 +++++++++++++++++++++++++ internal/providers/vast/provider.go | 20 ++++++++++++ 3 files changed, 96 insertions(+), 1 deletion(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index 0325995fb..70a4f3d5a 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -386,9 +386,10 @@ func (b *backend) targetFromInstance(item vastInstance, req core.ResolveRequest) return core.LeaseTarget{}, exit(5, "vast instance %d reached terminal status %s", item.ID, item.Status) } server := serverFromInstance(item, b.cfg) + server = mergeVastClaimMetadata(server) leaseID := server.Labels["lease"] target := core.LeaseTarget{Server: server, LeaseID: leaseID} - if !req.ReleaseOnly { + if !req.ReleaseOnly && !req.StatusOnly { ssh, err := sshTargetFromInstance(b.cfg, item) if err != nil { return core.LeaseTarget{}, err @@ -588,6 +589,22 @@ func mergeVastClaimLabels(server core.Server) core.Server { return server } +func mergeVastClaimMetadata(server core.Server) core.Server { + leaseID := strings.TrimSpace(server.Labels["lease"]) + if leaseID == "" { + return server + } + claim, ok, err := core.ReadLeaseClaimWithPresence(leaseID) + if err != nil || !ok || claim.Provider != providerName { + return server + } + if claim.CloudID != "" && server.CloudID != "" && claim.CloudID != server.CloudID { + return server + } + server.Labels = preserveVastClaimMetadata(server.Labels, claim.Labels) + return server +} + func serverFromInstance(item vastInstance, cfg core.Config) core.Server { labels := labelsFromVastInstance(item, cfg) server := core.Server{ @@ -627,6 +644,23 @@ func cloneLabels(labels map[string]string) map[string]string { return out } +func preserveVastClaimMetadata(labels, existing map[string]string) map[string]string { + out := cloneLabels(labels) + for _, key := range []string{ + vastReleaseActionLabel, + vastKeyIDLabel, + vastKeyOwnedLabel, + vastOfferIDLabel, + "provider_key", + "recovery", + } { + if value, ok := existing[key]; ok { + out[key] = value + } + } + return out +} + func isOwnedVastInstance(item vastInstance) bool { return isVastCrabboxOwnedLabel(item.Label) } diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index d220ec115..9141d5b0d 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -282,6 +282,34 @@ func TestAcquireCreatesAttachesPollsReadinessAndClaims(t *testing.T) { } } +func TestResolvePreservesPersistedVastClaimMetadata(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + repoRoot := t.TempDir() + b.cfg.Vast.ReleaseAction = "stop" + b.DirectSSHBackend.Cfg = b.cfg + if _, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: repoRoot}, RequestedSlug: "preserve-meta"}); err != nil { + t.Fatal(err) + } + + b.cfg.Vast.ReleaseAction = "destroy" + b.DirectSSHBackend.Cfg = b.cfg + resolved, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "preserve-meta", Repo: core.Repo{Root: repoRoot}}) + if err != nil { + t.Fatal(err) + } + if resolved.Server.Labels[vastReleaseActionLabel] != "stop" || resolved.Server.Labels[vastKeyIDLabel] != "key-100" || resolved.Server.Labels[vastKeyOwnedLabel] != "true" { + t.Fatalf("resolved labels=%#v", resolved.Server.Labels) + } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("preserve-meta", providerName) + if claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } + if claim.Labels[vastReleaseActionLabel] != "stop" || claim.Labels[vastKeyIDLabel] != "key-100" || claim.Labels[vastKeyOwnedLabel] != "true" { + t.Fatalf("claim labels=%#v", claim.Labels) + } +} + func TestResolveRejectsTerminalStatusForRunButAllowsRelease(t *testing.T) { api := &fakeVastAPI{instances: []vastInstance{{ID: 9, Label: encodeVastOwnershipLabel("cbx_failed", "failed", "ready"), Status: "failed", SSHHost: "203.0.113.9", SSHPort: 22}}} b := newTestBackend(t, api) @@ -298,6 +326,19 @@ func TestResolveRejectsTerminalStatusForRunButAllowsRelease(t *testing.T) { } } +func TestResolveStatusOnlyAllowsInstanceWithoutSSHEndpoint(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{{ID: 10, Label: encodeVastOwnershipLabel("cbx_status", "status-me", "stopped"), Status: "stopped"}}} + b := newTestBackend(t, api) + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "status-me", StatusOnly: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_status" || lease.SSH.Host != "" { + t.Fatalf("lease=%#v", lease) + } +} + func TestAcquireRollsBackOnCallbackFailure(t *testing.T) { api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} b := newTestBackend(t, api) diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go index 0675cf44b..2edd6db5e 100644 --- a/internal/providers/vast/provider.go +++ b/internal/providers/vast/provider.go @@ -39,6 +39,26 @@ func (Provider) ApplyFlags(cfg *core.Config, fs *flag.FlagSet, values any) error return ApplyVastProviderFlags(cfg, fs, values) } +func (Provider) PrepareLeaseClaimEndpoint(existing core.LeaseClaim, provider, slug string, server core.Server, allowProviderMetadata bool) (core.Server, error) { + if provider != providerName { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s as provider=%s", existing.LeaseID, provider) + } + if slug != existing.Slug { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with slug=%s", existing.LeaseID, slug) + } + if server.Labels["lease"] != existing.LeaseID || server.Labels["slug"] != existing.Slug { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with mismatched label identity", existing.LeaseID) + } + if existing.CloudID != "" && server.CloudID != "" && existing.CloudID != server.CloudID { + return core.Server{}, core.Exit(2, "refusing to rewrite vast lease=%s with stale instance identity", existing.LeaseID) + } + if allowProviderMetadata { + return server, nil + } + server.Labels = preserveVastClaimMetadata(server.Labels, existing.Labels) + return server, nil +} + func (p Provider) Configure(cfg core.Config, rt core.Runtime) (core.Backend, error) { if err := p.ValidateConfig(cfg); err != nil { return nil, err From a00eba96ef5c67d33c2d611978a1b35ff7afec93 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:49:48 -0700 Subject: [PATCH 12/18] fix(vast): detach SSH key before destroy Detach provider-owned Vast SSH keys before destroying the instance so normal release cleanup follows the same effective ordering as rollback. Add an order assertion to prevent the detach call from becoming a suppressed post-destroy no-op. --- internal/providers/vast/backend.go | 6 +++--- internal/providers/vast/backend_test.go | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index 70a4f3d5a..bfaa3851d 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -540,14 +540,14 @@ func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.S return fmt.Errorf("finalize vast stop claim: %w", err) } default: - if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { - return err - } if keyID := strings.TrimSpace(claim.Labels[vastKeyIDLabel]); keyID != "" && claim.Labels[vastKeyOwnedLabel] == "true" { if err := client.DetachInstanceSSHKey(ctx, instanceID, keyID); err != nil && !isVastNotFound(err) { return err } } + if err := client.DestroyInstance(ctx, instanceID); err != nil && !isVastNotFound(err) { + return err + } if err := core.RemoveLeaseClaimIfUnchanged(leaseID, claim); err != nil { return fmt.Errorf("finalize vast cleanup claim: %w", err) } diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index 9141d5b0d..d7eec7d67 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -44,6 +44,7 @@ type fakeVastAPI struct { id int keyID string } + events []string nextID int } @@ -133,6 +134,7 @@ func (f *fakeVastAPI) ManageInstance(_ context.Context, id int, input vastManage func (f *fakeVastAPI) DestroyInstance(_ context.Context, id int) error { f.destroyed = append(f.destroyed, id) + f.events = append(f.events, "destroy:"+strconv.Itoa(id)) if f.destroyErr != nil { return f.destroyErr } @@ -166,6 +168,7 @@ func (f *fakeVastAPI) DetachInstanceSSHKey(_ context.Context, id int, keyID stri id int keyID string }{id: id, keyID: keyID}) + f.events = append(f.events, "detach:"+strconv.Itoa(id)+":"+keyID) return f.detachErr } @@ -395,6 +398,12 @@ func TestReleaseDestroysByDefaultAndRemovesClaim(t *testing.T) { if len(api.destroyed) != 1 || api.destroyed[0] != 100 { t.Fatalf("destroyed=%v", api.destroyed) } + if len(api.detached) != 1 || api.detached[0].id != 100 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if got, want := strings.Join(api.events, ","), "detach:100:key-100,destroy:100"; got != want { + t.Fatalf("events=%q want %q", got, want) + } if _, ok, claimErr := core.ResolveLeaseClaimForProvider("destroy-me", providerName); claimErr != nil || ok { t.Fatalf("claim ok=%v err=%v", ok, claimErr) } From 2c18d62faef17af9bcc857a23c679374d2158b65 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:53:28 -0700 Subject: [PATCH 13/18] fix(vast): honor explicit release action overrides Let an explicitly supplied Vast release action override the persisted claim label during release and reporting. This keeps safety overrides such as --vast-release-action keep from being ignored on leases acquired with the default destroy policy. --- internal/providers/vast/backend.go | 11 ++++++++-- internal/providers/vast/backend_test.go | 28 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index bfaa3851d..d94c3b3d5 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -424,7 +424,7 @@ func (b *backend) ReleaseLease(ctx context.Context, req core.ReleaseLeaseRequest } func (b *backend) ReleaseLeaseMessage(lease core.LeaseTarget) string { - action := normalizeVastReleaseAction(firstNonBlank(lease.Server.Labels[vastReleaseActionLabel], b.cfg.Vast.ReleaseAction)) + action := effectiveVastReleaseAction(b.cfg, lease.Server.Labels) if action == "stop" || action == "keep" { return fmt.Sprintf("%s lease=%s vast=%s name=%s", action, lease.LeaseID, lease.Server.DisplayID(), lease.Server.Name) } @@ -528,7 +528,7 @@ func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.S } else if !isVastNotFound(getErr) { return getErr } - action := normalizeVastReleaseAction(firstNonBlank(claim.Labels[vastReleaseActionLabel], b.cfg.Vast.ReleaseAction)) + action := effectiveVastReleaseAction(b.cfg, claim.Labels) switch action { case "keep": return nil @@ -740,6 +740,13 @@ func normalizeVastReleaseAction(value string) string { } } +func effectiveVastReleaseAction(cfg core.Config, labels map[string]string) string { + if core.DeleteOnReleaseExplicit(cfg, providerName) { + return normalizeVastReleaseAction(cfg.Vast.ReleaseAction) + } + return normalizeVastReleaseAction(firstNonBlank(labels[vastReleaseActionLabel], cfg.Vast.ReleaseAction)) +} + func parseVastInstanceID(value string) (int, bool) { id, err := strconv.Atoi(strings.TrimSpace(value)) return id, err == nil && id > 0 diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index d7eec7d67..15dd8e3a8 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -409,6 +409,34 @@ func TestReleaseDestroysByDefaultAndRemovesClaim(t *testing.T) { } } +func TestReleaseHonorsExplicitReleaseActionOverride(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "keep-me"}) + if err != nil { + t.Fatal(err) + } + if lease.Server.Labels[vastReleaseActionLabel] != "destroy" { + t.Fatalf("lease labels=%#v", lease.Server.Labels) + } + + b.cfg.Vast.ReleaseAction = "keep" + core.MarkDeleteOnReleaseExplicit(&b.cfg, providerName) + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.destroyed) != 0 || len(api.detached) != 0 { + t.Fatalf("destroyed=%v detached=%v", api.destroyed, api.detached) + } + msg := b.ReleaseLeaseMessage(lease) + if !strings.Contains(msg, "keep lease=") || strings.Contains(msg, "destroyed") { + t.Fatalf("message=%q", msg) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("keep-me", providerName); claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } +} + func TestReleaseStopIsExplicitAndTested(t *testing.T) { api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} b := newTestBackend(t, api) From 9aa29834fcee615d5e5c3efe4c2e63caa4d94fcc Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:57:56 -0700 Subject: [PATCH 14/18] test(cli): widen controller subprocess deadlines Give controller subprocess tests more non-race scheduling headroom after repeated local failures in terminal-bound process lifecycle checks. This keeps provider verification from failing on unrelated launch-gate timing. --- internal/cli/controller_subprocess_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/cli/controller_subprocess_test.go b/internal/cli/controller_subprocess_test.go index 109832d2b..7f90dfc14 100644 --- a/internal/cli/controller_subprocess_test.go +++ b/internal/cli/controller_subprocess_test.go @@ -23,9 +23,9 @@ import ( func controllerSubprocessTestTimeout(base time.Duration) time.Duration { if raceEnabled { - return 5 * base + return 10 * base } - return base + return 5 * base } type fixedIdentityExecControllerRunner struct { From f97022818dcb1fcf64a55e8adecdf2e27782b273 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 19:55:08 -0700 Subject: [PATCH 15/18] fix(vast): align live API and retained lease handling Encode Vast search and create requests using the documented API shapes, restore stored SSH keys on resolve, and keep stopped lease claims so retained instances can be reconciled or destroyed later. Also widen a race-instrumented Apple VZ helper test timeout encountered during the final verification gate. --- .../applevzhelper/cli_darwin_arm64_test.go | 2 +- internal/providers/vast/backend.go | 36 ++-- internal/providers/vast/backend_test.go | 75 +++++++- internal/providers/vast/client.go | 165 ++++++++++++++---- internal/providers/vast/client_test.go | 63 +++++-- internal/providers/vast/provider.go | 3 + internal/providers/vast/provider_test.go | 40 +++++ 7 files changed, 320 insertions(+), 64 deletions(-) diff --git a/internal/applevzhelper/cli_darwin_arm64_test.go b/internal/applevzhelper/cli_darwin_arm64_test.go index ae34048ec..d4328b245 100644 --- a/internal/applevzhelper/cli_darwin_arm64_test.go +++ b/internal/applevzhelper/cli_darwin_arm64_test.go @@ -311,7 +311,7 @@ exit 23 "--cpus", "2", "--memory-mib", "2048", "--disk-gib", "16", - "--ready-timeout", "5s", + "--ready-timeout", "15s", }, strings.NewReader(`{"image":"test.img"}`), &bytes.Buffer{}, &bytes.Buffer{}) if err == nil || (!strings.Contains(err.Error(), "helper daemon exited before the VM reached running state") && diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index d94c3b3d5..558cfbd83 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -20,6 +20,7 @@ const ( vastKeyIDLabel = "provider_key_id" vastKeyOwnedLabel = "provider_key_owned" vastOfferIDLabel = "vast_offer_id" + vastReadyCheck = "command -v git >/dev/null && command -v rsync >/dev/null && command -v tar >/dev/null && command -v python3 >/dev/null" vastReleaseActionLabel = "release_action" ) @@ -324,19 +325,12 @@ func (b *backend) Resolve(ctx context.Context, req core.ResolveRequest) (core.Le for _, item := range instances { byID[item.ID] = item } - if id, ok := parseVastInstanceID(req.ID); ok { - item, found := byID[id] - if !found { - item, err = client.GetInstance(ctx, id) - if err != nil { - return b.releaseTargetFromClaim(req.ID, err, req.ReleaseOnly) - } - } - return b.targetFromInstance(item, req) - } servers := serversFromInstances(instances, b.cfg, false) server, leaseID, err := core.FindServerByAlias(servers, req.ID) - if err == nil && leaseID != "" { + if err != nil { + return core.LeaseTarget{}, err + } + if leaseID != "" { if id, ok := parseVastInstanceID(server.CloudID); ok { return b.targetFromInstance(byID[id], req) } @@ -358,8 +352,15 @@ func (b *backend) Resolve(ctx context.Context, req core.ResolveRequest) (core.Le return claimTarget(claim), nil } } - if err != nil { - return core.LeaseTarget{}, err + if id, ok := parseVastInstanceID(req.ID); ok { + item, found := byID[id] + if !found { + item, err = client.GetInstance(ctx, id) + if err != nil { + return b.releaseTargetFromClaim(req.ID, err, req.ReleaseOnly) + } + } + return b.targetFromInstance(item, req) } return core.LeaseTarget{}, exit(4, "lease/instance not found: %s", req.ID) } @@ -389,11 +390,12 @@ func (b *backend) targetFromInstance(item vastInstance, req core.ResolveRequest) server = mergeVastClaimMetadata(server) leaseID := server.Labels["lease"] target := core.LeaseTarget{Server: server, LeaseID: leaseID} - if !req.ReleaseOnly && !req.StatusOnly { + if !req.ReleaseOnly && (!req.StatusOnly || req.ReadyProbe) { ssh, err := sshTargetFromInstance(b.cfg, item) if err != nil { return core.LeaseTarget{}, err } + core.UseStoredTestboxKey(&ssh, leaseID) target.SSH = ssh } if req.Repo.Root != "" && !req.NoLocalStateMutations { @@ -536,7 +538,10 @@ func (b *backend) deleteServer(ctx context.Context, _ core.Config, server core.S if _, err := client.ManageInstance(ctx, instanceID, vastManageInstanceInput{State: "stopped", Label: encodeVastOwnershipLabel(leaseID, server.Labels["slug"], "stopped")}); err != nil { return err } - if err := core.RemoveLeaseClaimIfUnchanged(leaseID, claim); err != nil { + labels := cloneLabels(claim.Labels) + labels["state"] = "stopped" + labels[vastReleaseActionLabel] = "stop" + if _, err := core.UpdateLeaseClaimLabelsIfUnchanged(leaseID, claim, labels); err != nil { return fmt.Errorf("finalize vast stop claim: %w", err) } default: @@ -698,6 +703,7 @@ func sshTargetFromInstance(cfg core.Config, item vastInstance) (core.SSHTarget, ssh.Port = strconv.Itoa(item.SSHPort) ssh.User = firstNonBlank(cfg.SSHUser, cfg.Vast.User, "root") ssh.TargetOS = core.TargetLinux + ssh.ReadyCheck = vastReadyCheck return ssh, nil } diff --git a/internal/providers/vast/backend_test.go b/internal/providers/vast/backend_test.go index 15dd8e3a8..494a4dbc2 100644 --- a/internal/providers/vast/backend_test.go +++ b/internal/providers/vast/backend_test.go @@ -257,6 +257,11 @@ func TestListFiltersOwnedByDefaultAndAllShowsManual(t *testing.T) { func TestAcquireCreatesAttachesPollsReadinessAndClaims(t *testing.T) { api := &fakeVastAPI{offers: []vastOffer{{ID: 42, GPUName: "RTX 4090", GPUCount: 1, Rentable: true}}} b := newTestBackend(t, api) + var readyTarget core.SSHTarget + b.waitSSH = func(_ context.Context, target *core.SSHTarget, _ string, _ time.Duration) error { + readyTarget = *target + return nil + } lease, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "gpu-box", Keep: true}) if err != nil { t.Fatal(err) @@ -264,6 +269,12 @@ func TestAcquireCreatesAttachesPollsReadinessAndClaims(t *testing.T) { if lease.LeaseID == "" || lease.Server.CloudID != "100" || lease.SSH.Host != "203.0.113.24" || lease.SSH.Port != "2201" || lease.SSH.User != "root" || lease.SSH.Key == "" { t.Fatalf("lease=%#v", lease) } + if lease.SSH.ReadyCheck != vastReadyCheck || strings.Contains(lease.SSH.ReadyCheck, "crabbox-ready") { + t.Fatalf("lease ready check=%q", lease.SSH.ReadyCheck) + } + if readyTarget.ReadyCheck != vastReadyCheck { + t.Fatalf("waitSSH target ready check=%q", readyTarget.ReadyCheck) + } if len(api.searches) != 1 || api.searches[0].Config.Order != "dlperf_per_dphtotal desc" { t.Fatalf("searches=%#v", api.searches) } @@ -291,7 +302,8 @@ func TestResolvePreservesPersistedVastClaimMetadata(t *testing.T) { repoRoot := t.TempDir() b.cfg.Vast.ReleaseAction = "stop" b.DirectSSHBackend.Cfg = b.cfg - if _, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: repoRoot}, RequestedSlug: "preserve-meta"}); err != nil { + acquired, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: repoRoot}, RequestedSlug: "preserve-meta"}) + if err != nil { t.Fatal(err) } @@ -304,6 +316,9 @@ func TestResolvePreservesPersistedVastClaimMetadata(t *testing.T) { if resolved.Server.Labels[vastReleaseActionLabel] != "stop" || resolved.Server.Labels[vastKeyIDLabel] != "key-100" || resolved.Server.Labels[vastKeyOwnedLabel] != "true" { t.Fatalf("resolved labels=%#v", resolved.Server.Labels) } + if resolved.SSH.Key != acquired.SSH.Key { + t.Fatalf("resolved SSH key=%q want %q", resolved.SSH.Key, acquired.SSH.Key) + } claim, ok, claimErr := core.ResolveLeaseClaimForProvider("preserve-meta", providerName) if claimErr != nil || !ok { t.Fatalf("claim ok=%v err=%v", ok, claimErr) @@ -342,6 +357,42 @@ func TestResolveStatusOnlyAllowsInstanceWithoutSSHEndpoint(t *testing.T) { } } +func TestResolveStatusOnlyReadyProbeIncludesSSHTarget(t *testing.T) { + api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} + b := newTestBackend(t, api) + acquired, err := b.Acquire(context.Background(), core.AcquireRequest{Repo: core.Repo{Root: t.TempDir()}, RequestedSlug: "status-wait", Keep: true}) + if err != nil { + t.Fatal(err) + } + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "status-wait", StatusOnly: true, ReadyProbe: true}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != acquired.LeaseID || lease.SSH.Host != acquired.SSH.Host || lease.SSH.Key != acquired.SSH.Key { + t.Fatalf("lease=%#v acquired=%#v", lease, acquired) + } + if lease.SSH.ReadyCheck != vastReadyCheck { + t.Fatalf("ready check=%q", lease.SSH.ReadyCheck) + } +} + +func TestResolvePrefersNumericSlugOverInstanceID(t *testing.T) { + api := &fakeVastAPI{instances: []vastInstance{ + {ID: 123, Label: encodeVastOwnershipLabel("cbx_other", "other", "ready"), Status: "running", SSHHost: "203.0.113.123", SSHPort: 22}, + {ID: 100, Label: encodeVastOwnershipLabel("cbx_numeric", "123", "ready"), Status: "running", SSHHost: "203.0.113.100", SSHPort: 2200}, + }} + b := newTestBackend(t, api) + + lease, err := b.Resolve(context.Background(), core.ResolveRequest{ID: "123"}) + if err != nil { + t.Fatal(err) + } + if lease.LeaseID != "cbx_numeric" || lease.Server.CloudID != "100" || lease.SSH.Host != "203.0.113.100" { + t.Fatalf("lease=%#v", lease) + } +} + func TestAcquireRollsBackOnCallbackFailure(t *testing.T) { api := &fakeVastAPI{offers: []vastOffer{{ID: 42, Rentable: true}}} b := newTestBackend(t, api) @@ -455,6 +506,28 @@ func TestReleaseStopIsExplicitAndTested(t *testing.T) { if len(api.managed) < 2 || api.managed[len(api.managed)-1].input.State != "stopped" { t.Fatalf("managed=%#v", api.managed) } + claim, ok, claimErr := core.ResolveLeaseClaimForProvider("stop-me", providerName) + if claimErr != nil || !ok { + t.Fatalf("claim ok=%v err=%v", ok, claimErr) + } + if claim.Labels["state"] != "stopped" || claim.Labels[vastReleaseActionLabel] != "stop" || claim.Labels[vastKeyIDLabel] != "key-100" { + t.Fatalf("claim labels=%#v", claim.Labels) + } + + b.cfg.Vast.ReleaseAction = "destroy" + core.MarkDeleteOnReleaseExplicit(&b.cfg, providerName) + if err := b.ReleaseLease(context.Background(), core.ReleaseLeaseRequest{Lease: lease}); err != nil { + t.Fatal(err) + } + if len(api.detached) != 1 || api.detached[0].id != 100 || api.detached[0].keyID != "key-100" { + t.Fatalf("detached=%v", api.detached) + } + if len(api.destroyed) != 1 || api.destroyed[0] != 100 { + t.Fatalf("destroyed=%v", api.destroyed) + } + if _, ok, claimErr := core.ResolveLeaseClaimForProvider("stop-me", providerName); claimErr != nil || ok { + t.Fatalf("claim after destroy ok=%v err=%v", ok, claimErr) + } } func TestReleaseLeaseMessageUsesPersistedReleaseAction(t *testing.T) { diff --git a/internal/providers/vast/client.go b/internal/providers/vast/client.go index 406637ed2..2226b9b1f 100644 --- a/internal/providers/vast/client.go +++ b/internal/providers/vast/client.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "regexp" + "sort" "strconv" "strings" ) @@ -172,7 +173,11 @@ func (c *vastClient) do(ctx context.Context, method, path string, body any, out } reader = bytes.NewReader(data) } - req, err := http.NewRequestWithContext(ctx, method, c.apiURL+path, reader) + endpoint := c.apiURL + path + if parsed, err := url.Parse(path); err == nil && parsed.IsAbs() { + endpoint = path + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, reader) if err != nil { return err } @@ -252,11 +257,29 @@ func (c *vastClient) GetInstance(ctx context.Context, id int) (vastInstance, err } func (c *vastClient) ListInstances(ctx context.Context) ([]vastInstance, error) { - var raw json.RawMessage - if err := c.do(ctx, http.MethodGet, "/instances/", nil, &raw); err != nil { - return nil, err + var out []vastInstance + var afterToken string + for { + params := url.Values{} + params.Set("limit", "25") + if afterToken != "" { + params.Set("after_token", afterToken) + } + var raw json.RawMessage + endpoint := vastAPIURLForVersion(c.apiURL, "v1") + "/instances/?" + params.Encode() + if err := c.do(ctx, http.MethodGet, endpoint, nil, &raw); err != nil { + return nil, err + } + page, nextToken, err := decodeVastInstancesPage(raw) + if err != nil { + return nil, err + } + out = append(out, page...) + if strings.TrimSpace(nextToken) == "" { + return out, nil + } + afterToken = nextToken } - return decodeVastInstances(raw) } func (c *vastClient) ManageInstance(ctx context.Context, id int, input vastManageInstanceInput) (vastInstance, error) { @@ -291,37 +314,68 @@ func (c *vastClient) DetachInstanceSSHKey(ctx context.Context, id int, keyID str } func buildVastOfferSearchPayload(cfg VastConfig) map[string]any { - query := map[string]any{ - "verified": true, - "rentable": true, - "rented": false, - "direct_port_count": map[string]any{ - "gte": 1, - }, + payload := map[string]any{ + "verified": vastFilter("eq", true), + "rentable": vastFilter("eq", true), + "rented": vastFilter("eq", false), + "direct_port_count": vastFilter("gte", 1), } if cfg.GPUName != "" { - query["gpu_name"] = cfg.GPUName + payload["gpu_name"] = vastFilter("eq", cfg.GPUName) } if cfg.GPUCount > 0 { - query["num_gpus"] = map[string]any{"gte": cfg.GPUCount} + payload["num_gpus"] = vastFilter("gte", cfg.GPUCount) } if cfg.MinReliability > 0 { - query["reliability2"] = map[string]any{"gte": cfg.MinReliability} + payload["reliability"] = vastFilter("gte", cfg.MinReliability) } if cfg.MaxDphTotal > 0 { - query["dph_total"] = map[string]any{"lte": cfg.MaxDphTotal} + payload["dph_total"] = vastFilter("lte", cfg.MaxDphTotal) } - instanceType := normalizeInstanceType(strings.TrimSpace(cfg.InstanceType)) + instanceType := vastAPIInstanceType(cfg.InstanceType) if instanceType == "" { instanceType = "ondemand" } - return map[string]any{ - "type": instanceType, - "query": query, - "order": strings.TrimSpace(cfg.Order), + payload["type"] = instanceType + if order := strings.TrimSpace(cfg.Order); order != "" { + payload["order"] = vastOrderTuples(order) + } + return payload +} + +func vastFilter(operator string, value any) map[string]any { + return map[string]any{operator: value} +} + +func vastAPIInstanceType(value string) string { + switch normalizeInstanceType(value) { + case "interruptible": + return "bid" + default: + return normalizeInstanceType(value) } } +func vastOrderTuples(order string) [][]string { + parts := strings.Split(order, ",") + out := make([][]string, 0, len(parts)) + for _, part := range parts { + fields := strings.Fields(strings.TrimSpace(part)) + if len(fields) == 0 { + continue + } + direction := "desc" + if len(fields) > 1 { + switch strings.ToLower(fields[1]) { + case "asc", "desc": + direction = strings.ToLower(fields[1]) + } + } + out = append(out, []string{fields[0], direction}) + } + return out +} + func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { cfg := input.Config payload := map[string]any{ @@ -334,7 +388,7 @@ func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { payload["image"] = cfg.Image } if cfg.TemplateID != "" { - payload["template_id"] = cfg.TemplateID + payload["template_hash_id"] = cfg.TemplateID } if cfg.DiskGB > 0 { payload["disk"] = cfg.DiskGB @@ -346,7 +400,7 @@ func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { payload["ssh_key"] = input.SSHKey } if len(input.Environment) > 0 { - payload["env"] = input.Environment + payload["env"] = vastEnvFlags(input.Environment) } if input.OnStart != "" { payload["onstart"] = input.OnStart @@ -354,6 +408,37 @@ func buildVastCreatePayload(input vastCreateInstanceInput) map[string]any { return payload } +func vastEnvFlags(env map[string]string) string { + keys := make([]string, 0, len(env)) + for key := range env { + key = strings.TrimSpace(key) + if key != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, "-e", key+"="+shellQuoteVastEnvValue(env[key])) + } + return strings.Join(parts, " ") +} + +func shellQuoteVastEnvValue(value string) string { + if value == "" { + return "''" + } + if strings.IndexFunc(value, func(r rune) bool { + return (r < 'A' || r > 'Z') && + (r < 'a' || r > 'z') && + (r < '0' || r > '9') && + !strings.ContainsRune("_-./:", r) + }) == -1 { + return value + } + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} + func decodeVastOffers(raw json.RawMessage) ([]vastOffer, error) { var direct []vastOffer if err := json.Unmarshal(raw, &direct); err == nil { @@ -406,8 +491,9 @@ func decodeVastInstance(raw json.RawMessage) (vastInstance, error) { return normalizeVastInstance(direct), nil } var envelope struct { - Instance vastInstance `json:"instance"` - Data vastInstance `json:"data"` + Instance vastInstance `json:"instance"` + Instances vastInstance `json:"instances"` + Data vastInstance `json:"data"` } if err := json.Unmarshal(raw, &envelope); err != nil { return vastInstance{}, err @@ -415,30 +501,49 @@ func decodeVastInstance(raw json.RawMessage) (vastInstance, error) { if envelope.Instance.ID != 0 || envelope.Instance.ContractID != 0 || envelope.Instance.SSHHost != "" { return normalizeVastInstance(envelope.Instance), nil } + if envelope.Instances.ID != 0 || envelope.Instances.ContractID != 0 || envelope.Instances.SSHHost != "" { + return normalizeVastInstance(envelope.Instances), nil + } return normalizeVastInstance(envelope.Data), nil } func decodeVastInstances(raw json.RawMessage) ([]vastInstance, error) { + instances, _, err := decodeVastInstancesPage(raw) + return instances, err +} + +func decodeVastInstancesPage(raw json.RawMessage) ([]vastInstance, string, error) { var direct []vastInstance if err := json.Unmarshal(raw, &direct); err == nil { - return normalizeVastInstances(direct), nil + return normalizeVastInstances(direct), "", nil } var envelope struct { Instances []vastInstance `json:"instances"` Data []vastInstance `json:"data"` Results []vastInstance `json:"results"` + NextToken string `json:"next_token"` } if err := json.Unmarshal(raw, &envelope); err != nil { - return nil, err + return nil, "", err } switch { case envelope.Instances != nil: - return normalizeVastInstances(envelope.Instances), nil + return normalizeVastInstances(envelope.Instances), envelope.NextToken, nil case envelope.Data != nil: - return normalizeVastInstances(envelope.Data), nil + return normalizeVastInstances(envelope.Data), envelope.NextToken, nil default: - return normalizeVastInstances(envelope.Results), nil + return normalizeVastInstances(envelope.Results), envelope.NextToken, nil + } +} + +func vastAPIURLForVersion(apiURL, version string) string { + base := strings.TrimRight(apiURL, "/") + for _, suffix := range []string{"/api/v0", "/api/v1"} { + if strings.HasSuffix(base, suffix) { + return strings.TrimSuffix(base, suffix) + "/api/" + version + } } + return base } func decodeVastSSHKeys(raw json.RawMessage) ([]vastInstanceSSHKey, error) { diff --git a/internal/providers/vast/client_test.go b/internal/providers/vast/client_test.go index 6562baa9c..4b0715231 100644 --- a/internal/providers/vast/client_test.go +++ b/internal/providers/vast/client_test.go @@ -84,16 +84,25 @@ func TestOfferSearchPayloadAndDecode(t *testing.T) { if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatal(err) } - query := body["query"].(map[string]any) - if body["type"] != "ondemand" || body["order"] != "dlperf_per_dphtotal desc" || - query["gpu_name"] != "H100" || query["verified"] != true || query["rentable"] != true || query["rented"] != false { + if body["type"] != "ondemand" || + body["gpu_name"].(map[string]any)["eq"] != "H100" || + body["verified"].(map[string]any)["eq"] != true || + body["rentable"].(map[string]any)["eq"] != true || + body["rented"].(map[string]any)["eq"] != false { t.Fatalf("unexpected search body: %#v", body) } - if query["direct_port_count"].(map[string]any)["gte"] != float64(1) || - query["num_gpus"].(map[string]any)["gte"] != float64(4) || - query["reliability2"].(map[string]any)["gte"] != 0.95 || - query["dph_total"].(map[string]any)["lte"] != 3.5 { - t.Fatalf("unexpected query filters: %#v", query) + order := body["order"].([]any) + if len(order) != 1 || order[0].([]any)[0] != "dlperf_per_dphtotal" || order[0].([]any)[1] != "desc" { + t.Fatalf("order=%#v", body["order"]) + } + if _, ok := body["query"]; ok { + t.Fatalf("search body should use top-level filters: %#v", body) + } + if body["direct_port_count"].(map[string]any)["gte"] != float64(1) || + body["num_gpus"].(map[string]any)["gte"] != float64(4) || + body["reliability"].(map[string]any)["gte"] != 0.95 || + body["dph_total"].(map[string]any)["lte"] != 3.5 { + t.Fatalf("unexpected search filters: %#v", body) } writeJSON(t, w, map[string]any{"offers": []map[string]any{{"id": 11, "gpu_name": "H100", "ssh_host": "203.0.113.10", "ssh_port": 2201}}}) })) @@ -116,6 +125,13 @@ func TestOfferSearchPayloadAndDecode(t *testing.T) { } } +func TestOfferSearchPayloadMapsInterruptibleToBid(t *testing.T) { + body := buildVastOfferSearchPayload(VastConfig{InstanceType: "interruptible"}) + if body["type"] != "bid" { + t.Fatalf("type=%#v want bid", body["type"]) + } +} + func TestCreateInstancePayloadAndDecodeNewContract(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut || r.URL.Path != "/api/v0/asks/42/" { @@ -127,14 +143,16 @@ func TestCreateInstancePayloadAndDecodeNewContract(t *testing.T) { } if body["runtype"] != "ssh_direct" || body["target_state"] != "running" || body["cancel_unavail"] != true || body["vm"] != false || - body["image"] != "nvidia/cuda:12" || body["template_id"] != "tpl-123" || + body["image"] != "nvidia/cuda:12" || body["template_hash_id"] != "tpl-123" || body["disk"] != float64(80) || body["label"] != "cbx1|lease|slug|active" || body["ssh_key"] != "ssh-ed25519 AAAA..." { t.Fatalf("unexpected create body: %#v", body) } - env := body["env"].(map[string]any) - if env["CRABBOX"] != "1" { - t.Fatalf("env=%#v", env) + if _, ok := body["template_id"]; ok { + t.Fatalf("create body should use template_hash_id: %#v", body) + } + if body["env"] != "-e CRABBOX=1" { + t.Fatalf("env=%#v", body["env"]) } writeJSON(t, w, map[string]any{"success": true, "new_contract": 99}) })) @@ -158,12 +176,22 @@ func TestCreateInstancePayloadAndDecodeNewContract(t *testing.T) { func TestInstanceMethodsAndDecoding(t *testing.T) { var seen []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen = append(seen, r.Method+" "+r.URL.Path) + seen = append(seen, r.Method+" "+r.URL.RequestURI()) switch r.Method + " " + r.URL.Path { case "GET /api/v0/instances/99/": - writeJSON(t, w, map[string]any{"id": 99, "actual_status": "running", "ssh_host": "198.51.100.8", "ssh_port": 2222, "gpu_name": "RTX 4090"}) - case "GET /api/v0/instances/": - writeJSON(t, w, map[string]any{"instances": []map[string]any{{"id": 99}, {"contract_id": 100}}}) + writeJSON(t, w, map[string]any{"instances": map[string]any{"id": 99, "actual_status": "running", "ssh_host": "198.51.100.8", "ssh_port": 2222, "gpu_name": "RTX 4090"}}) + case "GET /api/v1/instances/": + if r.URL.Query().Get("limit") != "25" { + t.Fatalf("list query=%s", r.URL.RawQuery) + } + switch r.URL.Query().Get("after_token") { + case "": + writeJSON(t, w, map[string]any{"next_token": "page-2", "instances": []map[string]any{{"id": 99}}}) + case "page-2": + writeJSON(t, w, map[string]any{"next_token": nil, "instances": []map[string]any{{"contract_id": 100}}}) + default: + t.Fatalf("unexpected after_token query=%s", r.URL.RawQuery) + } case "PUT /api/v0/instances/99/": var body vastManageInstanceInput if err := json.NewDecoder(r.Body).Decode(&body); err != nil { @@ -206,7 +234,8 @@ func TestInstanceMethodsAndDecoding(t *testing.T) { if err := client.DestroyInstance(context.Background(), 99); err != nil { t.Fatal(err) } - if strings.Join(seen, ",") != "GET /api/v0/instances/99/,GET /api/v0/instances/,PUT /api/v0/instances/99/,DELETE /api/v0/instances/99/" { + wantSeen := "GET /api/v0/instances/99/,GET /api/v1/instances/?limit=25,GET /api/v1/instances/?after_token=page-2&limit=25,PUT /api/v0/instances/99/,DELETE /api/v0/instances/99/" + if strings.Join(seen, ",") != wantSeen { t.Fatalf("seen=%v", seen) } } diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go index 2edd6db5e..df3c13dd1 100644 --- a/internal/providers/vast/provider.go +++ b/internal/providers/vast/provider.go @@ -66,6 +66,9 @@ func (p Provider) Configure(cfg core.Config, rt core.Runtime) (core.Backend, err if cfg.TargetOS != "" && cfg.TargetOS != core.TargetLinux { return nil, exit(2, "provider=%s supports target=linux only", providerName) } + if cfg.Tailscale.Enabled || cfg.Network == core.NetworkTailscale { + return nil, exit(2, "provider=%s does not support Tailscale options", providerName) + } return newBackend(p.Spec(), cfg, rt), nil } diff --git a/internal/providers/vast/provider_test.go b/internal/providers/vast/provider_test.go index 206f05678..232746c60 100644 --- a/internal/providers/vast/provider_test.go +++ b/internal/providers/vast/provider_test.go @@ -174,3 +174,43 @@ func TestValidateConfigRejectsUnsafeValues(t *testing.T) { t.Fatalf("on-demand alias rejected: %v", err) } } + +func TestConfigureRejectsTailscaleBeforeBackend(t *testing.T) { + base := core.Config{ + TargetOS: core.TargetLinux, + Vast: core.VastConfig{ + APIURL: "https://console.vast.ai/api/v0", + InstanceType: "ondemand", + Runtype: "ssh_direct", + DiskGB: 20, + ReleaseAction: "destroy", + }, + } + tests := []struct { + name string + mutate func(*core.Config) + }{ + { + name: "enabled", + mutate: func(cfg *core.Config) { + cfg.Tailscale.Enabled = true + }, + }, + { + name: "network", + mutate: func(cfg *core.Config) { + cfg.Network = core.NetworkTailscale + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cfg := base + test.mutate(&cfg) + backend, err := (Provider{}).Configure(cfg, core.Runtime{}) + if err == nil || backend != nil || !strings.Contains(err.Error(), "does not support Tailscale") { + t.Fatalf("backend=%T err=%v, want Tailscale rejection", backend, err) + } + }) + } +} From 083858c596f5cc7cc3dffaa9687ef3d7a9d79fe0 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 21:21:02 -0700 Subject: [PATCH 16/18] fix(vast): remove dead provider helpers Drop unused Vast helper functions left behind after the pagination and release-lifecycle fixes. CI runs deadcode with test reachability, and these helpers were no longer referenced by the provider implementation or tests. --- internal/providers/vast/backend.go | 7 ------- internal/providers/vast/client.go | 5 ----- internal/providers/vast/core.go | 4 ---- 3 files changed, 16 deletions(-) diff --git a/internal/providers/vast/backend.go b/internal/providers/vast/backend.go index 558cfbd83..f79319f38 100644 --- a/internal/providers/vast/backend.go +++ b/internal/providers/vast/backend.go @@ -99,13 +99,6 @@ func (b *backend) stderr() io.Writer { return io.Discard } -func (b *backend) stdout() io.Writer { - if b.rt.Stdout != nil { - return b.rt.Stdout - } - return io.Discard -} - func (b *backend) now() time.Time { if b.rt.Clock != nil { return b.rt.Clock.Now().UTC() diff --git a/internal/providers/vast/client.go b/internal/providers/vast/client.go index 2226b9b1f..35e04ab7d 100644 --- a/internal/providers/vast/client.go +++ b/internal/providers/vast/client.go @@ -507,11 +507,6 @@ func decodeVastInstance(raw json.RawMessage) (vastInstance, error) { return normalizeVastInstance(envelope.Data), nil } -func decodeVastInstances(raw json.RawMessage) ([]vastInstance, error) { - instances, _, err := decodeVastInstancesPage(raw) - return instances, err -} - func decodeVastInstancesPage(raw json.RawMessage) ([]vastInstance, string, error) { var direct []vastInstance if err := json.Unmarshal(raw, &direct); err == nil { diff --git a/internal/providers/vast/core.go b/internal/providers/vast/core.go index 30d0c42b3..3f321e344 100644 --- a/internal/providers/vast/core.go +++ b/internal/providers/vast/core.go @@ -57,7 +57,3 @@ func isVastProviderName(provider string) bool { return false } } - -func notImplemented(operation string) error { - return exit(2, "provider=%s %s is not implemented yet", providerName, operation) -} From 282c4471f8e950043c152759c184b70b5369a517 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 21:22:25 -0700 Subject: [PATCH 17/18] docs(vast): clarify stopped release claims --- docs/providers/vast.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/providers/vast.md b/docs/providers/vast.md index 447254eab..bb80ec364 100644 --- a/docs/providers/vast.md +++ b/docs/providers/vast.md @@ -197,7 +197,9 @@ the local claim, and removes the local per-lease key. Release actions: - `destroy` or `delete`: destroy the Vast instance on `stop` or one-shot release. -- `stop`: request Vast to stop the instance and remove the local Crabbox claim. +- `stop`: request Vast to stop the instance and keep the local Crabbox claim + with `state=stopped` so later status, cleanup, or explicit destroy can + reconcile the retained resource. - `keep`: leave the instance and local claim untouched during release. Use `stop` or `keep` only when you explicitly accept the retained resource and From 88f7d7f6c23509db03ce378cf78fbe9a78ffb489 Mon Sep 17 00:00:00 2001 From: Coy Geek <65363919+coygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:02:56 -0700 Subject: [PATCH 18/18] fix(vast): reject unsafe API URL components Reject query strings and fragments in the Vast API URL so configured provider endpoints cannot smuggle secret-bearing URL components or break path construction. Document the endpoint contract and cover query and fragment inputs in provider validation tests. --- docs/providers/vast.md | 2 +- internal/providers/vast/provider.go | 3 +++ internal/providers/vast/provider_test.go | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/providers/vast.md b/docs/providers/vast.md index bb80ec364..9e2012ed7 100644 --- a/docs/providers/vast.md +++ b/docs/providers/vast.md @@ -69,7 +69,7 @@ Config keys under `vast:`: | Key | Maps to | Default | Notes | | --- | --- | --- | --- | -| `apiUrl` | `cfg.Vast.APIURL` | `https://console.vast.ai/api/v0` | Absolute Vast REST API URL without credentials. HTTPS is required except for localhost test endpoints. | +| `apiUrl` | `cfg.Vast.APIURL` | `https://console.vast.ai/api/v0` | Absolute Vast REST API URL without credentials, query strings, or fragments. HTTPS is required except for localhost test endpoints. | | `instanceType` | `cfg.Vast.InstanceType` | `ondemand` | Offer type, `ondemand` or `interruptible`; `on-demand` is normalized. | | `gpuName` | `cfg.Vast.GPUName` | empty | Optional Vast GPU name selector. | | `gpuCount` | `cfg.Vast.GPUCount` | `0` | Minimum GPU count when greater than zero. | diff --git a/internal/providers/vast/provider.go b/internal/providers/vast/provider.go index df3c13dd1..c8e5fd202 100644 --- a/internal/providers/vast/provider.go +++ b/internal/providers/vast/provider.go @@ -93,6 +93,9 @@ func (Provider) ValidateConfig(cfg core.Config) error { if err != nil || u.Scheme == "" || u.Host == "" || u.User != nil { return exit(2, "vast.apiUrl must be an absolute URL without credentials") } + if u.RawQuery != "" || u.Fragment != "" { + return exit(2, "vast.apiUrl must not include query strings or fragments") + } switch strings.ToLower(u.Scheme) { case "https", "http": default: diff --git a/internal/providers/vast/provider_test.go b/internal/providers/vast/provider_test.go index 232746c60..628a633df 100644 --- a/internal/providers/vast/provider_test.go +++ b/internal/providers/vast/provider_test.go @@ -120,6 +120,20 @@ func TestValidateConfigRejectsUnsafeValues(t *testing.T) { }, want: "absolute URL without credentials", }, + { + name: "query url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://vast.example.test/api/v0?token=secret" + }, + want: "query strings or fragments", + }, + { + name: "fragment url", + mutate: func(cfg *core.Config) { + cfg.Vast.APIURL = "https://vast.example.test/api/v0#secret" + }, + want: "query strings or fragments", + }, { name: "instance type", mutate: func(cfg *core.Config) {