Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,48 @@ func (d *DAG[Node]) AddEdge(from, to string) {
d.Parents[to] = append(d.Parents[to], from)
}

// findCycle returns the nodes forming a dependency cycle, or nil if the graph is acyclic.
func (d *DAG[Node]) findCycle() []string {
const (
unvisited = 0
onStack = 1
done = 2
)
state := make(map[string]int, len(d.Nodes))
var stack []string
var visit func(string) []string
visit = func(name string) []string {
state[name] = onStack
stack = append(stack, name)
for _, child := range d.Children[name] {
switch state[child] {
case onStack:
// found a back-edge: return the cycle from child to here
for i, n := range stack {
if n == child {
return append(stack[i:], child)
}
}
case unvisited:
if cycle := visit(child); cycle != nil {
return cycle
}
}
}
stack = stack[:len(stack)-1]
state[name] = done
return nil
}
for name := range d.Nodes {
if state[name] == unvisited {
if cycle := visit(name); cycle != nil {
return cycle
}
}
}
return nil
}

func (d *DAG[Node]) Subgraph(nodeNames []string) map[string]bool {
visited := make(map[string]bool)
var visit func(string)
Expand Down
16 changes: 16 additions & 0 deletions internal/dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,19 @@ func TestDAG_Subgraph(t *testing.T) {
t.Fatalf("expected c in subgraph")
}
}

func TestDAG_findCycle(t *testing.T) {
d := NewDAG[bool]("")
d.AddNode("a", true)
d.AddNode("b", true)
d.AddNode("c", true)
d.AddEdge("a", "b")
d.AddEdge("b", "c")
if cycle := d.findCycle(); cycle != nil {
t.Fatalf("acyclic graph reported cycle: %v", cycle)
}
d.AddEdge("c", "a") // close the loop
if d.findCycle() == nil {
t.Fatal("cycle not detected")
}
}
7 changes: 7 additions & 0 deletions internal/log_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@ package internal
import (
"bytes"
"log"
"sync"
)

type logWriter struct {
// prefixSuffixProvider returns the prefix and suffix to use when logging.
prefixSuffixProvider func() (string, string)
buffer bytes.Buffer
logger *log.Logger
// mu guards buffer: the same writer is used for both stdout and stderr,
// which os/exec copies on separate goroutines
mu sync.Mutex
}

