|
| 1 | +package v20260301 |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "net/http" |
| 8 | + |
| 9 | + "github.com/Azure/azure-sdk-for-go/sdk/azcore" |
| 10 | + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" |
| 11 | + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" |
| 12 | + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" |
| 13 | + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" |
| 14 | + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v8" |
| 15 | + "github.com/sirupsen/logrus" |
| 16 | + "google.golang.org/grpc/codes" |
| 17 | + "google.golang.org/grpc/status" |
| 18 | + "google.golang.org/protobuf/types/known/anypb" |
| 19 | + |
| 20 | + "github.com/Azure/AKSFlexNode/components/aksmachine" |
| 21 | + "github.com/Azure/AKSFlexNode/components/services/actions" |
| 22 | + "github.com/Azure/AKSFlexNode/pkg/utils/utilpb" |
| 23 | +) |
| 24 | + |
| 25 | +const ( |
| 26 | + aksFlexNodePoolName = "aksflexnodes" |
| 27 | + // flexNodeTagKey is the tag that identifies this machine as an AKS flex node. |
| 28 | + flexNodeTagKey = "aks-flex-node" |
| 29 | + |
| 30 | + // ARM calls to a local test server. It is for testing only and should not be set in production. |
| 31 | + armEndpointOverride = "" |
| 32 | +) |
| 33 | + |
| 34 | +type ensureMachineAction struct { |
| 35 | + logger *logrus.Logger |
| 36 | +} |
| 37 | + |
| 38 | +func newEnsureMachineAction() (actions.Server, error) { |
| 39 | + return &ensureMachineAction{ |
| 40 | + logger: logrus.New(), |
| 41 | + }, nil |
| 42 | +} |
| 43 | + |
| 44 | +var _ actions.Server = (*ensureMachineAction)(nil) |
| 45 | + |
| 46 | +// ApplyAction runs two sequential sub-steps: |
| 47 | +// 1. Ensure the "aksflexnodes" agent pool exists with mode "Machines". |
| 48 | +// 2. Ensure the local machine exists in that pool tagged as a flex node. |
| 49 | +// |
| 50 | +// If drift detection and remediation is not enabled in the agent config, the |
| 51 | +// action returns immediately without performing any Azure operations. |
| 52 | +func (a *ensureMachineAction) ApplyAction( |
| 53 | + ctx context.Context, |
| 54 | + req *actions.ApplyActionRequest, |
| 55 | +) (*actions.ApplyActionResponse, error) { |
| 56 | + action, err := utilpb.AnyTo[*aksmachine.EnsureMachine](req.GetItem()) |
| 57 | + if err != nil { |
| 58 | + return nil, err |
| 59 | + } |
| 60 | + |
| 61 | + spec := action.GetSpec() |
| 62 | + |
| 63 | + // Skip all Azure operations when drift detection/remediation is disabled. |
| 64 | + if !spec.GetEnabled() { |
| 65 | + a.logger.Info("EnsureMachine: drift detection and remediation is disabled, skipping") |
| 66 | + item, err := anypb.New(action) |
| 67 | + if err != nil { |
| 68 | + return nil, err |
| 69 | + } |
| 70 | + return actions.ApplyActionResponse_builder{Item: item}.Build(), nil |
| 71 | + } |
| 72 | + |
| 73 | + subID := spec.GetSubscriptionId() |
| 74 | + rg := spec.GetResourceGroup() |
| 75 | + clusterName := spec.GetClusterName() |
| 76 | + machineName := spec.GetMachineName() |
| 77 | + k8sVersion := spec.GetKubernetesVersion() |
| 78 | + |
| 79 | + if subID == "" || rg == "" || clusterName == "" || machineName == "" || k8sVersion == "" { |
| 80 | + return nil, status.Errorf(codes.InvalidArgument, |
| 81 | + "EnsureMachine: spec fields incomplete: subscriptionId=%q resourceGroup=%q clusterName=%q machineName=%q kubernetesVersion=%q", |
| 82 | + subID, rg, clusterName, machineName, k8sVersion) |
| 83 | + } |
| 84 | + |
| 85 | + cred, err := credentialFromSpec(spec.GetAzureCredential()) |
| 86 | + if err != nil { |
| 87 | + return nil, status.Errorf(codes.Internal, "EnsureMachine: resolve credential: %v", err) |
| 88 | + } |
| 89 | + |
| 90 | + armOpts := buildARMClientOptions(armEndpointOverride) |
| 91 | + |
| 92 | + // Step 1: ensure the agent pool exists with mode "Machines". |
| 93 | + if err := a.ensureAgentPool(ctx, cred, armOpts, subID, rg, clusterName); err != nil { |
| 94 | + return nil, status.Errorf(codes.Internal, "EnsureMachine: ensure agent pool: %v", err) |
| 95 | + } |
| 96 | + |
| 97 | + // Step 2: ensure this machine is registered in the pool as a flex node. |
| 98 | + if err := a.ensureMachine(ctx, cred, armOpts, spec); err != nil { |
| 99 | + return nil, status.Errorf(codes.Internal, "EnsureMachine: ensure machine: %v", err) |
| 100 | + } |
| 101 | + |
| 102 | + item, err := anypb.New(action) |
| 103 | + if err != nil { |
| 104 | + return nil, err |
| 105 | + } |
| 106 | + return actions.ApplyActionResponse_builder{Item: item}.Build(), nil |
| 107 | +} |
| 108 | + |
| 109 | +// ensureAgentPool calls CreateOrUpdate on the "aksflexnodes" agent pool with |
| 110 | +// mode "Machines" and waits for the long-running operation to complete. |
| 111 | +func (a *ensureMachineAction) ensureAgentPool(ctx context.Context, cred azcore.TokenCredential, armOpts *arm.ClientOptions, subID, rg, clusterName string) error { |
| 112 | + client, err := armcontainerservice.NewAgentPoolsClient(subID, cred, armOpts) |
| 113 | + if err != nil { |
| 114 | + return fmt.Errorf("create agent pools client: %w", err) |
| 115 | + } |
| 116 | + |
| 117 | + mode := armcontainerservice.AgentPoolMode("Machines") |
| 118 | + params := armcontainerservice.AgentPool{ |
| 119 | + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ |
| 120 | + Mode: &mode, |
| 121 | + }, |
| 122 | + } |
| 123 | + |
| 124 | + a.logger.Infof("Ensuring agent pool %q (mode=Machines) on cluster %s/%s", aksFlexNodePoolName, rg, clusterName) |
| 125 | + |
| 126 | + // Check whether the agent pool already exists; if so, skip the PUT. |
| 127 | + _, err = client.Get(ctx, rg, clusterName, aksFlexNodePoolName, nil) |
| 128 | + if err == nil { |
| 129 | + a.logger.Infof("Agent pool %q already exists on cluster %s/%s, skipping", aksFlexNodePoolName, rg, clusterName) |
| 130 | + return nil |
| 131 | + } |
| 132 | + if !isNotFound(err) { |
| 133 | + return fmt.Errorf("get agent pool %q: %w", aksFlexNodePoolName, err) |
| 134 | + } |
| 135 | + |
| 136 | + poller, err := client.BeginCreateOrUpdate(ctx, rg, clusterName, aksFlexNodePoolName, params, nil) |
| 137 | + if err != nil { |
| 138 | + return fmt.Errorf("begin create or update agent pool %q: %w", aksFlexNodePoolName, err) |
| 139 | + } |
| 140 | + |
| 141 | + if _, err = poller.PollUntilDone(ctx, nil); err != nil { |
| 142 | + return fmt.Errorf("wait for agent pool %q: %w", aksFlexNodePoolName, err) |
| 143 | + } |
| 144 | + |
| 145 | + a.logger.Infof("Agent pool %q ensured on cluster %s/%s", aksFlexNodePoolName, rg, clusterName) |
| 146 | + return nil |
| 147 | +} |
| 148 | + |
| 149 | +// ensureMachine registers this machine in the "aksflexnodes" agent pool as a |
| 150 | +// flex node. It first checks whether the machine resource already exists; if so |
| 151 | +// it skips the PUT to avoid overwriting properties managed by the AKS control plane. |
| 152 | +func (a *ensureMachineAction) ensureMachine(ctx context.Context, cred azcore.TokenCredential, armOpts *arm.ClientOptions, spec *aksmachine.EnsureMachineSpec) error { |
| 153 | + subID := spec.GetSubscriptionId() |
| 154 | + rg := spec.GetResourceGroup() |
| 155 | + clusterName := spec.GetClusterName() |
| 156 | + machineName := spec.GetMachineName() |
| 157 | + |
| 158 | + client, err := armcontainerservice.NewMachinesClient(subID, cred, armOpts) |
| 159 | + if err != nil { |
| 160 | + return fmt.Errorf("create machines client: %w", err) |
| 161 | + } |
| 162 | + |
| 163 | + // Check whether the machine is already registered; if so, skip the PUT. |
| 164 | + _, err = client.Get(ctx, rg, clusterName, aksFlexNodePoolName, machineName, nil) |
| 165 | + if err == nil { |
| 166 | + a.logger.Infof("Machine %q already exists in pool %q on cluster %s/%s, skipping", machineName, aksFlexNodePoolName, rg, clusterName) |
| 167 | + return nil |
| 168 | + } |
| 169 | + if !isNotFound(err) { |
| 170 | + return fmt.Errorf("get machine %q: %w", machineName, err) |
| 171 | + } |
| 172 | + |
| 173 | + params := armcontainerservice.Machine{ |
| 174 | + Properties: &armcontainerservice.MachineProperties{ |
| 175 | + Tags: map[string]*string{ |
| 176 | + flexNodeTagKey: to.Ptr("true"), |
| 177 | + }, |
| 178 | + Kubernetes: buildK8sProfile(spec), |
| 179 | + }, |
| 180 | + } |
| 181 | + |
| 182 | + poller, err := client.BeginCreateOrUpdate(ctx, rg, clusterName, aksFlexNodePoolName, machineName, params, nil) |
| 183 | + if err != nil { |
| 184 | + return fmt.Errorf("begin create or update machine %q: %w", machineName, err) |
| 185 | + } |
| 186 | + |
| 187 | + // if the ARM server returns a synchronous 2xx response |
| 188 | + // with no Azure-AsyncOperation / Operation-Location / Location header, the SDK treats it as synchronously |
| 189 | + // complete and PollUntilDone returns right away with the response body — no looping occurs. |
| 190 | + if _, err = poller.PollUntilDone(ctx, nil); err != nil { |
| 191 | + return fmt.Errorf("wait for machine %q: %w", machineName, err) |
| 192 | + } |
| 193 | + |
| 194 | + a.logger.Infof("Machine %q ensured in pool %q on cluster %s/%s", machineName, aksFlexNodePoolName, rg, clusterName) |
| 195 | + return nil |
| 196 | +} |
| 197 | + |
| 198 | +// buildK8sProfile constructs a MachineKubernetesProfile from the spec using |
| 199 | +// the explicit allow-list of fields permitted for flex nodes: |
| 200 | +// - OrchestratorVersion, MaxPods, NodeLabels, NodeTaints, |
| 201 | +// NodeInitializationTaints, KubeletConfig (image GC thresholds). |
| 202 | +func buildK8sProfile(spec *aksmachine.EnsureMachineSpec) *armcontainerservice.MachineKubernetesProfile { |
| 203 | + p := &armcontainerservice.MachineKubernetesProfile{} |
| 204 | + |
| 205 | + if v := spec.GetKubernetesVersion(); v != "" { |
| 206 | + p.OrchestratorVersion = to.Ptr(v) |
| 207 | + } |
| 208 | + if mp := spec.GetMaxPods(); mp > 0 { |
| 209 | + p.MaxPods = to.Ptr(mp) |
| 210 | + } |
| 211 | + if labels := spec.GetNodeLabels(); len(labels) > 0 { |
| 212 | + p.NodeLabels = make(map[string]*string, len(labels)) |
| 213 | + for k, v := range labels { |
| 214 | + p.NodeLabels[k] = to.Ptr(v) |
| 215 | + } |
| 216 | + } |
| 217 | + if taints := spec.GetNodeTaints(); len(taints) > 0 { |
| 218 | + p.NodeTaints = make([]*string, len(taints)) |
| 219 | + for i, t := range taints { |
| 220 | + p.NodeTaints[i] = to.Ptr(t) |
| 221 | + } |
| 222 | + } |
| 223 | + if initTaints := spec.GetNodeInitializationTaints(); len(initTaints) > 0 { |
| 224 | + p.NodeInitializationTaints = make([]*string, len(initTaints)) |
| 225 | + for i, t := range initTaints { |
| 226 | + p.NodeInitializationTaints[i] = to.Ptr(t) |
| 227 | + } |
| 228 | + } |
| 229 | + if kc := spec.GetKubeletConfig(); kc != nil { |
| 230 | + p.KubeletConfig = &armcontainerservice.KubeletConfig{} |
| 231 | + if h := kc.GetImageGcHighThreshold(); h > 0 { |
| 232 | + p.KubeletConfig.ImageGcHighThreshold = to.Ptr(h) |
| 233 | + } |
| 234 | + if l := kc.GetImageGcLowThreshold(); l > 0 { |
| 235 | + p.KubeletConfig.ImageGcLowThreshold = to.Ptr(l) |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + return p |
| 240 | +} |
| 241 | + |
| 242 | +// credentialFromSpec resolves an Azure ARM credential from the proto AzureCredential field. |
| 243 | +// Falls back to Azure CLI credential when the field is absent or empty. |
| 244 | +func credentialFromSpec(cred *aksmachine.AzureCredential) (azcore.TokenCredential, error) { |
| 245 | + // Prefer explicitly configured credentials when present. |
| 246 | + if cred != nil { |
| 247 | + if sp := cred.GetServicePrincipal(); sp != nil { |
| 248 | + return azidentity.NewClientSecretCredential(sp.GetTenantId(), sp.GetClientId(), sp.GetClientSecret(), nil) |
| 249 | + } |
| 250 | + if mi := cred.GetManagedIdentity(); mi != nil { |
| 251 | + opts := &azidentity.ManagedIdentityCredentialOptions{} |
| 252 | + if id := mi.GetClientId(); id != "" { |
| 253 | + opts.ID = azidentity.ClientID(id) |
| 254 | + } |
| 255 | + return azidentity.NewManagedIdentityCredential(opts) |
| 256 | + } |
| 257 | + } |
| 258 | + // Fall back to Azure CLI credential when no explicit credential is configured. |
| 259 | + return azidentity.NewAzureCLICredential(nil) |
| 260 | +} |
| 261 | + |
| 262 | +// buildARMClientOptions returns ARM client options that redirect all calls to |
| 263 | +// endpointOverride when non-empty (e.g. "http://localhost:8080" for local testing). |
| 264 | +// Returns nil when the override is empty, which causes the SDK to use the default |
| 265 | +// public Azure Resource Manager endpoint. |
| 266 | +func buildARMClientOptions(endpointOverride string) *arm.ClientOptions { |
| 267 | + if endpointOverride == "" { |
| 268 | + return nil |
| 269 | + } |
| 270 | + return &arm.ClientOptions{ |
| 271 | + ClientOptions: azcore.ClientOptions{ |
| 272 | + Cloud: cloud.Configuration{ |
| 273 | + Services: map[cloud.ServiceName]cloud.ServiceConfiguration{ |
| 274 | + cloud.ResourceManager: { |
| 275 | + Endpoint: endpointOverride, |
| 276 | + // No audience needed for local servers that don't validate tokens. |
| 277 | + Audience: endpointOverride, |
| 278 | + }, |
| 279 | + }, |
| 280 | + }, |
| 281 | + InsecureAllowCredentialWithHTTP: true, |
| 282 | + }, |
| 283 | + } |
| 284 | +} |
| 285 | + |
| 286 | +// isNotFound reports whether the Azure SDK error is an HTTP 404. |
| 287 | +func isNotFound(err error) bool { |
| 288 | + var respErr *azcore.ResponseError |
| 289 | + return errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound |
| 290 | +} |
0 commit comments