diff --git a/.vscode/settings.json b/.vscode/settings.json index b84dffae5..d7a16600d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -97,6 +97,7 @@ "netcat", "newrequest", "NEWTAG", + "nolint", "nosec", "NOSONAR", "Numerify", @@ -142,9 +143,11 @@ "samlsp", "samltypes", "sarama", + "SASL", "Satosa", "SATOSAV", "schac", + "SCRAMSHA", "SDJWT", "sdjwtvc", "sdktrace", @@ -152,6 +155,7 @@ "securego", "semconv", "setnx", + "setnxttl", "sexp", "shortuuid", "simplequeue", @@ -192,6 +196,7 @@ "vulncheck", "wrongverifier", "wwwallet", + "XDGSCRAM", "zapr", "zenghongtu", "zenor", diff --git a/internal/apigw/apiv1/handlers_oauth.go b/internal/apigw/apiv1/handlers_oauth.go index 3d972720a..b315f6f0b 100644 --- a/internal/apigw/apiv1/handlers_oauth.go +++ b/internal/apigw/apiv1/handlers_oauth.go @@ -182,28 +182,10 @@ func (c *Client) OAuthToken(ctx context.Context, req *openid4vci.TokenRequest) ( code = req.PreAuthorizedCode } - // Look up the client to enforce type-specific requirements (client_id is optional for pre-auth flow) - // When client_assertion is provided (private_key_jwt auth), extract client_id from the assertion's sub claim. - // SECURITY: ExtractClientIDFromAssertion only decodes the JWT payload — it does NOT verify the signature. - // The extracted client_id is used for both client config lookup (Clients.Get) AND client binding - // verification (WalletClientID check). Without signature verification, an attacker could forge - // the sub claim to impersonate another client. - // Full JWT assertion verification (signature, aud, exp, jti) is NOT yet implemented. - // TODO(security): Implement RFC 7523 private_key_jwt assertion verification - // (signature, aud, exp, jti) before enabling client_assertion in production. - // Tracked as a known risk — see risk register SID-RISK-CLIENT-ASSERTION. + // When client_assertion is provided (private_key_jwt auth), verify the assertion + // signature per RFC 7523 and extract client_id from the sub claim. clientID := req.ClientID if req.ClientAssertion != "" { - // SECURITY: client_assertion signature is NOT verified — only the payload is decoded. - // Reject unless explicitly opted in via allow_unverified_client_assertion config flag. - // Full RFC 7523 verification (signature, aud, exp, jti) must be implemented before - // removing this flag. Tracked in risk register SID-RISK-CLIENT-ASSERTION. - if !c.cfg.APIGW.Delivery.OpenID4VCI.AllowUnverifiedClientAssertion { - return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidRequest, - "client_assertion is not supported (RFC 7523 verification not implemented)", 400) - } - c.log.Warn("accepting unverified client_assertion (allow_unverified_client_assertion is enabled) — signature verification not implemented", - "client_assertion_type", req.ClientAssertionType) if req.ClientAssertionType == "" { return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidRequest, "client_assertion_type is required when client_assertion is provided", 400) @@ -212,18 +194,73 @@ func (c *Client) OAuthToken(ctx context.Context, req *openid4vci.TokenRequest) ( return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidRequest, fmt.Sprintf("unsupported client_assertion_type %q; expected urn:ietf:params:oauth:client-assertion-type:jwt-bearer", req.ClientAssertionType), 400) } - sub, err := oauth2.ExtractClientIDFromAssertion(req.ClientAssertion) - if err != nil { - c.log.Error(err, "failed to extract client_id from client_assertion") - return nil, oauth2.NewOAuthErrorWithCause(oauth2.ErrCodeInvalidClient, - "Invalid client assertion", 401, err) - } - if clientID != "" && clientID != sub { - return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidClient, - "client_id does not match assertion subject", 401) + + if c.cfg.APIGW.Delivery.OpenID4VCI.AllowUnverifiedClientAssertion { + // CONFORMANCE TESTING ONLY: accept assertion without signature verification. + c.log.Warn("accepting unverified client_assertion (allow_unverified_client_assertion is enabled)", + "client_assertion_type", req.ClientAssertionType) + sub, err := oauth2.ExtractClientIDFromAssertion(req.ClientAssertion) + if err != nil { + c.log.Error(err, "failed to extract client_id from client_assertion") + return nil, oauth2.NewOAuthErrorWithCause(oauth2.ErrCodeInvalidClient, + "Invalid client assertion", 401, err) + } + if clientID != "" && clientID != sub { + return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidClient, + "client_id does not match assertion subject", 401) + } + clientID = sub + } else { + // RFC 7523 full verification path: extract sub for client lookup, then verify signature. + sub, err := oauth2.ExtractClientIDFromAssertion(req.ClientAssertion) + if err != nil { + c.log.Error(err, "failed to extract client_id from client_assertion") + return nil, oauth2.NewOAuthErrorWithCause(oauth2.ErrCodeInvalidClient, + "Invalid client assertion", 401, err) + } + if clientID != "" && clientID != sub { + return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidClient, + "client_id does not match assertion subject", 401) + } + clientID = sub + + // Look up client to get JWKS URI for signature verification + oauthClientForVerify, err := c.cfg.APIGW.Delivery.OpenID4VCI.Clients.Get(clientID) + if err != nil { + return nil, oauth2.NewOAuthErrorWithCause(oauth2.ErrCodeInvalidClient, + "Client authentication failed", 401, err) + } + + verifier := &oauth2.ClientAssertionVerifier{ + TokenEndpoint: c.cfg.APIGW.Delivery.OpenID4VCI.TokenEndpoint, + JWKSCache: c.cacheService.JWKS, + JTICheck: func(jti string, exp time.Time) error { + // Scope by clientID to prevent cross-client collisions; + // use time.Until(exp) as TTL so entries expire with the assertion. + cacheKey := "client_assertion:" + clientID + ":" + jti + // Add the same leeway as the JWT verifier to tolerate small clock skews. + ttl := time.Until(exp.Add(30 * time.Second)) + if ttl <= 0 { + return errors.New("client_assertion jti has already expired") + } + unique, err := c.cacheService.DPopJTI.SetNXWithTTL(ctx, cacheKey, true, ttl) + if err != nil { + return fmt.Errorf("jti cache error: %w", err) + } + if !unique { + return errors.New("client_assertion jti already used") + } + return nil + }, + } + assertionClaims, err := verifier.Verify(ctx, req.ClientAssertion, oauthClientForVerify) + if err != nil { + c.log.Error(err, "client_assertion verification failed", "client_id", clientID) + return nil, oauth2.NewOAuthErrorWithCause(oauth2.ErrCodeInvalidClient, + "Client assertion verification failed", 401, err) + } + c.log.Debug("client_assertion verified", "client_id", clientID, "jti", assertionClaims.JTI) } - clientID = sub - c.log.Debug("OAuthToken: resolved client_id from client_assertion", "client_id", clientID) } else if req.ClientAssertionType != "" { return nil, oauth2.NewOAuthError(oauth2.ErrCodeInvalidRequest, "client_assertion_type provided without client_assertion", 400) diff --git a/internal/apigw/auth_providers/samlsp/service.go b/internal/apigw/auth_providers/samlsp/service.go index 357276325..c7cb89399 100644 --- a/internal/apigw/auth_providers/samlsp/service.go +++ b/internal/apigw/auth_providers/samlsp/service.go @@ -94,6 +94,25 @@ func New(ctx context.Context, cfg *model.SAMLSP, sessionCache pkgcache.Cache[*Se "cert_path", cfg.MetadataSigningCertPath) } + // Determine if the metadata source is MDQ or URL-based (requires signature verification) + requiresRemoteFetch := cfg.MDQServer != "" || + (cfg.StaticIDPMetadata != nil && cfg.StaticIDPMetadata.MetadataURL != "") + + // Enforce metadata signature verification for MDQ/URL sources unless explicitly opted out + if requiresRemoteFetch && metadataSigningCert == nil && !cfg.AllowUnsignedMetadata { + return nil, fmt.Errorf("metadata_signing_cert_path is required when using MDQ or URL-based metadata sources " + + "(set allow_unsigned_metadata: true to override — NOT RECOMMENDED for production)") + } + if requiresRemoteFetch && metadataSigningCert == nil && cfg.AllowUnsignedMetadata { + s.log.Warn("INSECURE: accepting unsigned metadata from remote sources (allow_unsigned_metadata is enabled)") + } + + // Warn for local file metadata without signature verification (allowed but noted) + if cfg.StaticIDPMetadata != nil && cfg.StaticIDPMetadata.MetadataPath != "" && metadataSigningCert == nil { + s.log.Warn("static IdP metadata loaded from local file without signature verification", + "path", cfg.StaticIDPMetadata.MetadataPath) + } + // Initialize MDQ client (either MDQ or static mode) if cfg.StaticIDPMetadata != nil { // Static IdP mode diff --git a/internal/apigw/outbound/kafka_message_publisher.go b/internal/apigw/outbound/kafka_message_publisher.go index 5faf97d1e..0523f97d8 100644 --- a/internal/apigw/outbound/kafka_message_publisher.go +++ b/internal/apigw/outbound/kafka_message_publisher.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "reflect" "github.com/SUNET/vc/internal/apigw/apiv1" @@ -22,7 +23,10 @@ type kafkaMessageProducer struct { // New creates a new instance of a kafka event publisher used by apigw func New(ctx context.Context, cfg *model.Cfg, tracer *trace.Tracer, log *logger.Log) (apiv1.EventPublisher, error) { - saramaConfig := kafka.CommonProducerConfig(cfg) + saramaConfig, err := kafka.CommonProducerConfig(cfg) + if err != nil { + return nil, fmt.Errorf("kafka producer security config: %w", err) + } client, err := kafka.NewSyncProducerClient(ctx, saramaConfig, cfg, tracer, log.New("kafka_message_producer_client")) if err != nil { return nil, err diff --git a/internal/verifier/apiv1/handlers_verification.go b/internal/verifier/apiv1/handlers_verification.go index cbb6371a1..3be58f36f 100644 --- a/internal/verifier/apiv1/handlers_verification.go +++ b/internal/verifier/apiv1/handlers_verification.go @@ -26,7 +26,7 @@ type VerificationRequestObjectRequest struct { } func (c *Client) VerificationRequestObject(ctx context.Context, req *VerificationRequestObjectRequest) (string, error) { - c.log.Debug("Verification request object", "req", req) + c.log.Debug("Verification request object", "id", req.ID) // Query by RequestObjectID since that's what the wallet sends via ?id= parameter authorizationContext, err := c.cacheService.AuthContext.Get(ctx, &cache.AuthorizationContext{ @@ -50,7 +50,7 @@ func (c *Client) VerificationRequestObject(ctx context.Context, req *Verificatio return "", err } - c.log.Debug("Signed JWT", "jwt", signedJWT) + c.log.Debug("Signed JWT created", "requestObjectID", authorizationContext.RequestObjectID) return signedJWT, nil } @@ -103,7 +103,7 @@ func (c *Client) VerificationDirectPost(ctx context.Context, req *VerificationDi return nil, err } - c.log.Debug("directPost", "vpResponse", vpResponse) + c.log.Debug("directPost", "state", vpResponse.State, "credential_count", len(vpResponse.VPToken)) // Get authorization context by state authCtx, err := c.cacheService.AuthContext.Get(ctx, &cache.AuthorizationContext{State: vpResponse.State}) diff --git a/pkg/cache/generic.go b/pkg/cache/generic.go index 4bce69200..a382b56b0 100644 --- a/pkg/cache/generic.go +++ b/pkg/cache/generic.go @@ -36,6 +36,12 @@ type Cache[V any] interface { // "already exists" from operational errors. SetNX(ctx context.Context, key string, value V) (bool, error) + // SetNXWithTTL stores a value only if the key does not already exist (atomic), + // using a custom TTL instead of the default. Returns true if the value was set, + // false if the key already existed. + // If ttl <= 0, implementations MUST fall back to SetNX (default TTL). + SetNXWithTTL(ctx context.Context, key string, value V, ttl time.Duration) (bool, error) + // SetWithTTL stores a value with a custom TTL, overriding the default. SetWithTTL(ctx context.Context, key string, value V, ttl time.Duration) diff --git a/pkg/cache/generic_memory.go b/pkg/cache/generic_memory.go index 7950315f1..c6c8169c4 100644 --- a/pkg/cache/generic_memory.go +++ b/pkg/cache/generic_memory.go @@ -46,6 +46,17 @@ func (m *MemoryCache[V]) SetNX(_ context.Context, key string, value V) (bool, er return !found, nil } +// SetNXWithTTL stores a value only if the key does not already exist, using a custom TTL. +// Returns true if the value was set, false if the key already existed. +// If ttl <= 0, falls back to SetNX (default TTL). +func (m *MemoryCache[V]) SetNXWithTTL(ctx context.Context, key string, value V, ttl time.Duration) (bool, error) { + if ttl <= 0 { + return m.SetNX(ctx, key, value) + } + _, found := m.cache.GetOrSet(key, value, ttlcache.WithTTL[string, V](ttl)) + return !found, nil +} + // SetWithTTL stores a value with a custom TTL. func (m *MemoryCache[V]) SetWithTTL(_ context.Context, key string, value V, ttl time.Duration) { m.cache.Set(key, value, ttl) diff --git a/pkg/cache/generic_mongo.go b/pkg/cache/generic_mongo.go index bb84179c1..5ce713532 100644 --- a/pkg/cache/generic_mongo.go +++ b/pkg/cache/generic_mongo.go @@ -148,6 +148,29 @@ func (m *MongoCache[V]) SetNX(ctx context.Context, key string, value V) (bool, e return true, nil } +// SetNXWithTTL stores a value only if the key does not already exist (atomic), +// using a custom TTL approximated via created_at shifting (same as SetWithTTL). +// If ttl <= 0, falls back to SetNX (default TTL). +func (m *MongoCache[V]) SetNXWithTTL(ctx context.Context, key string, value V, ttl time.Duration) (bool, error) { + if ttl <= 0 { + return m.SetNX(ctx, key, value) + } + shift := m.ttl - ttl + createdAt := time.Now().Add(-shift) + entry, err := m.marshalEntry(key, value, createdAt) + if err != nil { + return false, fmt.Errorf("mongo cache setnxttl marshal failed (cache=%s): %w", m.collection, err) + } + _, err = m.coll.InsertOne(ctx, entry) + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return false, nil + } + return false, fmt.Errorf("mongo cache setnxttl failed (cache=%s): %w", m.collection, err) + } + return true, nil +} + // SetWithTTL stores a value with a custom TTL. // MongoDB TTL indexes are collection-wide, so per-entry TTL is approximated // by shifting created_at: the document expires when diff --git a/pkg/cache/generic_test.go b/pkg/cache/generic_test.go index e90f766b7..1341c5b14 100644 --- a/pkg/cache/generic_test.go +++ b/pkg/cache/generic_test.go @@ -78,6 +78,62 @@ func runGenericCacheContractTests[V comparable](t *testing.T, c Cache[V], val1, fresh := c.Len() assert.GreaterOrEqual(t, fresh, 0) }) + + t.Run("SetNX_New", func(t *testing.T) { + ok, err := c.SetNX(ctx, "nx-new", val1) + require.NoError(t, err) + assert.True(t, ok, "SetNX on a new key should succeed") + got, found := c.Get(ctx, "nx-new") + require.True(t, found) + assert.Equal(t, val1, got) + }) + + t.Run("SetNX_Existing", func(t *testing.T) { + c.Set(ctx, "nx-exist", val1) + ok, err := c.SetNX(ctx, "nx-exist", val2) + require.NoError(t, err) + assert.False(t, ok, "SetNX on existing key should return false") + got, found := c.Get(ctx, "nx-exist") + require.True(t, found) + assert.Equal(t, val1, got, "value should not be overwritten") + }) + + t.Run("SetNXWithTTL_New", func(t *testing.T) { + ok, err := c.SetNXWithTTL(ctx, "nxttl-new", val1, 1*time.Hour) + require.NoError(t, err) + assert.True(t, ok, "SetNXWithTTL on a new key should succeed") + got, found := c.Get(ctx, "nxttl-new") + require.True(t, found) + assert.Equal(t, val1, got) + }) + + t.Run("SetNXWithTTL_Existing", func(t *testing.T) { + c.Set(ctx, "nxttl-exist", val1) + ok, err := c.SetNXWithTTL(ctx, "nxttl-exist", val2, 1*time.Hour) + require.NoError(t, err) + assert.False(t, ok, "SetNXWithTTL on existing key should return false") + got, found := c.Get(ctx, "nxttl-exist") + require.True(t, found) + assert.Equal(t, val1, got, "value should not be overwritten") + }) + + t.Run("SetNXWithTTL_ZeroTTL_FallsBackToDefault", func(t *testing.T) { + ok, err := c.SetNXWithTTL(ctx, "nxttl-zero", val1, 0) + require.NoError(t, err) + assert.True(t, ok, "SetNXWithTTL with zero TTL should succeed on new key") + got, found := c.Get(ctx, "nxttl-zero") + require.True(t, found) + assert.Equal(t, val1, got) + }) + + t.Run("SetNXWithTTL_NegativeTTL_FallsBackToDefault", func(t *testing.T) { + ok, err := c.SetNXWithTTL(ctx, "nxttl-neg", val1, -5*time.Second) + require.NoError(t, err) + assert.True(t, ok, "SetNXWithTTL with negative TTL should succeed on new key") + got, found := c.Get(ctx, "nxttl-neg") + require.True(t, found) + assert.Equal(t, val1, got) + }) } // --- MemoryCache type-specific tests --- @@ -146,6 +202,45 @@ func TestMemoryCache_TTLExpiration(t *testing.T) { assert.False(t, ok, "expected item to expire") } +func TestMemoryCache_SetNXWithTTL_Expiration(t *testing.T) { + ctx := context.Background() + c := NewMemoryCache[string](5 * time.Minute) + + // Store with a short custom TTL. + ok, err := c.SetNXWithTTL(ctx, "nx-expire", "val", 50*time.Millisecond) + require.NoError(t, err) + require.True(t, ok) + + // Value should be readable immediately. + got, found := c.Get(ctx, "nx-expire") + require.True(t, found) + assert.Equal(t, "val", got) + + // After the custom TTL elapses, the value should be gone. + time.Sleep(150 * time.Millisecond) + _, found = c.Get(ctx, "nx-expire") + assert.False(t, found, "expected item to expire after custom TTL") +} + +func TestMemoryCache_SetNXWithTTL_ZeroTTL_UsesDefault(t *testing.T) { + ctx := context.Background() + // Default TTL = 50ms. Zero TTL should fall back to default, not infinite. + c := NewMemoryCache[string](50 * time.Millisecond) + + ok, err := c.SetNXWithTTL(ctx, "nx-zero-ttl", "val", 0) + require.NoError(t, err) + require.True(t, ok) + + // Value should be readable immediately. + _, found := c.Get(ctx, "nx-zero-ttl") + require.True(t, found) + + // After default TTL elapses, value should expire. + time.Sleep(150 * time.Millisecond) + _, found = c.Get(ctx, "nx-zero-ttl") + assert.False(t, found, "zero TTL should fall back to default TTL, not infinite") +} + func TestMemoryCache_Stop(t *testing.T) { c := NewMemoryCache[string](5 * time.Minute) c.Stop() // should not panic @@ -338,3 +433,57 @@ func TestMongoCache_SetWithTTL_SameAsDefault(t *testing.T) { assert.InDelta(t, 0, age, float64(5*time.Second), "created_at should be approximately now when TTL equals default") } + +func TestMongoCache_SetNXWithTTL_CreatedAtShift(t *testing.T) { + client, cleanup := startMongoContainer(t) + defer cleanup() + + ctx := t.Context() + // Collection TTL = 10 minutes. + c, err := NewMongoCache[string](ctx, client, "test_generic", "cache_nxttl_shift", 10*time.Minute, nil) + require.NoError(t, err) + + // SetNXWithTTL with 3-minute TTL should shift created_at 7 minutes into the past. + ok, err := c.SetNXWithTTL(ctx, "nx-shift", "val", 3*time.Minute) + require.NoError(t, err) + require.True(t, ok) + + coll := client.Database("test_generic").Collection("cache_nxttl_shift") + var doc bson.M + err = coll.FindOne(ctx, bson.M{"_id": "nx-shift"}).Decode(&doc) + require.NoError(t, err) + + createdAt := doc["created_at"].(bson.DateTime).Time() + age := time.Since(createdAt) + assert.InDelta(t, 7*time.Minute, age, float64(5*time.Second), + "created_at should be shifted ~7 minutes into the past for 3-minute TTL") + + // Value should still be readable. + got, found := c.Get(ctx, "nx-shift") + require.True(t, found) + assert.Equal(t, "val", got) +} + +func TestMongoCache_SetNXWithTTL_ZeroTTL_NoShift(t *testing.T) { + client, cleanup := startMongoContainer(t) + defer cleanup() + + ctx := t.Context() + c, err := NewMongoCache[string](ctx, client, "test_generic", "cache_nxttl_zero", 10*time.Minute, nil) + require.NoError(t, err) + + // Zero TTL should fall back to SetNX (default TTL), meaning created_at ~now. + ok, err := c.SetNXWithTTL(ctx, "nx-zero", "val", 0) + require.NoError(t, err) + require.True(t, ok) + + coll := client.Database("test_generic").Collection("cache_nxttl_zero") + var doc bson.M + err = coll.FindOne(ctx, bson.M{"_id": "nx-zero"}).Decode(&doc) + require.NoError(t, err) + + createdAt := doc["created_at"].(bson.DateTime).Time() + age := time.Since(createdAt) + assert.InDelta(t, 0, age, float64(5*time.Second), + "zero TTL should fall back to default (created_at ~now)") +} diff --git a/pkg/messagebroker/kafka/consumer.go b/pkg/messagebroker/kafka/consumer.go index 2000143e5..c1a1ce22d 100644 --- a/pkg/messagebroker/kafka/consumer.go +++ b/pkg/messagebroker/kafka/consumer.go @@ -41,8 +41,13 @@ type MessageConsumerClient struct { } func NewConsumerClient(ctx context.Context, cfg *model.Cfg, brokers []string, log *logger.Log) (*MessageConsumerClient, error) { + saramaConfig, err := commonConsumerConfig(cfg) + if err != nil { + return nil, fmt.Errorf("kafka consumer security config: %w", err) + } + client := &MessageConsumerClient{ - SaramaConfig: commonConsumerConfig(cfg), + SaramaConfig: saramaConfig, brokers: brokers, wg: sync.WaitGroup{}, log: log, @@ -52,14 +57,16 @@ func NewConsumerClient(ctx context.Context, cfg *model.Cfg, brokers []string, lo } // commonConsumerConfig returns a new Kafka consumer configuration instance with sane defaults for vc. -func commonConsumerConfig(cfg *model.Cfg) *sarama.Config { - // TODO: set cfg from file - is now hardcoded +func commonConsumerConfig(cfg *model.Cfg) (*sarama.Config, error) { saramaConfig := sarama.NewConfig() saramaConfig.Consumer.Offsets.Initial = sarama.OffsetOldest saramaConfig.Consumer.Group.Rebalance.GroupStrategies = []sarama.BalanceStrategy{sarama.NewBalanceStrategyRange()} - saramaConfig.Net.SASL.Enable = false - // TODO: enable and configure security when consuming from Kafka - return saramaConfig + + if err := applySecurityConfig(saramaConfig, cfg); err != nil { + return nil, err + } + + return saramaConfig, nil } // Start starts the actual event consuming from specified kafka topics diff --git a/pkg/messagebroker/kafka/producer.go b/pkg/messagebroker/kafka/producer.go index 6ab5623a0..ee5e3860c 100644 --- a/pkg/messagebroker/kafka/producer.go +++ b/pkg/messagebroker/kafka/producer.go @@ -48,17 +48,19 @@ func NewSyncProducerClient(ctx context.Context, saramaConfig *sarama.Config, cfg } // CommonProducerConfig returns a new Kafka producer configuration instance with sane defaults for vc. -func CommonProducerConfig(cfg *model.Cfg) *sarama.Config { - // TODO(mk): set cfg from file - is now hardcoded +func CommonProducerConfig(cfg *model.Cfg) (*sarama.Config, error) { saramaConfig := sarama.NewConfig() saramaConfig.Producer.Return.Successes = true saramaConfig.Producer.RequiredAcks = sarama.WaitForAll saramaConfig.Producer.Idempotent = true saramaConfig.Net.MaxOpenRequests = 1 saramaConfig.Producer.Retry.Max = 3 - saramaConfig.Net.SASL.Enable = false - // TODO(mk): enable and configure security when publishing to Kafka - return saramaConfig + + if err := applySecurityConfig(saramaConfig, cfg); err != nil { + return nil, err + } + + return saramaConfig, nil } // Close close the producer diff --git a/pkg/messagebroker/kafka/scram.go b/pkg/messagebroker/kafka/scram.go new file mode 100644 index 000000000..86312b637 --- /dev/null +++ b/pkg/messagebroker/kafka/scram.go @@ -0,0 +1,41 @@ +package kafka + +import ( + "crypto/sha256" + "crypto/sha512" + "hash" + + "github.com/xdg-go/scram" +) + +// SHA256 is a hash generator function for SCRAM-SHA-256. +var SHA256 scram.HashGeneratorFcn = func() hash.Hash { return sha256.New() } + +// SHA512 is a hash generator function for SCRAM-SHA-512. +var SHA512 scram.HashGeneratorFcn = func() hash.Hash { return sha512.New() } + +// XDGSCRAMClient implements the sarama.SCRAMClient interface using xdg-go/scram. +type XDGSCRAMClient struct { + *scram.ClientConversation + scram.HashGeneratorFcn +} + +// Begin starts a new SCRAM conversation. +func (x *XDGSCRAMClient) Begin(userName, password, authzID string) (err error) { + client, err := x.HashGeneratorFcn.NewClient(userName, password, authzID) + if err != nil { + return err + } + x.ClientConversation = client.NewConversation() + return nil +} + +// Step advances the SCRAM conversation by one step. +func (x *XDGSCRAMClient) Step(challenge string) (response string, err error) { + return x.ClientConversation.Step(challenge) +} + +// Done returns true if the conversation is complete. +func (x *XDGSCRAMClient) Done() bool { + return x.ClientConversation.Done() +} diff --git a/pkg/messagebroker/kafka/scram_test.go b/pkg/messagebroker/kafka/scram_test.go new file mode 100644 index 000000000..86cf1bd79 --- /dev/null +++ b/pkg/messagebroker/kafka/scram_test.go @@ -0,0 +1,106 @@ +package kafka + +import ( + "testing" + + "github.com/xdg-go/scram" +) + +func TestHashGenerators(t *testing.T) { + tests := []struct { + name string + gen scram.HashGeneratorFcn + wantDigest int + }{ + {name: "SHA256", gen: SHA256, wantDigest: 32}, + {name: "SHA512", gen: SHA512, wantDigest: 64}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := tt.gen() + if h == nil { + t.Fatal("hash generator returned nil") + } + h.Write([]byte("test")) + if got := len(h.Sum(nil)); got != tt.wantDigest { + t.Fatalf("digest length = %d, want %d", got, tt.wantDigest) + } + }) + } +} + +func TestXDGSCRAMClient_Begin(t *testing.T) { + tests := []struct { + name string + hashFcn scram.HashGeneratorFcn + user string + pass string + authzID string + }{ + {name: "SHA256", hashFcn: SHA256, user: "user", pass: "password", authzID: ""}, + {name: "SHA512", hashFcn: SHA512, user: "user", pass: "password", authzID: ""}, + {name: "SHA512_with_authzID", hashFcn: SHA512, user: "user", pass: "password", authzID: "authz-user"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: tt.hashFcn} + if err := client.Begin(tt.user, tt.pass, tt.authzID); err != nil { + t.Fatalf("Begin() error = %v", err) + } + if client.ClientConversation == nil { + t.Fatal("Begin() did not initialize ClientConversation") + } + }) + } +} + +func TestXDGSCRAMClient_Step(t *testing.T) { + tests := []struct { + name string + hashFcn scram.HashGeneratorFcn + }{ + {name: "SHA256", hashFcn: SHA256}, + {name: "SHA512", hashFcn: SHA512}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: tt.hashFcn} + if err := client.Begin("user", "password", ""); err != nil { + t.Fatalf("Begin() error = %v", err) + } + + resp, err := client.Step("") + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if resp == "" { + t.Fatal("Step() returned empty client-first message") + } + }) + } +} + +func TestXDGSCRAMClient_Done(t *testing.T) { + tests := []struct { + name string + hashFcn scram.HashGeneratorFcn + }{ + {name: "SHA256", hashFcn: SHA256}, + {name: "SHA512", hashFcn: SHA512}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: tt.hashFcn} + if err := client.Begin("user", "password", ""); err != nil { + t.Fatalf("Begin() error = %v", err) + } + if client.Done() { + t.Fatal("Done() = true immediately after Begin(), want false") + } + }) + } +} diff --git a/pkg/messagebroker/kafka/security.go b/pkg/messagebroker/kafka/security.go new file mode 100644 index 000000000..5f4dbfb40 --- /dev/null +++ b/pkg/messagebroker/kafka/security.go @@ -0,0 +1,81 @@ +package kafka + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "path/filepath" + + "github.com/SUNET/vc/pkg/model" + + "github.com/IBM/sarama" +) + +// applySecurityConfig applies SASL and TLS settings from cfg to the Sarama configuration. +func applySecurityConfig(saramaConfig *sarama.Config, cfg *model.Cfg) error { + if cfg == nil || cfg.Common == nil { + return nil + } + kafka := &cfg.Common.Kafka + + // SASL + if kafka.SASL != nil && kafka.SASL.Enable { + saramaConfig.Net.SASL.Enable = true + saramaConfig.Net.SASL.User = kafka.SASL.Username + saramaConfig.Net.SASL.Password = kafka.SASL.Password + + switch kafka.SASL.Mechanism { + case "SCRAM-SHA-256": + saramaConfig.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA256} } + saramaConfig.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA256 + case "SCRAM-SHA-512": + saramaConfig.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA512} } + saramaConfig.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA512 + case "PLAIN": + saramaConfig.Net.SASL.Mechanism = sarama.SASLTypePlaintext + default: + return fmt.Errorf("unsupported SASL mechanism %q; supported values are SCRAM-SHA-256, SCRAM-SHA-512, PLAIN", kafka.SASL.Mechanism) + } + } + + // mTLS + if kafka.MTLS.Enable { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // Load CA cert if provided + if kafka.MTLS.CACertPath != "" { + caCert, err := os.ReadFile(filepath.Clean(kafka.MTLS.CACertPath)) + if err != nil { + return fmt.Errorf("reading CA cert %q: %w", kafka.MTLS.CACertPath, err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return fmt.Errorf("parsing CA cert %q: no valid certificates found", kafka.MTLS.CACertPath) + } + tlsConfig.RootCAs = caCertPool + } + + // Load client cert/key for mTLS — required when mTLS is enabled + if kafka.MTLS.CertFilePath == "" || kafka.MTLS.KeyFilePath == "" { + return fmt.Errorf("kafka mTLS is enabled but cert_file_path and key_file_path must both be set") + } + cert, err := tls.LoadX509KeyPair( + filepath.Clean(kafka.MTLS.CertFilePath), + filepath.Clean(kafka.MTLS.KeyFilePath), + ) + if err != nil { + return fmt.Errorf("loading client cert/key (%q, %q): %w", kafka.MTLS.CertFilePath, kafka.MTLS.KeyFilePath, err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + + tlsConfig.InsecureSkipVerify = kafka.MTLS.InsecureSkipVerify //nolint:gosec // configurable for testing only + + saramaConfig.Net.TLS.Enable = true + saramaConfig.Net.TLS.Config = tlsConfig + } + + return nil +} diff --git a/pkg/messagebroker/kafka/security_test.go b/pkg/messagebroker/kafka/security_test.go new file mode 100644 index 000000000..0a791827c --- /dev/null +++ b/pkg/messagebroker/kafka/security_test.go @@ -0,0 +1,435 @@ +package kafka + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/IBM/sarama" + "github.com/SUNET/vc/pkg/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeTempCertAndKey generates a self-signed cert+key pair and writes them to dir. +// Returns (certPath, keyPath, caCertPath). +func writeTempCertAndKey(t *testing.T, dir string) (string, string, string) { + t.Helper() + + // Generate CA key + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + } + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + caCertPath := filepath.Join(dir, "ca.pem") + require.NoError(t, os.WriteFile(caCertPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER}), 0o600)) + + // Generate client cert signed by CA + clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "Test Client"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + clientDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) + require.NoError(t, err) + + certPath := filepath.Join(dir, "client.pem") + require.NoError(t, os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientDER}), 0o600)) + + keyDER, err := x509.MarshalECPrivateKey(clientKey) + require.NoError(t, err) + keyPath := filepath.Join(dir, "client-key.pem") + require.NoError(t, os.WriteFile(keyPath, pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}), 0o600)) + + return certPath, keyPath, caCertPath +} + +func TestApplySecurityConfig_NilConfig(t *testing.T) { + sc := sarama.NewConfig() + err := applySecurityConfig(sc, nil) + require.NoError(t, err) + assert.False(t, sc.Net.SASL.Enable) + assert.False(t, sc.Net.TLS.Enable) +} + +func TestApplySecurityConfig_NilCommon(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{Common: nil} + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.False(t, sc.Net.SASL.Enable) + assert.False(t, sc.Net.TLS.Enable) +} + +func TestApplySecurityConfig_SASL_SCRAM256(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: true, + Mechanism: "SCRAM-SHA-256", + Username: "user", + Password: "pass", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.SASL.Enable) + assert.Equal(t, sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA256), sc.Net.SASL.Mechanism) + assert.Equal(t, "user", sc.Net.SASL.User) + assert.Equal(t, "pass", sc.Net.SASL.Password) + assert.NotNil(t, sc.Net.SASL.SCRAMClientGeneratorFunc) +} + +func TestApplySecurityConfig_SASL_SCRAM512(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: true, + Mechanism: "SCRAM-SHA-512", + Username: "u", + Password: "p", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.SASL.Enable) + assert.Equal(t, sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA512), sc.Net.SASL.Mechanism) + assert.NotNil(t, sc.Net.SASL.SCRAMClientGeneratorFunc) +} + +func TestApplySecurityConfig_SASL_PLAIN(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: true, + Mechanism: "PLAIN", + Username: "u", + Password: "p", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.SASL.Enable) + assert.Equal(t, sarama.SASLMechanism(sarama.SASLTypePlaintext), sc.Net.SASL.Mechanism) +} + +func TestApplySecurityConfig_SASL_UnsupportedMechanism(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: true, + Mechanism: "OAUTHBEARER", + Username: "u", + Password: "p", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported SASL mechanism") + assert.Contains(t, err.Error(), "OAUTHBEARER") +} + +func TestApplySecurityConfig_SASL_Disabled(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: false, + Mechanism: "SCRAM-SHA-512", + Username: "u", + Password: "p", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.False(t, sc.Net.SASL.Enable, "SASL should remain disabled when Enable=false") +} + +func TestApplySecurityConfig_MTLS_Valid(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, caCertPath := writeTempCertAndKey(t, dir) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CACertPath: caCertPath, + CertFilePath: certPath, + KeyFilePath: keyPath, + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.TLS.Enable) + require.NotNil(t, sc.Net.TLS.Config) + assert.NotNil(t, sc.Net.TLS.Config.RootCAs) + assert.Len(t, sc.Net.TLS.Config.Certificates, 1) + assert.False(t, sc.Net.TLS.Config.InsecureSkipVerify) +} + +func TestApplySecurityConfig_MTLS_NoCA(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeTempCertAndKey(t, dir) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CACertPath: "", // no CA — uses system roots + CertFilePath: certPath, + KeyFilePath: keyPath, + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.TLS.Enable) + assert.Nil(t, sc.Net.TLS.Config.RootCAs, "RootCAs should be nil when no CA path is given") + assert.Len(t, sc.Net.TLS.Config.Certificates, 1) +} + +func TestApplySecurityConfig_MTLS_MissingCertPath(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CertFilePath: "", + KeyFilePath: "/some/key.pem", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "cert_file_path and key_file_path must both be set") +} + +func TestApplySecurityConfig_MTLS_MissingKeyPath(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CertFilePath: "/some/cert.pem", + KeyFilePath: "", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "cert_file_path and key_file_path must both be set") +} + +func TestApplySecurityConfig_MTLS_CACertNotFound(t *testing.T) { + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CACertPath: "/nonexistent/ca.pem", + CertFilePath: "/some/cert.pem", + KeyFilePath: "/some/key.pem", + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading CA cert") +} + +func TestApplySecurityConfig_MTLS_CACertInvalidPEM(t *testing.T) { + dir := t.TempDir() + badCA := filepath.Join(dir, "bad-ca.pem") + require.NoError(t, os.WriteFile(badCA, []byte("not a valid PEM"), 0o600)) + + certPath, keyPath, _ := writeTempCertAndKey(t, dir) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CACertPath: badCA, + CertFilePath: certPath, + KeyFilePath: keyPath, + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid certificates found") +} + +func TestApplySecurityConfig_MTLS_InvalidKeyPair(t *testing.T) { + dir := t.TempDir() + _, _, caCertPath := writeTempCertAndKey(t, dir) + + // Write a cert but pair it with a different key + certPath := filepath.Join(dir, "mismatch-cert.pem") + keyPath := filepath.Join(dir, "mismatch-key.pem") + + // Generate a separate key (not matching the cert) + key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Self-signed cert with key1 + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(99), + Subject: pkix.Name{CommonName: "mismatch"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key1.PublicKey, key1) + require.NoError(t, err) + require.NoError(t, os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), 0o600)) + + // Write key2 (does not match cert) + keyDER, err := x509.MarshalECPrivateKey(key2) + require.NoError(t, err) + require.NoError(t, os.WriteFile(keyPath, pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}), 0o600)) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CACertPath: caCertPath, + CertFilePath: certPath, + KeyFilePath: keyPath, + }, + }, + }, + } + + err = applySecurityConfig(sc, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "loading client cert/key") +} + +func TestApplySecurityConfig_MTLS_InsecureSkipVerify(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeTempCertAndKey(t, dir) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + MTLS: model.MTLS{ + Enable: true, + CertFilePath: certPath, + KeyFilePath: keyPath, + InsecureSkipVerify: true, + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.TLS.Config.InsecureSkipVerify) +} + +func TestApplySecurityConfig_SASL_And_MTLS_Combined(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, caCertPath := writeTempCertAndKey(t, dir) + + sc := sarama.NewConfig() + cfg := &model.Cfg{ + Common: &model.Common{ + Kafka: model.Kafka{ + SASL: &model.KafkaSASL{ + Enable: true, + Mechanism: "SCRAM-SHA-512", + Username: "user", + Password: "pass", + }, + MTLS: model.MTLS{ + Enable: true, + CACertPath: caCertPath, + CertFilePath: certPath, + KeyFilePath: keyPath, + }, + }, + }, + } + + err := applySecurityConfig(sc, cfg) + require.NoError(t, err) + assert.True(t, sc.Net.SASL.Enable) + assert.Equal(t, sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA512), sc.Net.SASL.Mechanism) + assert.True(t, sc.Net.TLS.Enable) + assert.Len(t, sc.Net.TLS.Config.Certificates, 1) +} diff --git a/pkg/model/config.go b/pkg/model/config.go index fcef31318..35d1e7d71 100644 --- a/pkg/model/config.go +++ b/pkg/model/config.go @@ -57,14 +57,28 @@ type CORS struct { AllowedOrigins []string `yaml:"allowed_origins" validate:"omitempty" default:"[]" doc_example:"[\"https://wallet.sunet.se\", \"https://app.sunet.se\"]"` } -// TLS holds the TLS configuration +// TLS holds server-side TLS configuration (presenting a certificate to clients) type TLS struct { // Enable enables TLS Enable bool `yaml:"enable" default:"false"` // CertFilePath is the path to the TLS certificate - CertFilePath string `yaml:"cert_file_path" validate:"required"` + CertFilePath string `yaml:"cert_file_path" validate:"required_if=Enable true"` // KeyFilePath is the path to the TLS private key - KeyFilePath string `yaml:"key_file_path" validate:"required"` + KeyFilePath string `yaml:"key_file_path" validate:"required_if=Enable true"` +} + +// MTLS holds mutual TLS configuration for client connections (verifying peer + presenting own cert) +type MTLS struct { + // Enable enables mTLS for the connection + Enable bool `yaml:"enable" default:"false"` + // CACertPath is the path to a CA certificate for verifying the remote peer (optional; uses system roots if empty) + CACertPath string `yaml:"ca_cert_path,omitempty"` + // CertFilePath is the path to a client certificate for mutual authentication + CertFilePath string `yaml:"cert_file_path" validate:"required_if=Enable true"` + // KeyFilePath is the path to the client private key + KeyFilePath string `yaml:"key_file_path" validate:"required_if=Enable true"` + // InsecureSkipVerify disables certificate verification (TESTING ONLY — never use in production) + InsecureSkipVerify bool `yaml:"insecure_skip_verify" default:"false"` } // Mongo holds the MongoDB configuration @@ -98,7 +112,23 @@ type Kafka struct { // Enable enables Kafka integration Enable bool `yaml:"enable" default:"false"` // Brokers is the list of Kafka broker addresses - Brokers []string `yaml:"brokers" validate:"required" doc_example:"[\"kafka0:9092\", \"kafka1:9092\"]"` + Brokers []string `yaml:"brokers" validate:"required_if=Enable true" doc_example:"[\"kafka0:9092\", \"kafka1:9092\"]"` + // SASL configures SASL authentication for Kafka connections + SASL *KafkaSASL `yaml:"sasl,omitempty"` + // MTLS configures mutual TLS (mTLS) for Kafka broker connections + MTLS MTLS `yaml:"mtls" validate:"omitempty"` +} + +// KafkaSASL holds SASL authentication settings for Kafka +type KafkaSASL struct { + // Enable activates SASL authentication + Enable bool `yaml:"enable" default:"false"` + // Mechanism is the SASL mechanism (PLAIN, SCRAM-SHA-256, SCRAM-SHA-512) + Mechanism string `yaml:"mechanism" validate:"required_if=Enable true,omitempty,oneof=PLAIN SCRAM-SHA-256 SCRAM-SHA-512" default:"SCRAM-SHA-512"` + // Username is the SASL username + Username string `yaml:"username" validate:"required_if=Enable true"` + // Password is the SASL password + Password string `yaml:"password" validate:"required_if=Enable true"` } // Log holds the logging configuration @@ -232,6 +262,12 @@ type SAMLSP struct { // carry a valid XML signature from this certificate. MetadataSigningCertPath string `yaml:"metadata_signing_cert_path,omitempty"` + // AllowUnsignedMetadata permits MDQ/URL metadata without signature verification. + // This is INSECURE (MITM → fake IdP) and should only be used in development. + // When false (default), MDQ and URL metadata sources require MetadataSigningCertPath. + // Local metadata files are allowed unsigned regardless (with a startup warning). + AllowUnsignedMetadata bool `yaml:"allow_unsigned_metadata" default:"false"` + // MetadataCacheTTL in seconds (default: 3600) - how long to cache IdP metadata from MDQ MetadataCacheTTL int `yaml:"metadata_cache_ttl"` } diff --git a/pkg/oauth2/client_assertion_verifier.go b/pkg/oauth2/client_assertion_verifier.go new file mode 100644 index 000000000..edbf7ad77 --- /dev/null +++ b/pkg/oauth2/client_assertion_verifier.go @@ -0,0 +1,245 @@ +package oauth2 + +import ( + "context" + "crypto" + "encoding/json" + "errors" + "fmt" + "slices" + "time" + + "github.com/SUNET/vc/pkg/cache" + "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v3/jwk" +) + +// ClientAssertionClaims holds the validated claims from a client_assertion JWT (RFC 7523 §3). +type ClientAssertionClaims struct { + Issuer string + Subject string + Audience string + JTI string + IssuedAt time.Time + Expiry time.Time +} + +// ClientAssertionVerifier verifies client_assertion JWTs per RFC 7523. +type ClientAssertionVerifier struct { + // AllowedAlgorithms is the set of permitted signing algorithms (e.g. RS256, ES256). + AllowedAlgorithms []string + // TokenEndpoint is the expected audience value (the token endpoint URL). + TokenEndpoint string + // MaxLifetime is the maximum allowed lifetime of the assertion (exp - iat). + // Defaults to 5 minutes if zero. + MaxLifetime time.Duration + // JTICheck is called to verify that the jti has not been replayed. + // Returns an error if the jti was already seen. May be nil to skip replay checks. + JTICheck func(jti string, exp time.Time) error + // JWKSCache caches raw JWKS JSON keyed by the client's jwks_uri. + // When non-nil, avoids fetching the JWKS on every token request. + // When nil, falls back to fetching the JWKS directly on each call. + JWKSCache cache.Cache[[]byte] +} + +// defaultAllowedAlgorithms is the allowlist of signing algorithms accepted for client assertions. +var defaultAllowedAlgorithms = []string{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"} + +// Verify verifies a client_assertion JWT against the client's JWKS, validating +// the signature, audience, issuer/subject, expiration, and jti (replay protection). +// Returns the validated claims or an error. +func (v *ClientAssertionVerifier) Verify(ctx context.Context, assertion string, client *Client) (*ClientAssertionClaims, error) { + if client == nil { + return nil, errors.New("client is nil") + } + if client.JWKSURI == "" { + return nil, errors.New("client has no jwks_uri configured for assertion verification") + } + if v.TokenEndpoint == "" { + return nil, errors.New("token endpoint (audience) is not configured for assertion verification") + } + + // Fetch the client's JWKS (with optional cache) + keySet, err := v.fetchJWKS(ctx, client.JWKSURI) + if err != nil { + return nil, fmt.Errorf("failed to fetch client JWKS from %s: %w", client.JWKSURI, err) + } + + allowedAlgs := v.AllowedAlgorithms + if len(allowedAlgs) == 0 { + allowedAlgs = defaultAllowedAlgorithms + } + + // Parse and verify the JWT + token, err := jwt.Parse(assertion, func(token *jwt.Token) (any, error) { + // Reject "none" algorithm unconditionally + if token.Method.Alg() == "none" { + return nil, errors.New("algorithm 'none' is not allowed") + } + + // Check algorithm allowlist + alg := token.Method.Alg() + if !slices.Contains(allowedAlgs, alg) { + return nil, fmt.Errorf("algorithm %q is not in the allowed set", alg) + } + + // Find matching key in JWKS by kid (if present in header) + kid, _ := token.Header["kid"].(string) + if kid != "" { + if matchedKey, ok := keySet.LookupKeyID(kid); ok { + // If the key declares an algorithm, enforce it to avoid algorithm confusion. + if kAlg, hasAlg := matchedKey.Algorithm(); hasAlg && kAlg.String() != alg { + return nil, fmt.Errorf("key %q algorithm %q does not match token algorithm %q", kid, kAlg.String(), alg) + } + var rawKey crypto.PublicKey + if err := jwk.Export(matchedKey, &rawKey); err != nil { + return nil, fmt.Errorf("failed to extract raw key: %w", err) + } + return rawKey, nil + } + } + + // No kid match — collect all candidate keys and return a VerificationKeySet + // so the parser can try each key until one verifies. + var keys []jwt.VerificationKey + for i := 0; i < keySet.Len(); i++ { + k, ok := keySet.Key(i) + if !ok { + continue + } + // If the key declares an algorithm, skip keys that don't match + if kAlg, hasAlg := k.Algorithm(); hasAlg && kAlg.String() != alg { + continue + } + var rawKey crypto.PublicKey + if err := jwk.Export(k, &rawKey); err != nil { + continue + } + keys = append(keys, rawKey) + } + if len(keys) == 0 { + return nil, errors.New("no suitable key found in client JWKS") + } + return jwt.VerificationKeySet{Keys: keys}, nil + }, + jwt.WithValidMethods(allowedAlgs), + jwt.WithAudience(v.TokenEndpoint), + jwt.WithExpirationRequired(), + jwt.WithLeeway(30*time.Second), + ) + if err != nil { + return nil, fmt.Errorf("client assertion verification failed: %w", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("unexpected claims type") + } + + // RFC 7523 §3: iss MUST equal sub (both identify the client) + iss, _ := claims["iss"].(string) + sub, _ := claims["sub"].(string) + if iss == "" || sub == "" { + return nil, errors.New("client assertion must contain 'iss' and 'sub' claims") + } + if iss != sub { + return nil, fmt.Errorf("client assertion 'iss' (%s) must equal 'sub' (%s) per RFC 7523 §3", iss, sub) + } + + // Verify jti for replay protection + jti, _ := claims["jti"].(string) + if jti == "" && v.JTICheck != nil { + return nil, errors.New("client assertion must contain 'jti' claim for replay protection") + } + + // Parse time claims + expFloat, _ := claims["exp"].(float64) + iatFloat, okIat := claims["iat"].(float64) + if !okIat || iatFloat == 0 { + return nil, errors.New("client assertion must contain a numeric 'iat' claim") + } + expTime := time.Unix(int64(expFloat), 0) + iatTime := time.Unix(int64(iatFloat), 0) + if expTime.Before(iatTime) { + return nil, errors.New("client assertion 'exp' must be after 'iat'") + } + // Check max lifetime + maxLifetime := v.MaxLifetime + if maxLifetime == 0 { + maxLifetime = 5 * time.Minute + } + + now := time.Now() + clockSkew := 30 * time.Second + + // Reject iat in the future (beyond clock skew tolerance) + if iatTime.After(now.Add(clockSkew)) { + return nil, errors.New("client assertion 'iat' is in the future") + } + + // Reject exp too far ahead of now (beyond maxLifetime + clock skew) + if expTime.After(now.Add(maxLifetime + clockSkew)) { + return nil, fmt.Errorf("client assertion 'exp' is too far in the future (max %s from now)", maxLifetime) + } + + if expTime.Sub(iatTime) > maxLifetime { + return nil, fmt.Errorf("client assertion lifetime exceeds maximum (%s)", maxLifetime) + } + + // JTI replay check + if v.JTICheck != nil { + if err := v.JTICheck(jti, expTime); err != nil { + return nil, fmt.Errorf("client assertion jti replay detected: %w", err) + } + } + + result := &ClientAssertionClaims{ + Issuer: iss, + Subject: sub, + JTI: jti, + Expiry: expTime, + IssuedAt: iatTime, + } + switch aud := claims["aud"].(type) { + case string: + result.Audience = aud + case []any: + for _, elem := range aud { + if s, ok := elem.(string); ok && s == v.TokenEndpoint { + result.Audience = s + break + } + } + } + + return result, nil +} + +// fetchJWKS retrieves a JWKS keyset, using the cache when available. +func (v *ClientAssertionVerifier) fetchJWKS(ctx context.Context, uri string) (jwk.Set, error) { + if v.JWKSCache == nil { + return jwk.Fetch(ctx, uri) + } + + // Try cache first. + if raw, ok := v.JWKSCache.Get(ctx, uri); ok { + set, err := jwk.Parse(raw) + if err == nil { + return set, nil + } + // Cached data is corrupt – fall through to re-fetch. + } + + // Fetch from remote. + set, err := jwk.Fetch(ctx, uri) + if err != nil { + return nil, err + } + + // Serialize and store in cache. + if raw, err := json.Marshal(set); err == nil { + v.JWKSCache.Set(ctx, uri, raw) + } + + return set, nil +} diff --git a/pkg/oauth2/client_assertion_verifier_test.go b/pkg/oauth2/client_assertion_verifier_test.go new file mode 100644 index 000000000..e3e0f3f43 --- /dev/null +++ b/pkg/oauth2/client_assertion_verifier_test.go @@ -0,0 +1,557 @@ +package oauth2 + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/SUNET/vc/pkg/cache" + "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v3/jwk" +) + +// testKeySetup generates an ECDSA key pair and serves the public key as a JWKS endpoint. +func testKeySetup(t *testing.T) (*ecdsa.PrivateKey, *httptest.Server) { + t.Helper() + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + pubJWK, err := jwk.Import(privKey.Public()) + if err != nil { + t.Fatalf("failed to import public key to JWK: %v", err) + } + _ = pubJWK.Set(jwk.KeyIDKey, "test-kid") + _ = pubJWK.Set(jwk.AlgorithmKey, "ES256") + + set := jwk.NewSet() + _ = set.AddKey(pubJWK) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(set) + })) + t.Cleanup(srv.Close) + return privKey, srv +} + +// signAssertion creates a signed JWT client assertion. +func signAssertion(t *testing.T, key *ecdsa.PrivateKey, claims jwt.MapClaims) string { + t.Helper() + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = "test-kid" + signed, err := token.SignedString(key) + if err != nil { + t.Fatalf("failed to sign assertion: %v", err) + } + return signed +} + +func validClaims(tokenEndpoint string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "unique-jti-123", + "iat": float64(now.Unix()), + "exp": float64(now.Add(2 * time.Minute).Unix()), + } +} + +func TestClientAssertionVerifier_Verify(t *testing.T) { + privKey, srv := testKeySetup(t) + tokenEndpoint := "https://verifier.example.com/token" + + client := &Client{ + JWKSURI: srv.URL, + } + + tests := []struct { + name string + verifier *ClientAssertionVerifier + claims jwt.MapClaims + client *Client + wantErr string + jtiCheck func(string, time.Time) error + checkResult func(*testing.T, *ClientAssertionClaims) + }{ + { + name: "valid_assertion", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: validClaims(tokenEndpoint), + client: client, + checkResult: func(t *testing.T, c *ClientAssertionClaims) { + if c.Issuer != "client-id" { + t.Errorf("Issuer = %q, want %q", c.Issuer, "client-id") + } + if c.Subject != "client-id" { + t.Errorf("Subject = %q, want %q", c.Subject, "client-id") + } + if c.JTI != "unique-jti-123" { + t.Errorf("JTI = %q, want %q", c.JTI, "unique-jti-123") + } + }, + }, + { + name: "no_jwks_uri", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: validClaims(tokenEndpoint), + client: &Client{}, + wantErr: "client has no jwks_uri configured", + }, + { + name: "wrong_audience", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: func() jwt.MapClaims { + c := validClaims(tokenEndpoint) + c["aud"] = "https://wrong.example.com/token" + return c + }(), + client: client, + wantErr: "verification failed", + }, + { + name: "iss_not_equal_sub", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: func() jwt.MapClaims { + c := validClaims(tokenEndpoint) + c["iss"] = "client-id" + c["sub"] = "different-id" + return c + }(), + client: client, + wantErr: "must equal 'sub'", + }, + { + name: "missing_jti", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + JTICheck: func(_ string, _ time.Time) error { return nil }, + }, + claims: func() jwt.MapClaims { + c := validClaims(tokenEndpoint) + delete(c, "jti") + return c + }(), + client: client, + wantErr: "must contain 'jti' claim", + }, + { + name: "missing_iss_and_sub", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: func() jwt.MapClaims { + c := validClaims(tokenEndpoint) + delete(c, "iss") + delete(c, "sub") + return c + }(), + client: client, + wantErr: "must contain 'iss' and 'sub' claims", + }, + { + name: "missing_iat", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: func() jwt.MapClaims { + c := validClaims(tokenEndpoint) + delete(c, "iat") + return c + }(), + client: client, + wantErr: "must contain a numeric 'iat' claim", + }, + { + name: "missing_iat_with_far_future_exp_bypass_attempt", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + MaxLifetime: 5 * time.Minute, + }, + claims: func() jwt.MapClaims { + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "bypass-jti", + "exp": float64(time.Now().Add(24 * time.Hour).Unix()), + } + }(), + client: client, + wantErr: "must contain a numeric 'iat' claim", + }, + { + name: "zero_iat_bypass_attempt", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + MaxLifetime: 5 * time.Minute, + }, + claims: func() jwt.MapClaims { + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "zero-iat-jti", + "iat": float64(0), + "exp": float64(time.Now().Add(24 * time.Hour).Unix()), + } + }(), + client: client, + wantErr: "must contain a numeric 'iat' claim", + }, + { + name: "non_numeric_iat_bypass_attempt", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + MaxLifetime: 5 * time.Minute, + }, + claims: func() jwt.MapClaims { + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "string-iat-jti", + "iat": "not-a-number", + "exp": float64(time.Now().Add(24 * time.Hour).Unix()), + } + }(), + client: client, + wantErr: "must contain a numeric 'iat' claim", + }, + { + name: "future_iat_bypass_max_lifetime", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + MaxLifetime: 5 * time.Minute, + }, + claims: func() jwt.MapClaims { + futureIat := time.Now().Add(24 * time.Hour) + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "future-iat-jti", + "iat": float64(futureIat.Unix()), + "exp": float64(futureIat.Add(2 * time.Minute).Unix()), + } + }(), + client: client, + wantErr: "'iat' is in the future", + }, + { + name: "expired_token", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: func() jwt.MapClaims { + past := time.Now().Add(-10 * time.Minute) + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "expired-jti", + "iat": float64(past.Unix()), + "exp": float64(past.Add(2 * time.Minute).Unix()), + } + }(), + client: client, + wantErr: "verification failed", + }, + { + name: "exceeds_max_lifetime", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + MaxLifetime: 1 * time.Minute, + }, + claims: func() jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": "client-id", + "sub": "client-id", + "aud": tokenEndpoint, + "jti": "long-lived-jti", + "iat": float64(now.Unix()), + "exp": float64(now.Add(10 * time.Minute).Unix()), + } + }(), + client: client, + wantErr: "'exp' is too far in the future", + }, + { + name: "jti_replay_detected", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: validClaims(tokenEndpoint), + client: client, + jtiCheck: func(string, time.Time) error { return errors.New("already seen") }, + wantErr: "jti replay detected", + }, + { + name: "disallowed_algorithm", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + AllowedAlgorithms: []string{"RS256"}, + }, + claims: validClaims(tokenEndpoint), + client: client, + wantErr: "verification failed", + }, + { + name: "unreachable_jwks_uri", + verifier: &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + }, + claims: validClaims(tokenEndpoint), + client: &Client{JWKSURI: "http://127.0.0.1:1/nonexistent"}, + wantErr: "failed to fetch client JWKS", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.jtiCheck != nil { + tt.verifier.JTICheck = tt.jtiCheck + } + + assertion := signAssertion(t, privKey, tt.claims) + result, err := tt.verifier.Verify(context.Background(), assertion, tt.client) + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if got := err.Error(); !contains(got, tt.wantErr) { + t.Fatalf("error = %q, want substring %q", got, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.checkResult != nil { + tt.checkResult(t, result) + } + }) + } +} + +func TestClientAssertionVerifier_DefaultAlgorithms(t *testing.T) { + expected := []string{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "EdDSA"} + if len(defaultAllowedAlgorithms) != len(expected) { + t.Fatalf("defaultAllowedAlgorithms length = %d, want %d", len(defaultAllowedAlgorithms), len(expected)) + } + for i, alg := range expected { + if defaultAllowedAlgorithms[i] != alg { + t.Errorf("defaultAllowedAlgorithms[%d] = %q, want %q", i, defaultAllowedAlgorithms[i], alg) + } + } +} + +func contains(s, substr string) bool { + if substr == "" { + return true + } + for i := 0; i+len(substr) <= len(s); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestClientAssertionVerifier_ValidWithCustomAlgorithms(t *testing.T) { + privKey, srv := testKeySetup(t) + tokenEndpoint := "https://verifier.example.com/token" + + verifier := &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + AllowedAlgorithms: []string{"ES256"}, + } + + assertion := signAssertion(t, privKey, validClaims(tokenEndpoint)) + result, err := verifier.Verify(context.Background(), assertion, &Client{JWKSURI: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Issuer != "client-id" { + t.Errorf("Issuer = %q, want %q", result.Issuer, "client-id") + } +} + +func TestClientAssertionVerifier_JTICheckCalled(t *testing.T) { + privKey, srv := testKeySetup(t) + tokenEndpoint := "https://verifier.example.com/token" + + var capturedJTI string + verifier := &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + JTICheck: func(jti string, exp time.Time) error { + capturedJTI = jti + return nil + }, + } + + claims := validClaims(tokenEndpoint) + claims["jti"] = "specific-jti-value" + assertion := signAssertion(t, privKey, claims) + + _, err := verifier.Verify(context.Background(), assertion, &Client{JWKSURI: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if capturedJTI != "specific-jti-value" { + t.Errorf("JTICheck received jti = %q, want %q", capturedJTI, "specific-jti-value") + } +} + +func TestClientAssertionVerifier_KeyLookupFallback(t *testing.T) { + // Test that verification works even without kid in token header + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + pubJWK, err := jwk.Import(privKey.Public()) + if err != nil { + t.Fatalf("failed to import public key: %v", err) + } + // Set algorithm but NO kid + _ = pubJWK.Set(jwk.AlgorithmKey, "ES256") + + set := jwk.NewSet() + _ = set.AddKey(pubJWK) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(set) + })) + t.Cleanup(srv.Close) + + tokenEndpoint := "https://verifier.example.com/token" + verifier := &ClientAssertionVerifier{TokenEndpoint: tokenEndpoint} + + // Sign without kid in header + token := jwt.NewWithClaims(jwt.SigningMethodES256, validClaims(tokenEndpoint)) + // No kid header set + assertion, err := token.SignedString(privKey) + if err != nil { + t.Fatalf("failed to sign: %v", err) + } + + result, err := verifier.Verify(context.Background(), assertion, &Client{JWKSURI: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Subject != "client-id" { + t.Errorf("Subject = %q, want %q", result.Subject, "client-id") + } +} + +func TestClientAssertionVerifier_InvalidJWT(t *testing.T) { + _, srv := testKeySetup(t) + tokenEndpoint := "https://verifier.example.com/token" + + tests := []struct { + name string + assertion string + wantErr string + }{ + {name: "garbage_token", assertion: "not.a.jwt", wantErr: "verification failed"}, + {name: "empty_token", assertion: "", wantErr: "verification failed"}, + {name: "partial_token", assertion: "eyJhbGciOiJFUzI1NiJ9.", wantErr: "verification failed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifier := &ClientAssertionVerifier{TokenEndpoint: tokenEndpoint} + _, err := verifier.Verify(context.Background(), tt.assertion, &Client{JWKSURI: srv.URL}) + if err == nil { + t.Fatal("expected error, got nil") + } + if got := fmt.Sprint(err); !contains(got, tt.wantErr) { + t.Errorf("error = %q, want substring %q", got, tt.wantErr) + } + }) + } +} + +func TestClientAssertionVerifier_JWKSCache(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + pubJWK, err := jwk.Import(privKey.Public()) + if err != nil { + t.Fatalf("failed to import public key to JWK: %v", err) + } + _ = pubJWK.Set(jwk.KeyIDKey, "test-kid") + _ = pubJWK.Set(jwk.AlgorithmKey, "ES256") + + set := jwk.NewSet() + _ = set.AddKey(pubJWK) + + tokenEndpoint := "https://verifier.example.com/token" + + // Count how many times the JWKS endpoint is hit. + var fetchCount atomic.Int32 + countingSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(set) + })) + t.Cleanup(countingSrv.Close) + + jwksCache := cache.NewMemoryCache[[]byte](5 * time.Minute) + verifier := &ClientAssertionVerifier{ + TokenEndpoint: tokenEndpoint, + JWKSCache: jwksCache, + } + + client := &Client{JWKSURI: countingSrv.URL} + + // First call: should fetch from remote. + claims1 := validClaims(tokenEndpoint) + claims1["jti"] = "cache-test-1" + assertion1 := signAssertion(t, privKey, claims1) + _, err = verifier.Verify(context.Background(), assertion1, client) + if err != nil { + t.Fatalf("first verify failed: %v", err) + } + if fetchCount.Load() != 1 { + t.Fatalf("expected 1 fetch, got %d", fetchCount.Load()) + } + + // Second call: should use cache, no additional fetch. + claims2 := validClaims(tokenEndpoint) + claims2["jti"] = "cache-test-2" + assertion2 := signAssertion(t, privKey, claims2) + _, err = verifier.Verify(context.Background(), assertion2, client) + if err != nil { + t.Fatalf("second verify failed: %v", err) + } + if fetchCount.Load() != 1 { + t.Fatalf("expected still 1 fetch after cache hit, got %d", fetchCount.Load()) + } +} diff --git a/pkg/oauth2/clients.go b/pkg/oauth2/clients.go index 7d0326003..d3fde0568 100644 --- a/pkg/oauth2/clients.go +++ b/pkg/oauth2/clients.go @@ -27,6 +27,9 @@ type Client struct { RedirectURIs RedirectURIs `json:"redirect_uri" yaml:"redirect_uri" validate:"required,min=1,dive,required" doc_example:"\"https://example.com/callback\""` // Scopes is the list of OAuth2 scopes allowed for the client Scopes []string `json:"scopes" yaml:"scopes" validate:"required"` + // JWKSURI is the URL to the client's JWKS for verifying client_assertion signatures (RFC 7523). + // Required for confidential clients using private_key_jwt authentication. + JWKSURI string `json:"jwks_uri,omitempty" yaml:"jwks_uri,omitempty" validate:"required_if=Type confidential,omitempty,httpsurl"` } // RedirectURIs holds one or more allowed redirect URIs.