func (lw *logWriter) Write(p []byte) (int, error) {
prefix, suffix := lw.prefixSuffixProvider()

lw.mu.Lock()
defer lw.mu.Unlock()

for _, b := range p {
if b == '\n' {
lw.logger.Printf("%s%s%s\n", prefix, lw.buffer.String(), suffix)
Expand Down
1 change: 1 addition & 0 deletions internal/probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func probeLoop(ctx context.Context, probe types.Probe, callback func(ok bool, er
if err != nil {
return fmt.Errorf("failed to get %q: %w", httpGet.GetURL(), err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
data, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s: %q", resp.Status, data)
Expand Down
7 changes: 6 additions & 1 deletion internal/proc/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ func (c *container) Run(ctx context.Context, stdout, stderr io.Writer) error {
if err != nil {
return fmt.Errorf("failed to create docker client: %w", err)
}
defer cli.Close()
c.cli = cli
// close the client once the context is done rather than when Run returns:
// the metrics goroutine keeps using c.cli until ctx is cancelled
go func() {
<-ctx.Done()
cli.Close()
}()

dockerfile := filepath.Join(c.Image, "Dockerfile")
id, existingHash, err := c.getContainer(ctx, cli)
Expand Down
4 changes: 3 additions & 1 deletion internal/proc/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ func (h *host) Run(ctx context.Context, stdout, stderr io.Writer) error {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
cmd.Env = append(environ, os.Environ()...)
// spec env takes precedence over the inherited environment (exec is last-wins)
cmd.Env = append(os.Environ(), environ...)
log := h.log
log.Println("starting process")
err = cmd.Start()
Expand All @@ -53,6 +54,7 @@ func (h *host) Run(ctx context.Context, stdout, stderr io.Writer) error {
h.pid = pid
pgid, err := syscall.Getpgid(pid)
if err != nil {
_ = cmd.Process.Kill()
return fmt.Errorf("failed get pgid: %w", err)
}
Comment thread
alexec marked this conversation as resolved.
go func() {
Expand Down
8 changes: 7 additions & 1 deletion internal/proc/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type k8s struct {
log *log.Logger
spec types.Spec
name string
podsMu sync.Mutex
pods []string // namespace/name
clientset kubernetes.Interface
restConfig *rest.Config
Expand Down Expand Up @@ -282,9 +283,11 @@ func (k *k8s) Run(ctx context.Context, stdout io.Writer, stderr io.Writer) error

podKey := pod.Namespace + "/" + pod.Name

k.podsMu.Lock()
if !slices.Contains(k.pods, podKey) {
k.pods = append(k.pods, podKey)
}
k.podsMu.Unlock()

running := make(map[string]bool)

Expand Down Expand Up @@ -443,7 +446,10 @@ func sortUnstructureds(uns []*unstructured.Unstructured) {

func (k *k8s) GetMetrics(ctx context.Context) (*types.Metrics, error) {
sum := &types.Metrics{}
for _, podKey := range k.pods {
k.podsMu.Lock()
pods := append([]string(nil), k.pods...)
k.podsMu.Unlock()
for _, podKey := range pods {
parts := strings.SplitN(podKey, "/", 2)
namespace := parts[0]
podName := parts[1]
Expand Down
60 changes: 40 additions & 20 deletions internal/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
for name, t := range wf.Tasks {
dag.AddNode(name, true)
for _, dependency := range t.Dependencies {
if _, ok := wf.Tasks[dependency]; !ok {
return fmt.Errorf("task %q depends on unknown task %q", name, dependency)
}
dag.AddEdge(dependency, name)
}
}
if cycle := dag.findCycle(); cycle != nil {
return fmt.Errorf("dependency cycle detected: %s", strings.Join(cycle, " -> "))
}
visited := dag.Subgraph(taskNames)

taskByName := wf.Tasks
Expand Down Expand Up @@ -191,12 +197,11 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
for name, taskNode := range subgraph.Nodes {
stalledTime := taskNode.Task.GetStalledTimeout()
stallTimers[name] = time.AfterFunc(stalledTime, func() {
if taskNode.Phase == "starting" || taskNode.Phase == "running" {
if phase := taskNode.getPhase(); phase == "starting" || phase == "running" {
// we suffix the message with "starting" so we can differentiate between a task that is starting and one that is running, later on we can change the message to "output received"
// and restore the phase to "running" or "starting"
taskNode.Message = fmt.Sprintf("no output for %s or more while %s", stalledTime, taskNode.Phase)
taskNode.Phase = "stalled"
logger.Printf("[%s] %s\n", taskNode.Name, taskNode.Message)
taskNode.setStatus("stalled", fmt.Sprintf("no output for %s or more while %s", stalledTime, phase))
logger.Printf("[%s] %s\n", taskNode.Name, taskNode.getMessage())
statusEvents <- taskNode
}
})
Expand All @@ -210,15 +215,30 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB

logger.Println("waiting for all tasks to complete")

// keep draining events so task/watcher goroutines blocked on a full
// channel can reach wg.Done() instead of deadlocking shutdown
drained := make(chan struct{})
go func() {
for {
select {
case <-drained:
return
case <-events:
case <-statusEvents:
}
}
}()
wg.Wait()
close(drained)

// if any task failed, we will return an error
var failures []string
for _, node := range subgraph.Nodes {

color := 30
faint := 0
switch node.Phase {
phase := node.getPhase()
switch phase {
case "failed":
// red
color = 31
Expand All @@ -228,7 +248,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
faint = 2
}

logger.Printf("\033[%d;%dm[%s] (%s) %s\033[0m\n", faint, color, node.Name, node.Phase, node.Message)
logger.Printf("\033[%d;%dm[%s] (%s) %s\033[0m\n", faint, color, node.Name, phase, node.getMessage())
}

if len(failures) > 0 {
Expand All @@ -251,16 +271,17 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
}

for _, node := range subgraph.Nodes {
phase := node.getPhase()
// Check if task should cause immediate exit
if node.Phase == "failed" && node.Task.GetRestartPolicy() == "Never" {
if phase == "failed" && node.Task.GetRestartPolicy() == "Never" {
logger.Printf("🚫 exiting because task %q failed and should not be restarted", node.Name)
cancel()
continue
}

// Check if task is complete and should be removed from tracking
isComplete := false
switch node.Phase {
switch phase {
case "succeeded", "skipped":
isComplete = true
case "running", "stalled":
Expand Down Expand Up @@ -288,7 +309,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
logger.Println("🔵 all requested tasks are running:")
// print a list of running tasks, and their ports
for _, node := range subgraph.Nodes {
if (node.Phase == "running" || node.Phase == "stalled") && node.Task.Ports != nil {
if p := node.getPhase(); (p == "running" || p == "stalled") && node.Task.Ports != nil {
for _, port := range node.Task.Ports {
logger.Printf(" - %s: http://localhost:%d\n", node.Name, port.HostPort)
}
Expand All @@ -310,7 +331,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
for _, parentName := range subgraph.Parents[taskName] {
parent := subgraph.Nodes[parentName]
if parent.blocked() {
logger.Printf("task %q is blocked by %q (%s): %s\n", taskName, parentName, parent.Phase, parent.Message)
logger.Printf("task %q is blocked by %q (%s): %s\n", taskName, parentName, parent.getPhase(), parent.getMessage())
blocked = true
}
}
Expand All @@ -322,7 +343,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
// we might already be pending, waiting, starting or running this task, so we don't want to start it again
node := subgraph.Nodes[taskName]

node.cancel()
node.doCancel()
allRunning = false

// each task is executed in a separate goroutine
Expand All @@ -336,7 +357,7 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
ctx, cancel := context.WithCancel(ctx)
defer cancel()

node.cancel = cancel
node.setCancel(cancel)

// send a poison pill to indicate that we've finish and the main loop must check to see if we need to exit
defer func() { events <- poisonPill }()
Expand All @@ -348,18 +369,17 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
var out io.Writer = &logWriter{
logger: logger,
prefixSuffixProvider: func() (string, string) {
return fmt.Sprintf("%s[%s] (%s) ", color(node.Name), node.Name, node.Phase), "\033[0m"
return fmt.Sprintf("%s[%s] (%s) ", color(node.Name), node.Name, node.getPhase()), "\033[0m"
},
}

logger := log.New(out, "", 0)

setNodeStatus := func(node *TaskNode, phase string, message string) {
node.Phase = phase
node.Message = message
node.setStatus(phase, message)
stallTimers[node.Name].Reset(node.Task.GetStalledTimeout())
setTerminalTitle(workflowTitle(name, subgraph.Nodes))
logger.Println(node.Message)
logger.Println(message)
statusEvents <- node
events <- poisonPill
}
Expand Down Expand Up @@ -461,8 +481,8 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
// so when we tail the log file, we see the output immediately
buf := funcWriter(func(p []byte) (int, error) {
stallTimers[node.Name].Reset(node.Task.GetStalledTimeout())
if node.Phase == "stalled" {
if strings.HasSuffix(node.Message, "starting") {
if node.getPhase() == "stalled" {
if strings.HasSuffix(node.getMessage(), "starting") {
setNodeStatus(node, "starting", "output received")
} else {
setNodeStatus(node, "running", "output received")
Expand Down Expand Up @@ -493,15 +513,15 @@ func RunSubgraph(ctx context.Context, cancel context.CancelFunc, port int, openB
case <-ctx.Done():
return
case <-ticker.C:
if node.Phase != "running" && node.Phase != "stalled" {
if ph := node.getPhase(); ph != "running" && ph != "stalled" {
continue
}
metrics, err := p.GetMetrics(ctx)
if err != nil {
logger.Printf("failed to get metrics: %v", err)
continue
}
node.Metrics = metrics
node.setMetrics(metrics)
statusEvents <- node
}
}
Expand Down
Loading
Loading