diff --git a/internal/dag.go b/internal/dag.go index 32d2c36..50b64d1 100644 --- a/internal/dag.go +++ b/internal/dag.go @@ -33,6 +33,48 @@ func (d *DAG[Node]) AddEdge(from, to string) { d.Parents[to] = append(d.Parents[to], from) } +// findCycle returns the nodes forming a dependency cycle, or nil if the graph is acyclic. +func (d *DAG[Node]) findCycle() []string { + const ( + unvisited = 0 + onStack = 1 + done = 2 + ) + state := make(map[string]int, len(d.Nodes)) + var stack []string + var visit func(string) []string + visit = func(name string) []string { + state[name] = onStack + stack = append(stack, name) + for _, child := range d.Children[name] { + switch state[child] { + case onStack: + // found a back-edge: return the cycle from child to here + for i, n := range stack { + if n == child { + return append(stack[i:], child) + } + } + case unvisited: + if cycle := visit(child); cycle != nil { + return cycle + } + } + } + stack = stack[:len(stack)-1] + state[name] = done + return nil + } + for name := range d.Nodes { + if state[name] == unvisited { + if cycle := visit(name); cycle != nil { + return cycle + } + } + } + return nil +} + func (d *DAG[Node]) Subgraph(nodeNames []string) map[string]bool { visited := make(map[string]bool) var visit func(string) diff --git a/internal/dag_test.go b/internal/dag_test.go index fc04bb1..f69bc72 100644 --- a/internal/dag_test.go +++ b/internal/dag_test.go @@ -21,3 +21,19 @@ func TestDAG_Subgraph(t *testing.T) { t.Fatalf("expected c in subgraph") } } + +func TestDAG_findCycle(t *testing.T) { + d := NewDAG[bool]("") + d.AddNode("a", true) + d.AddNode("b", true) + d.AddNode("c", true) + d.AddEdge("a", "b") + d.AddEdge("b", "c") + if cycle := d.findCycle(); cycle != nil { + t.Fatalf("acyclic graph reported cycle: %v", cycle) + } + d.AddEdge("c", "a") // close the loop + if d.findCycle() == nil { + t.Fatal("cycle not detected") + } +} diff --git a/internal/log_writer.go b/internal/log_writer.go index 0003862..b5ed1e8 100644 --- a/internal/log_writer.go +++ b/internal/log_writer.go @@ -3,6 +3,7 @@ package internal import ( "bytes" "log" + "sync" ) type logWriter struct { @@ -10,11 +11,17 @@ type logWriter struct { prefixSuffixProvider func() (string, string) buffer bytes.Buffer logger *log.Logger + // mu guards buffer: the same writer is used for both stdout and stderr, + // which os/exec copies on separate goroutines + mu sync.Mutex } func (lw *logWriter) Write(p []byte) (int, error) { prefix, suffix := lw.prefixSuffixProvider() + lw.mu.Lock() + defer lw.mu.Unlock() + for _, b := range p { if b == '\n' { lw.logger.Printf("%s%s%s\n", prefix, lw.buffer.String(), suffix) diff --git a/internal/probe.go b/internal/probe.go index af3f899..7feb1b2 100644 --- a/internal/probe.go +++ b/internal/probe.go @@ -31,6 +31,7 @@ func probeLoop(ctx context.Context, probe types.Probe, callback func(ok bool, er if err != nil { return fmt.Errorf("failed to get %q: %w", httpGet.GetURL(), err) } + defer resp.Body.Close() if resp.StatusCode >= 300 { data, _ := io.ReadAll(resp.Body) return fmt.Errorf("%s: %q", resp.Status, data) diff --git a/internal/proc/container.go b/internal/proc/container.go index 0044df4..4324b7d 100644 --- a/internal/proc/container.go +++ b/internal/proc/container.go @@ -51,8 +51,13 @@ func (c *container) Run(ctx context.Context, stdout, stderr io.Writer) error { if err != nil { return fmt.Errorf("failed to create docker client: %w", err) } - defer cli.Close() c.cli = cli + // close the client once the context is done rather than when Run returns: + // the metrics goroutine keeps using c.cli until ctx is cancelled + go func() { + <-ctx.Done() + cli.Close() + }() dockerfile := filepath.Join(c.Image, "Dockerfile") id, existingHash, err := c.getContainer(ctx, cli) diff --git a/internal/proc/host.go b/internal/proc/host.go index 28576b9..27f3de5 100644 --- a/internal/proc/host.go +++ b/internal/proc/host.go @@ -40,7 +40,8 @@ func (h *host) Run(ctx context.Context, stdout, stderr io.Writer) error { cmd.SysProcAttr = &syscall.SysProcAttr{ Setpgid: true, } - cmd.Env = append(environ, os.Environ()...) + // spec env takes precedence over the inherited environment (exec is last-wins) + cmd.Env = append(os.Environ(), environ...) log := h.log log.Println("starting process") err = cmd.Start() @@ -53,6 +54,7 @@ func (h *host) Run(ctx context.Context, stdout, stderr io.Writer) error { h.pid = pid pgid, err := syscall.Getpgid(pid) if err != nil { + _ = cmd.Process.Kill() return fmt.Errorf("failed get pgid: %w", err) } go func() { diff --git a/internal/proc/kubernetes.go b/internal/proc/kubernetes.go index c69a32f..b7fdebe 100644 --- a/internal/proc/kubernetes.go +++ b/internal/proc/kubernetes.go @@ -44,6 +44,7 @@ type k8s struct { log *log.Logger spec types.Spec name string + podsMu sync.Mutex pods []string // namespace/name clientset kubernetes.Interface restConfig *rest.Config @@ -282,9 +283,11 @@ func (k *k8s) Run(ctx context.Context, stdout io.Writer, stderr io.Writer) error podKey := pod.Namespace + "/" + pod.Name + k.podsMu.Lock() if !slices.Contains(k.pods, podKey) { k.pods = append(k.pods, podKey) } + k.podsMu.Unlock() running := make(map[string]bool) @@ -443,7 +446,10 @@ func sortUnstructureds(uns []*unstructured.Unstructured) { func (k *k8s) GetMetrics(ctx context.Context) (*types.Metrics, error) { sum := &types.Metrics{} - for _, podKey := range k.pods { + k.podsMu.Lock() + pods := append([]string(nil), k.pods...) + k.podsMu.Unlock() + for _, podKey := range pods { parts := strings.SplitN(podKey, "/", 2) namespace := parts[0] podName := parts[1] diff --git a/internal/run.go b/internal/run.go index 865b995..dd1426d 100644 --- a/internal/run.go +++ b/internal/run.go @@ -53,9 +53,15 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB for name, t := range wf.Tasks { dag.AddNode(name, true) for _, dependency := range t.Dependencies { + if _, ok := wf.Tasks[dependency]; !ok { + return fmt.Errorf("task %q depends on unknown task %q", name, dependency) + } dag.AddEdge(dependency, name) } } + if cycle := dag.findCycle(); cycle != nil { + return fmt.Errorf("dependency cycle detected: %s", strings.Join(cycle, " -> ")) + } visited := dag.Subgraph(taskNames) taskByName := wf.Tasks @@ -191,12 +197,11 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB for name, taskNode := range subgraph.Nodes { stalledTime := taskNode.Task.GetStalledTimeout() stallTimers[name] = time.AfterFunc(stalledTime, func() { - if taskNode.Phase == "starting" || taskNode.Phase == "running" { + if phase := taskNode.getPhase(); phase == "starting" || phase == "running" { // we suffix the message with "starting" so we can differentiate between a task that is starting and one that is running, later on we can change the message to "output received" // and restore the phase to "running" or "starting" - taskNode.Message = fmt.Sprintf("no output for %s or more while %s", stalledTime, taskNode.Phase) - taskNode.Phase = "stalled" - logger.Printf("[%s] %s\n", taskNode.Name, taskNode.Message) + taskNode.setStatus("stalled", fmt.Sprintf("no output for %s or more while %s", stalledTime, phase)) + logger.Printf("[%s] %s\n", taskNode.Name, taskNode.getMessage()) statusEvents <- taskNode } }) @@ -210,7 +215,21 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB logger.Println("waiting for all tasks to complete") + // keep draining events so task/watcher goroutines blocked on a full + // channel can reach wg.Done() instead of deadlocking shutdown + drained := make(chan struct{}) + go func() { + for { + select { + case <-drained: + return + case <-events: + case <-statusEvents: + } + } + }() wg.Wait() + close(drained) // if any task failed, we will return an error var failures []string @@ -218,7 +237,8 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB color := 30 faint := 0 - switch node.Phase { + phase := node.getPhase() + switch phase { case "failed": // red color = 31 @@ -228,7 +248,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB faint = 2 } - logger.Printf("\033[%d;%dm[%s] (%s) %s\033[0m\n", faint, color, node.Name, node.Phase, node.Message) + logger.Printf("\033[%d;%dm[%s] (%s) %s\033[0m\n", faint, color, node.Name, phase, node.getMessage()) } if len(failures) > 0 { @@ -251,8 +271,9 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB } for _, node := range subgraph.Nodes { + phase := node.getPhase() // Check if task should cause immediate exit - if node.Phase == "failed" && node.Task.GetRestartPolicy() == "Never" { + if phase == "failed" && node.Task.GetRestartPolicy() == "Never" { logger.Printf("🚫 exiting because task %q failed and should not be restarted", node.Name) cancel() continue @@ -260,7 +281,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB // Check if task is complete and should be removed from tracking isComplete := false - switch node.Phase { + switch phase { case "succeeded", "skipped": isComplete = true case "running", "stalled": @@ -288,7 +309,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB logger.Println("🔵 all requested tasks are running:") // print a list of running tasks, and their ports for _, node := range subgraph.Nodes { - if (node.Phase == "running" || node.Phase == "stalled") && node.Task.Ports != nil { + if p := node.getPhase(); (p == "running" || p == "stalled") && node.Task.Ports != nil { for _, port := range node.Task.Ports { logger.Printf(" - %s: http://localhost:%d\n", node.Name, port.HostPort) } @@ -310,7 +331,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB for _, parentName := range subgraph.Parents[taskName] { parent := subgraph.Nodes[parentName] if parent.blocked() { - logger.Printf("task %q is blocked by %q (%s): %s\n", taskName, parentName, parent.Phase, parent.Message) + logger.Printf("task %q is blocked by %q (%s): %s\n", taskName, parentName, parent.getPhase(), parent.getMessage()) blocked = true } } @@ -322,7 +343,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB // we might already be pending, waiting, starting or running this task, so we don't want to start it again node := subgraph.Nodes[taskName] - node.cancel() + node.doCancel() allRunning = false // each task is executed in a separate goroutine @@ -336,7 +357,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB ctx, cancel := context.WithCancel(ctx) defer cancel() - node.cancel = cancel + node.setCancel(cancel) // send a poison pill to indicate that we've finish and the main loop must check to see if we need to exit defer func() { events <- poisonPill }() @@ -348,18 +369,17 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB var out io.Writer = &logWriter{ logger: logger, prefixSuffixProvider: func() (string, string) { - return fmt.Sprintf("%s[%s] (%s) ", color(node.Name), node.Name, node.Phase), "\033[0m" + return fmt.Sprintf("%s[%s] (%s) ", color(node.Name), node.Name, node.getPhase()), "\033[0m" }, } logger := log.New(out, "", 0) setNodeStatus := func(node *TaskNode, phase string, message string) { - node.Phase = phase - node.Message = message + node.setStatus(phase, message) stallTimers[node.Name].Reset(node.Task.GetStalledTimeout()) setTerminalTitle(workflowTitle(name, subgraph.Nodes)) - logger.Println(node.Message) + logger.Println(message) statusEvents <- node events <- poisonPill } @@ -461,8 +481,8 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB // so when we tail the log file, we see the output immediately buf := funcWriter(func(p []byte) (int, error) { stallTimers[node.Name].Reset(node.Task.GetStalledTimeout()) - if node.Phase == "stalled" { - if strings.HasSuffix(node.Message, "starting") { + if node.getPhase() == "stalled" { + if strings.HasSuffix(node.getMessage(), "starting") { setNodeStatus(node, "starting", "output received") } else { setNodeStatus(node, "running", "output received") @@ -493,7 +513,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB case <-ctx.Done(): return case <-ticker.C: - if node.Phase != "running" && node.Phase != "stalled" { + if ph := node.getPhase(); ph != "running" && ph != "stalled" { continue } metrics, err := p.GetMetrics(ctx) @@ -501,7 +521,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB logger.Printf("failed to get metrics: %v", err) continue } - node.Metrics = metrics + node.setMetrics(metrics) statusEvents <- node } } diff --git a/internal/server.go b/internal/server.go index 978c461..e2f70b0 100644 --- a/internal/server.go +++ b/internal/server.go @@ -27,7 +27,11 @@ func StartServer(ctx context.Context, port int, wg *sync.WaitGroup, dag DAG[*Tas go func() { for event := range events { streams.Range(func(key, value any) bool { - value.(chan *TaskNode) <- event + // non-blocking: a slow client must not stall the broadcast + select { + case value.(chan *TaskNode) <- event: + default: + } return true }) } @@ -78,18 +82,22 @@ func StartServer(ctx context.Context, port int, wg *sync.WaitGroup, dag DAG[*Tas // return an event stream w.Header().Set("Content-Type", "text/event-stream") - for event := range stream { - marshal, err := json.Marshal(event) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _, err = fmt.Fprintf(w, "data: %s\n\n", marshal) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + for { + select { + case <-r.Context().Done(): return + case event := <-stream: + marshal, err := json.Marshal(event) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, err = fmt.Fprintf(w, "data: %s\n\n", marshal) + if err != nil { + return + } + w.(http.Flusher).Flush() } - w.(http.Flusher).Flush() } }) mux.HandleFunc("/logs/{task}", func(w http.ResponseWriter, r *http.Request) { @@ -125,8 +133,12 @@ func StartServer(ctx context.Context, port int, wg *sync.WaitGroup, dag DAG[*Tas return } - // Sleep for a short duration before checking for new lines - time.Sleep(1 * time.Second) + // stop tailing when the client disconnects + select { + case <-r.Context().Done(): + return + case <-time.After(1 * time.Second): + } // Reset the scanner to continue reading new lines _, err := file.Seek(0, io.SeekCurrent) diff --git a/internal/task_node.go b/internal/task_node.go index b752b4d..6f43e99 100644 --- a/internal/task_node.go +++ b/internal/task_node.go @@ -1,6 +1,7 @@ package internal import ( + "encoding/json" "sync" "github.com/kitproj/kit/internal/types" @@ -28,12 +29,62 @@ type TaskNode struct { Metrics *types.Metrics `json:"metrics,omitempty"` // cancel function cancel func() - // a mutex + // mu serializes execution so two instances of a task don't run at once mu *sync.Mutex + // status guards the Phase, Message, Metrics and cancel fields, which are + // written by task/timer/metrics goroutines and read by the main loop and the server + status sync.RWMutex } -func (n TaskNode) blocked() bool { - switch n.Phase { +func (n *TaskNode) setStatus(phase, message string) { + n.status.Lock() + n.Phase = phase + n.Message = message + n.status.Unlock() +} + +func (n *TaskNode) setMetrics(m *types.Metrics) { + n.status.Lock() + n.Metrics = m + n.status.Unlock() +} + +func (n *TaskNode) getPhase() string { + n.status.RLock() + defer n.status.RUnlock() + return n.Phase +} + +func (n *TaskNode) getMessage() string { + n.status.RLock() + defer n.status.RUnlock() + return n.Message +} + +func (n *TaskNode) setCancel(cancel func()) { + n.status.Lock() + n.cancel = cancel + n.status.Unlock() +} + +func (n *TaskNode) doCancel() { + n.status.RLock() + cancel := n.cancel + n.status.RUnlock() + cancel() +} + +// MarshalJSON takes the status lock so the phase/message/metrics fields are +// read consistently while task goroutines may be writing them. +func (n *TaskNode) MarshalJSON() ([]byte, error) { + type alias TaskNode + n.status.RLock() + defer n.status.RUnlock() + return json.Marshal((*alias)(n)) +} + +func (n *TaskNode) blocked() bool { + switch n.getPhase() { case "running", "stalled": return n.Task.GetType() == types.TaskTypeJob case "succeeded", "skipped": diff --git a/internal/terminal.go b/internal/terminal.go index 627045a..72a8b87 100644 --- a/internal/terminal.go +++ b/internal/terminal.go @@ -49,7 +49,7 @@ func workflowTitle(name string, nodes map[string]*TaskNode) string { running := 0 failures := []string{} for _, node := range nodes { - switch node.Phase { + switch node.getPhase() { case "failed": failures = append(failures, node.Name) case "running", "stalled": diff --git a/internal/types/env_var.go b/internal/types/env_var.go index b454b38..0d8d9fb 100644 --- a/internal/types/env_var.go +++ b/internal/types/env_var.go @@ -17,15 +17,13 @@ func (v EnvVar) String() (string, error) { } func (v *EnvVar) Unstring(s string) error { - parts := strings.Split(s, "=") - switch len(parts) { - case 2: - v.Name = parts[0] - v.Value = parts[1] - return nil - default: + parts := strings.SplitN(s, "=", 2) + if len(parts) != 2 { return fmt.Errorf("invalid EnvVar string %q", s) } + v.Name = parts[0] + v.Value = parts[1] + return nil } func (v *EnvVar) UnmarshalJSON(data []byte) error { diff --git a/internal/types/envfile.go b/internal/types/envfile.go index 920da95..373288a 100644 --- a/internal/types/envfile.go +++ b/internal/types/envfile.go @@ -13,21 +13,29 @@ type Envfile Strings func (f Envfile) Environ(workingDir string) ([]string, error) { var environ []string for _, e := range f { - file, err := os.Open(filepath.Join(workingDir, e)) + lines, err := readEnvfile(filepath.Join(workingDir, e)) if err != nil { return nil, err } - defer file.Close() - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "#") { - environ = append(environ, line) - } - } - if err := scanner.Err(); err != nil { - return nil, err - } + environ = append(environ, lines...) } return environ, nil } + +func readEnvfile(path string) ([]string, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + var environ []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + environ = append(environ, line) + } + return environ, scanner.Err() +} diff --git a/internal/types/probe.go b/internal/types/probe.go index fe23cca..682513f 100644 --- a/internal/types/probe.go +++ b/internal/types/probe.go @@ -95,10 +95,13 @@ func parsePort(s string) uint16 { func (p Probe) URL() *url.URL { var u *url.URL - if p.TCPSocket != nil { + switch { + case p.TCPSocket != nil: u = p.TCPSocket.URL() - } else { + case p.HTTPGet != nil: u = p.HTTPGet.URL() + default: + return &url.URL{} } var x = url.Values{} if p.InitialDelaySeconds > 0 {