Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/tls-scanner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ func run(args []string) (exitCode int) {
return 0
}

policy := scanner.Policy()
policy, err := scanner.Policy()
if err != nil {
log.Printf("Error loading policy: %v", err)
return 1
}

defer func() {
if *timingFile != "" {
Expand Down Expand Up @@ -133,7 +137,6 @@ func run(args []string) (exitCode int) {
}

var client *k8s.Client
var err error
var pods []k8s.PodInfo

if *targets != "" {
Expand Down
9 changes: 4 additions & 5 deletions internal/scanner/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,18 @@ type ComponentPolicy struct {
// Policy returns the org-wide component policy embedded in the binary.
// It is the single source of truth for which TLS profile applies to each
// component type. To change the policy, edit policy.yaml and submit for review.
func Policy() *ComponentPolicy {
func Policy() (*ComponentPolicy, error) {
var p ComponentPolicy
if err := yaml.Unmarshal(embeddedPolicyYAML, &p); err != nil {
// policy.yaml is checked into source; a parse failure is a programming error.
panic(fmt.Sprintf("failed to parse embedded policy: %v", err))
return nil, fmt.Errorf("failed to parse embedded policy: %w", err)
}
for i := range p.Rules {
if err := p.Rules[i].compile(); err != nil {
panic(fmt.Sprintf("policy rule %d: %v", i, err))
return nil, fmt.Errorf("policy rule %d: %w", i, err)
}
}
p.warnShadowedRules()
return &p
return &p, nil
}

// warnShadowedRules logs a warning for each rule that can never be reached
Expand Down
13 changes: 11 additions & 2 deletions internal/scanner/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,17 @@ func TestPolicyRuleShadowing(t *testing.T) {
}
}

func testPolicy(t *testing.T) *ComponentPolicy {
t.Helper()
p, err := Policy()
if err != nil {
t.Fatalf("Policy() error: %v", err)
}
return p
}

func TestPolicy(t *testing.T) {
p := Policy()
p := testPolicy(t)
if p == nil {
t.Fatal("Policy() returned nil")
}
Expand Down Expand Up @@ -250,7 +259,7 @@ func TestPolicyResolve(t *testing.T) {
}

func TestPolicyBehaviour(t *testing.T) {
p := Policy()
p := testPolicy(t)

tests := []struct {
name string
Expand Down
6 changes: 3 additions & 3 deletions internal/scanner/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestScanWithMockTestSSL(t *testing.T) {
{IP: "10.0.0.1", Port: 443},
{IP: "10.0.0.2", Port: 8443},
}
results := Scan(jobs, 2, nil, nil, Policy())
results := Scan(jobs, 2, nil, nil, testPolicy(t))

if results.ScannedIPs != 2 {
t.Fatalf("expected 2 scanned IPs, got %d", results.ScannedIPs)
Expand All @@ -65,7 +65,7 @@ func TestScanPQCEnrichment(t *testing.T) {
testutil.InstallMockTestSSL(t)

jobs := []ScanJob{{IP: "10.0.0.1", Port: 443}}
results := Scan(jobs, 1, nil, nil, Policy())
results := Scan(jobs, 1, nil, nil, testPolicy(t))

pr := results.IPResults[0].PortResults[0]

Expand Down Expand Up @@ -98,7 +98,7 @@ func TestPerformClusterScanWithMockPods(t *testing.T) {
makePod("no-ports", "openshift-console", "10.128.0.30"),
}

results := PerformClusterScan(pods, 2, nil, Policy())
results := PerformClusterScan(pods, 2, nil, testPolicy(t))

if results.ScannedIPs != 3 {
t.Errorf("expected 3 scanned IPs (including no-ports), got %d", results.ScannedIPs)
Expand Down
8 changes: 4 additions & 4 deletions internal/scanner/testssl_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestIntegrationSingleTarget(t *testing.T) {
jobs := []ScanJob{{IP: tgt.ip, Port: tgt.port}}

start := time.Now()
results := batchScan(jobs, 1, nil, nil, Policy())
results := batchScan(jobs, 1, nil, nil, testPolicy(t))
elapsed := time.Since(start)

t.Logf("Single target %s (%s:%d): %v", tgt.desc, tgt.ip, tgt.port, elapsed)
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestIntegrationBatchTargets(t *testing.T) {
}

start := time.Now()
results := batchScan(jobs, 4, nil, nil, Policy())
results := batchScan(jobs, 4, nil, nil, testPolicy(t))
elapsed := time.Since(start)

t.Logf("Batch %d targets (MAX_PARALLEL=4): %v (%.1fs/target)",
Expand Down Expand Up @@ -102,12 +102,12 @@ func TestIntegrationParallelScaling(t *testing.T) {

t.Log("--- Sequential run (MAX_PARALLEL=1) ---")
start := time.Now()
seqResults := batchScan(jobs, 1, nil, nil, Policy())
seqResults := batchScan(jobs, 1, nil, nil, testPolicy(t))
sequential := time.Since(start)

t.Log("--- Parallel run (MAX_PARALLEL=", len(jobs), ") ---")
start = time.Now()
parResults := batchScan(jobs, len(jobs), nil, nil, Policy())
parResults := batchScan(jobs, len(jobs), nil, nil, testPolicy(t))
parallel := time.Since(start)

speedup := sequential.Seconds() / parallel.Seconds()
Expand Down