diff --git a/internal/encryption/raft_envelope.go b/internal/encryption/raft_envelope.go new file mode 100644 index 000000000..bf317c744 --- /dev/null +++ b/internal/encryption/raft_envelope.go @@ -0,0 +1,113 @@ +package encryption + +import ( + "encoding/binary" + + "github.com/cockroachdb/errors" +) + +// RaftAADPurpose is the literal byte 'R' (0x52) that prefixes the +// raft-envelope AAD per design §4.2. It distinguishes a raft envelope +// from a storage envelope: a storage-layer ciphertext replayed into +// the raft layer (or the reverse) fails GCM verification because the +// AAD prefix does not match. +const RaftAADPurpose byte = 'R' // 0x52 + +// raftAADSize is the length of the raft-envelope AAD: +// +// purpose(1) ‖ envelope_version(1) ‖ key_id(4) +// +// No pebble_key — the raft envelope is location-independent. The +// engine identifies the entry by raftpb.Entry.Index, which the +// pre-apply hook uses to gate the Unwrap (§6.3). +const raftAADSize = 1 + versionBytes + keyIDBytes // 6 + +// BuildRaftAAD composes the §4.2 raft-envelope AAD: a single-byte +// purpose tag ('R'), the envelope version, and the 4-byte big-endian +// key_id. Exposed for tests; production callers go through +// WrapRaftPayload / UnwrapRaftPayload. +func BuildRaftAAD(version byte, keyID uint32) []byte { + aad := make([]byte, raftAADSize) + aad[0] = RaftAADPurpose + aad[1] = version + binary.BigEndian.PutUint32(aad[2:2+keyIDBytes], keyID) + return aad +} + +// WrapRaftPayload wraps payload in a §4.2 raft envelope under the DEK +// identified by keyID, using the supplied 12-byte nonce. The cipher +// must already hold the keyID under the "raft" purpose (the keystore +// itself does not enforce purpose — that contract is maintained by the +// sidecar loader). +// +// The flag byte is fixed at 0x00; raft proposals do not carry the +// Snappy compression bit (the apply path is latency-sensitive and +// proposals are small / high-entropy). +// +// Nonce uniqueness is the caller's responsibility: re-using a +// (keyID, nonce) pair under the same DEK is a catastrophic AES-GCM +// failure (key-recovery + plaintext XOR). The §4.2 deterministic +// nonce construction (`node_id ‖ local_epoch ‖ write_count`) +// guarantees uniqueness by construction; do not substitute a +// different scheme without an equivalent uniqueness proof. +func WrapRaftPayload(c *Cipher, keyID uint32, nonce, payload []byte) ([]byte, error) { + if c == nil { + return nil, errors.WithStack(ErrNilKeystore) + } + if len(nonce) != NonceSize { + return nil, errors.Wrapf(ErrBadNonceSize, "got %d bytes, want %d", len(nonce), NonceSize) + } + const envelopeFlag byte = 0 + aad := BuildRaftAAD(EnvelopeVersionV1, keyID) + body, err := c.Encrypt(payload, aad, keyID, nonce) + if err != nil { + return nil, errors.Wrap(err, "encryption: raft envelope encrypt") + } + var nonceArr [NonceSize]byte + copy(nonceArr[:], nonce) + env := Envelope{ + Version: EnvelopeVersionV1, + Flag: envelopeFlag, + KeyID: keyID, + Nonce: nonceArr, + Body: body, + } + encoded, err := env.Encode() + if err != nil { + return nil, errors.Wrap(err, "encryption: raft envelope encode") + } + return encoded, nil +} + +// UnwrapRaftPayload reverses WrapRaftPayload. Decodes the envelope, +// rebuilds the AAD identically, and calls Decrypt. The same `*Cipher` +// instance used at wrap time must hold the embedded keyID (or one of +// its rotated successors) for unwrap to succeed. +// +// Surfaces typed errors callers can disambiguate via errors.Is: +// +// - ErrEnvelopeShort: encoded shorter than HeaderSize+TagSize +// - ErrEnvelopeVersion: unknown version byte +// - ErrUnknownKeyID: DEK is not loaded (retired or sidecar missing) +// - ErrIntegrity: GCM tag mismatch (tampered envelope, wrong DEK, +// or layer confusion with a storage envelope) +// +// A storage envelope fed to UnwrapRaftPayload fails with +// ErrIntegrity because the storage AAD prefix +// ('envelope_version ‖ flag ‖ key_id ‖ value_header(9B) ‖ pebble_key') +// does not start with the raft-purpose byte 'R'. +func UnwrapRaftPayload(c *Cipher, encoded []byte) ([]byte, error) { + if c == nil { + return nil, errors.WithStack(ErrNilKeystore) + } + env, err := DecodeEnvelope(encoded) + if err != nil { + return nil, errors.Wrap(err, "encryption: raft envelope decode") + } + aad := BuildRaftAAD(env.Version, env.KeyID) + plain, err := c.Decrypt(env.Body, aad, env.KeyID, env.Nonce[:]) + if err != nil { + return nil, errors.Wrap(err, "encryption: raft envelope decrypt") + } + return plain, nil +} diff --git a/internal/encryption/raft_envelope_test.go b/internal/encryption/raft_envelope_test.go new file mode 100644 index 000000000..12915a321 --- /dev/null +++ b/internal/encryption/raft_envelope_test.go @@ -0,0 +1,243 @@ +package encryption_test + +import ( + "bytes" + "crypto/rand" + "testing" + + "github.com/bootjp/elastickv/internal/encryption" + "github.com/cockroachdb/errors" +) + +// raftFixture wires a freshly-keyed Cipher with one DEK at testKeyID +// — the Stage-3 raft envelope tests don't need to model multiple DEKs; +// rotation / retire scenarios are covered by the storage-envelope +// suite that exercises the same Cipher / Keystore. +func raftFixture(t *testing.T) (*encryption.Cipher, uint32) { + t.Helper() + ks, kid := newKeystoreWithKey(t) + return mustCipher(t, ks), kid +} + +func TestRaftEnvelope_RoundTrip(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + cases := []struct { + name string + payload []byte + }{ + {"empty", []byte{}}, + {"short", []byte("op=put key=k1 v=hello")}, + {"binary", []byte{0x00, 0xff, 0x10, 0x42, 0xde, 0xad}}, + {"4 KiB", bytes.Repeat([]byte("X"), 4096)}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + nonce := newRandomNonce(t) + encoded, err := encryption.WrapRaftPayload(c, kid, nonce, tc.payload) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + if got, want := len(encoded), len(tc.payload)+encryption.EnvelopeOverhead; got != want { + t.Fatalf("encoded len=%d, want %d", got, want) + } + plain, err := encryption.UnwrapRaftPayload(c, encoded) + if err != nil { + t.Fatalf("Unwrap: %v", err) + } + if !bytes.Equal(plain, tc.payload) { + t.Fatalf("plaintext mismatch: got %x, want %x", plain, tc.payload) + } + }) + } +} + +// TestRaftEnvelope_DeterministicNonce confirms that the same +// (keyID, nonce, payload) triple produces the same encoded bytes. +// The §4.2 deterministic nonce factory relies on this property to +// reproduce ciphertexts deterministically across replays. +func TestRaftEnvelope_DeterministicNonce(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + nonce := newRandomNonce(t) + payload := []byte("repeatable payload") + a, err := encryption.WrapRaftPayload(c, kid, nonce, payload) + if err != nil { + t.Fatalf("Wrap A: %v", err) + } + b, err := encryption.WrapRaftPayload(c, kid, nonce, payload) + if err != nil { + t.Fatalf("Wrap B: %v", err) + } + if !bytes.Equal(a, b) { + t.Fatal("deterministic-nonce wrap produced different bytes for identical inputs") + } +} + +// TestRaftEnvelope_RejectsTagTamper flips a byte inside the GCM tag +// region (last 16 bytes) and confirms Unwrap surfaces ErrIntegrity. +// Mirrors the §4.1 storage-envelope tag-tamper test. +func TestRaftEnvelope_RejectsTagTamper(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + encoded, err := encryption.WrapRaftPayload(c, kid, newRandomNonce(t), []byte("payload")) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + encoded[len(encoded)-1] ^= 0xff + _, err = encryption.UnwrapRaftPayload(c, encoded) + if !errors.Is(err, encryption.ErrIntegrity) { + t.Fatalf("expected ErrIntegrity, got %v", err) + } +} + +// TestRaftEnvelope_RejectsKeyIDTamper flips a key_id byte inside the +// envelope header. The key_id participates in the AAD via +// BuildRaftAAD, so the flip changes the AAD on Unwrap and GCM +// rejects the tag. +func TestRaftEnvelope_RejectsKeyIDTamper(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + encoded, err := encryption.WrapRaftPayload(c, kid, newRandomNonce(t), []byte("payload")) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + // keyID lives at offset 2..5 (version=0, flag=1, keyID=2-5). + encoded[2] ^= 0x01 + _, err = encryption.UnwrapRaftPayload(c, encoded) + if !errors.Is(err, encryption.ErrIntegrity) && !errors.Is(err, encryption.ErrUnknownKeyID) { + // Either outcome is acceptable: ErrIntegrity if the tampered + // key_id remains in the keystore (impossible with a single + // loaded DEK), ErrUnknownKeyID if the tampered key_id is no + // longer loaded. With a single DEK loaded, a flipped low bit + // almost always lands on an unknown key_id. + t.Fatalf("expected ErrIntegrity or ErrUnknownKeyID, got %v", err) + } +} + +// TestRaftEnvelope_RejectsStorageEnvelope confirms a §4.1 storage +// envelope (whose AAD includes the value-header bytes and pebble +// key, but NOT the 'R' purpose byte) fails GCM verification when +// fed to UnwrapRaftPayload. This is the layer-confusion defence +// design §4.2 calls out: a storage ciphertext replayed into the +// raft path must not silently decrypt. +func TestRaftEnvelope_RejectsStorageEnvelope(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + nonce := newRandomNonce(t) + // Build a storage envelope by hand: AAD = HeaderAADBytes ‖ value-header ‖ pebble-key. + storageAAD := encryption.AppendHeaderAADBytes(nil, encryption.EnvelopeVersionV1, 0, kid) + storageAAD = append(storageAAD, []byte("synthetic 9B header")...) + storageAAD = append(storageAAD, []byte("synthetic pebble key")...) + body, err := c.Encrypt([]byte("payload"), storageAAD, kid, nonce) + if err != nil { + t.Fatalf("Encrypt storage-style: %v", err) + } + var nonceArr [encryption.NonceSize]byte + copy(nonceArr[:], nonce) + env := encryption.Envelope{ + Version: encryption.EnvelopeVersionV1, + Flag: 0, + KeyID: kid, + Nonce: nonceArr, + Body: body, + } + storageEncoded, err := env.Encode() + if err != nil { + t.Fatalf("Encode: %v", err) + } + // Feed the storage envelope to the raft unwrap path. + _, err = encryption.UnwrapRaftPayload(c, storageEncoded) + if !errors.Is(err, encryption.ErrIntegrity) { + t.Fatalf("expected ErrIntegrity for layer-confusion replay, got %v", err) + } +} + +// TestRaftEnvelope_RejectsRetiredKey confirms that an envelope +// whose key_id has been deleted from the keystore (DEK retirement +// or sidecar mismatch) surfaces ErrUnknownKeyID rather than a +// silent garbage decrypt. +func TestRaftEnvelope_RejectsRetiredKey(t *testing.T) { + t.Parallel() + ks, kid := newKeystoreWithKey(t) + c := mustCipher(t, ks) + encoded, err := encryption.WrapRaftPayload(c, kid, newRandomNonce(t), []byte("payload")) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + ks.Delete(kid) + _, err = encryption.UnwrapRaftPayload(c, encoded) + if !errors.Is(err, encryption.ErrUnknownKeyID) { + t.Fatalf("expected ErrUnknownKeyID after retire, got %v", err) + } +} + +// TestRaftEnvelope_ShortInputRejected covers DecodeEnvelope's +// length precondition (HeaderSize + TagSize = 34 bytes minimum). +func TestRaftEnvelope_ShortInputRejected(t *testing.T) { + t.Parallel() + c, _ := raftFixture(t) + for _, l := range []int{0, 1, 17, 33} { + _, err := encryption.UnwrapRaftPayload(c, make([]byte, l)) + if !errors.Is(err, encryption.ErrEnvelopeShort) { + t.Fatalf("len=%d: expected ErrEnvelopeShort, got %v", l, err) + } + } +} + +func TestRaftEnvelope_RejectsBadNonceSize(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + for _, n := range []int{0, 1, 11, 13} { + _, err := encryption.WrapRaftPayload(c, kid, make([]byte, n), []byte("p")) + if !errors.Is(err, encryption.ErrBadNonceSize) { + t.Fatalf("nonce=%d: expected ErrBadNonceSize, got %v", n, err) + } + } +} + +func TestRaftEnvelope_NilCipher(t *testing.T) { + t.Parallel() + if _, err := encryption.WrapRaftPayload(nil, 1, make([]byte, encryption.NonceSize), nil); !errors.Is(err, encryption.ErrNilKeystore) { + t.Fatalf("Wrap: expected ErrNilKeystore, got %v", err) + } + if _, err := encryption.UnwrapRaftPayload(nil, make([]byte, encryption.EnvelopeOverhead)); !errors.Is(err, encryption.ErrNilKeystore) { + t.Fatalf("Unwrap: expected ErrNilKeystore, got %v", err) + } +} + +func newRandomBytes(t *testing.T, n int) []byte { + t.Helper() + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read: %v", err) + } + return b +} + +// TestBuildRaftAAD pins the byte layout: 'R' ‖ version ‖ key_id BE. +func TestBuildRaftAAD(t *testing.T) { + t.Parallel() + got := encryption.BuildRaftAAD(0x01, 0xCAFEBABE) + want := []byte{'R', 0x01, 0xCA, 0xFE, 0xBA, 0xBE} + if !bytes.Equal(got, want) { + t.Fatalf("AAD mismatch:\n got %x\n want %x", got, want) + } +} + +// TestRaftEnvelope_NoLeakedPlaintextBytes — defensive: an encoded +// envelope must NOT contain the raw plaintext as a suffix. A +// regression where the envelope appended plaintext alongside the +// ciphertext would leak data on disk; this catches any such bug. +func TestRaftEnvelope_NoLeakedPlaintextBytes(t *testing.T) { + t.Parallel() + c, kid := raftFixture(t) + plaintext := newRandomBytes(t, 256) + encoded, err := encryption.WrapRaftPayload(c, kid, newRandomNonce(t), plaintext) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + if bytes.Contains(encoded, plaintext) { + t.Fatal("encoded envelope contains the plaintext suffix verbatim") + } +} diff --git a/internal/raftengine/etcd/encryption.go b/internal/raftengine/etcd/encryption.go new file mode 100644 index 000000000..7b40d2d7b --- /dev/null +++ b/internal/raftengine/etcd/encryption.go @@ -0,0 +1,77 @@ +package etcd + +import ( + "github.com/bootjp/elastickv/internal/encryption" + "github.com/cockroachdb/errors" +) + +// ErrRaftUnwrapFailed is returned by applyNormalEntry when the +// pre-apply hook cannot unwrap a §4.2 raft envelope (GCM tag +// mismatch, missing DEK in the local keystore, malformed +// envelope, or active tampering). +// +// Per design §6.3 this is a process-fatal event, NOT a recoverable +// Apply error. applyCommitted propagates the error up to runLoop, +// which exits via the engine's existing fatal-error path. The +// failing entry's index is NOT advanced through setApplied — the +// next restart must replay the same entry, not skip it. Silently +// skipping would let the local FSM diverge from peers that DID +// successfully unwrap and apply, breaking the consistency +// invariant the integrity tag was added to detect. +// +// Operator response: investigate sidecar / Raft-log divergence +// (§5.5 of the encryption design doc) or KEK custody (§9.3); a +// supervised restart with a corrected sidecar is the only safe +// recovery path. +var ErrRaftUnwrapFailed = errors.New("raftengine/etcd: raft envelope unwrap failed; halting apply") + +// RaftCutoverIndex returns the §7.1 Phase 2 cutover Raft index. +// Entries with index strictly greater than the returned value carry +// raft-envelope-wrapped fsm payloads; entries at or below the +// cutover are cleartext. The returned value is read on every +// applyNormalEntry, so implementations should be lock-free +// (atomic.Uint64.Load) — the engine does not synchronize the read. +// +// The Stage 3 default (when OpenConfig.RaftCutoverIndex is nil) is +// `^uint64(0)` (no entry's index is greater) so the unwrap path is +// inert until Stage 6 wires the sidecar's +// raft_envelope_cutover_index in. +type RaftCutoverIndex func() uint64 + +// inertRaftCutoverIndex is the OpenConfig-default returned when no +// cutover function is supplied: every entry index is treated as +// "below cutover" and the pre-apply hook is a no-op. +func inertRaftCutoverIndex() uint64 { + return ^uint64(0) +} + +// orInertCutover returns the supplied callback if non-nil, otherwise +// the inert default. Letting Engine.raftCutoverIndex be a real +// closure avoids a nil-check in the apply hot path. +func orInertCutover(fn RaftCutoverIndex) RaftCutoverIndex { + if fn == nil { + return inertRaftCutoverIndex + } + return fn +} + +// unwrapRaftPayload runs the §4.2 raft envelope Unwrap when both a +// cipher is wired AND entry.Index > cutover(). Returns the +// cleartext payload on success, or wraps any decrypt failure with +// ErrRaftUnwrapFailed for the caller to recognise via errors.Is. +// +// Extracted so applyNormalEntry stays a one-liner and the unit +// tests can exercise the cutover gate + error mapping without +// constructing a full Engine. +func unwrapRaftPayload(cipher *encryption.Cipher, payload []byte) ([]byte, error) { + plain, err := encryption.UnwrapRaftPayload(cipher, payload) + if err != nil { + // Mark wraps the encryption-package error with + // ErrRaftUnwrapFailed so the apply loop's errors.Is check + // distinguishes envelope-unwrap from other Apply paths; + // the underlying ErrIntegrity / ErrUnknownKeyID stays + // available for diagnostic logs via errors.Is. + return nil, errors.Wrap(errors.Mark(err, ErrRaftUnwrapFailed), "raftengine/etcd: raft envelope unwrap") + } + return plain, nil +} diff --git a/internal/raftengine/etcd/encryption_test.go b/internal/raftengine/etcd/encryption_test.go new file mode 100644 index 000000000..3d287877b --- /dev/null +++ b/internal/raftengine/etcd/encryption_test.go @@ -0,0 +1,397 @@ +package etcd + +import ( + "crypto/rand" + "io" + "sync/atomic" + "testing" + + "github.com/bootjp/elastickv/internal/encryption" + "github.com/cockroachdb/errors" + raftpb "go.etcd.io/raft/v3/raftpb" +) + +// fakeStateMachine records every Apply call so the encryption tests +// can assert (a) what bytes the FSM saw, and (b) whether Apply was +// even reached on the unwrap-failure path. +type fakeStateMachine struct { + calls atomic.Int32 + last []byte +} + +func (f *fakeStateMachine) Apply(data []byte) any { + f.calls.Add(1) + cp := make([]byte, len(data)) + copy(cp, data) + f.last = cp + return nil +} + +func (f *fakeStateMachine) Snapshot() (Snapshot, error) { return nil, nil } +func (f *fakeStateMachine) Restore(_ io.Reader) error { return nil } + +// raftCipherFixture wires a Cipher with a single DEK at testKeyID. +// The Stage-3 unit tests don't need to model purpose enforcement; +// the Cipher itself does not check purpose, so a single keystore +// with one key is enough to exercise the apply hook's gate logic. +func raftCipherFixture(t *testing.T) (*encryption.Cipher, uint32) { + t.Helper() + ks := encryption.NewKeystore() + dek := make([]byte, encryption.KeySize) + if _, err := rand.Read(dek); err != nil { + t.Fatalf("rand.Read DEK: %v", err) + } + const kid uint32 = 0xDEADBEEF + if err := ks.Set(kid, dek); err != nil { + t.Fatalf("Set DEK: %v", err) + } + c, err := encryption.NewCipher(ks) + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + return c, kid +} + +func newRaftNonce(t *testing.T) []byte { + t.Helper() + n := make([]byte, encryption.NonceSize) + if _, err := rand.Read(n); err != nil { + t.Fatalf("rand.Read nonce: %v", err) + } + return n +} + +func newTestEngine(fsm StateMachine, cipher *encryption.Cipher, cutover RaftCutoverIndex) *Engine { + return &Engine{ + fsm: fsm, + raftCipher: cipher, + raftCutoverIndex: orInertCutover(cutover), + } +} + +func envelopeEntry(t *testing.T, c *encryption.Cipher, kid uint32, index uint64, plaintext []byte) raftpb.Entry { + t.Helper() + wrapped, err := encryption.WrapRaftPayload(c, kid, newRaftNonce(t), plaintext) + if err != nil { + t.Fatalf("WrapRaftPayload: %v", err) + } + return raftpb.Entry{Index: index, Type: raftpb.EntryNormal, Data: encodeProposalEnvelope(7, wrapped)} +} + +// TestApplyNormalEntry_CutoverActive_NoCipher_FailsClosed locks down +// the misconfig case: when the cluster has crossed the raft envelope +// cutover (an above-cutover entry arrives) but raftCipher is nil — +// a sidecar/init race or operator wiring mistake — the engine MUST +// refuse to apply, NOT silently hand wrapped envelope bytes to +// fsm.Apply. The latter would permanently diverge this node from +// peers that DID unwrap and apply the cleartext. +func TestApplyNormalEntry_CutoverActive_NoCipher_FailsClosed(t *testing.T) { + t.Parallel() + const cutover uint64 = 100 + fsm := &fakeStateMachine{} + // cipher == nil simulates the misconfig: cutover is set, no + // cipher wired. + e := newTestEngine(fsm, nil, func() uint64 { return cutover }) + + // Any payload above the cutover index hits the fail-closed branch + // regardless of whether it's a real envelope or cleartext, because + // without a cipher we cannot tell them apart. + entry := raftpb.Entry{ + Type: raftpb.EntryNormal, + Index: cutover + 1, + Data: encodeProposalEnvelope(99, []byte("would-be wrapped payload")), + } + _, err := e.applyNormalEntry(entry) + if !errors.Is(err, ErrRaftUnwrapFailed) { + t.Fatalf("expected ErrRaftUnwrapFailed for cutover-active+no-cipher misconfig, got %v", err) + } + if got := fsm.calls.Load(); got != 0 { + t.Fatalf("fsm.Apply called %d times despite refused apply", got) + } + + // Below-cutover entries still pass through unchanged in this + // configuration — the misconfig detection fires only when an + // above-cutover entry actually arrives. (Pre-cutover entries + // were written before encryption was activated and remain + // legitimately cleartext.) + belowCutoverEntry := raftpb.Entry{ + Type: raftpb.EntryNormal, + Index: cutover, + Data: encodeProposalEnvelope(11, []byte("legacy cleartext")), + } + if _, err := e.applyNormalEntry(belowCutoverEntry); err != nil { + t.Fatalf("below-cutover should pass through, got %v", err) + } + if got := fsm.calls.Load(); got != 1 { + t.Fatalf("fsm.Apply call count after below-cutover = %d, want 1", got) + } +} + +// TestApplyNormalEntry_NoCipher_PassThrough confirms Stage-3 wiring +// preserves Stage-0/2 behaviour: with raftCipher == nil AND no +// cutover (the OpenConfig defaults), the apply hook is inert and the +// FSM sees the proposal envelope's inner payload byte-for-byte. +func TestApplyNormalEntry_NoCipher_PassThrough(t *testing.T) { + t.Parallel() + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, nil, nil) + plain := []byte("op=put key=k1 v=hello") + entry := raftpb.Entry{Type: raftpb.EntryNormal, Data: encodeProposalEnvelope(42, plain)} + if _, err := e.applyNormalEntry(entry); err != nil { + t.Fatalf("applyNormalEntry: %v", err) + } + if got := fsm.calls.Load(); got != 1 { + t.Fatalf("fsm.Apply call count = %d, want 1", got) + } + if string(fsm.last) != string(plain) { + t.Fatalf("FSM saw %q, want %q", fsm.last, plain) + } +} + +// TestApplyNormalEntry_BelowCutover_PassThrough confirms entries +// whose index is at or below the cutover are NOT unwrapped — they +// carry cleartext payloads (the legacy / pre-Phase-2 regime). +// Strict greater-than: index == cutover is the enable-flag entry +// itself and stays cleartext. +func TestApplyNormalEntry_BelowCutover_PassThrough(t *testing.T) { + t.Parallel() + c, _ := raftCipherFixture(t) + cutover := uint64(100) + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + + cleartextPayload := []byte("legacy cleartext") + for _, idx := range []uint64{1, 50, cutover - 1, cutover} { + fsm.calls.Store(0) + entry := raftpb.Entry{ + Type: raftpb.EntryNormal, + Index: idx, + Data: encodeProposalEnvelope(11, cleartextPayload), + } + if _, err := e.applyNormalEntry(entry); err != nil { + t.Fatalf("idx=%d: applyNormalEntry: %v", idx, err) + } + if got := fsm.calls.Load(); got != 1 { + t.Fatalf("idx=%d: fsm.Apply call count = %d, want 1", idx, got) + } + if string(fsm.last) != string(cleartextPayload) { + t.Fatalf("idx=%d: FSM saw %q, want %q", idx, fsm.last, cleartextPayload) + } + } +} + +// TestApplyNormalEntry_AboveCutover_Unwraps confirms entries whose +// index is strictly greater than the cutover go through the §4.2 +// raft envelope Unwrap and the FSM observes the cleartext payload. +func TestApplyNormalEntry_AboveCutover_Unwraps(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + cutover := uint64(100) + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + + for _, idx := range []uint64{cutover + 1, cutover + 100, cutover + 1_000_000} { + fsm.calls.Store(0) + plaintext := []byte("op=put key=k1 v=secret") + entry := envelopeEntry(t, c, kid, idx, plaintext) + if _, err := e.applyNormalEntry(entry); err != nil { + t.Fatalf("idx=%d: applyNormalEntry: %v", idx, err) + } + if got := fsm.calls.Load(); got != 1 { + t.Fatalf("idx=%d: fsm.Apply call count = %d, want 1", idx, got) + } + if string(fsm.last) != string(plaintext) { + t.Fatalf("idx=%d: FSM saw %q, want %q", idx, fsm.last, plaintext) + } + } +} + +// TestApplyNormalEntry_UnwrapFailure_Halts is the integrity test: +// when an above-cutover entry's envelope fails GCM verification +// (DEK retired, tampered bytes, missing key), applyNormalEntry +// returns ErrRaftUnwrapFailed and the FSM is NOT called. The +// caller (applyCommitted) is responsible for not advancing +// setApplied; that's covered by TestApplyCommitted_UnwrapFailure_DoesNotAdvanceApplied. +func TestApplyNormalEntry_UnwrapFailure_Halts(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + cutover := uint64(100) + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + + entry := envelopeEntry(t, c, kid, cutover+1, []byte("payload")) + // Tamper the GCM tag (last byte of the wrapped envelope, before + // the proposal-envelope wrapper). encodeProposalEnvelope uses + // `[0x01][8B id][envelope...]` so the last byte is the tag's + // last byte. + entry.Data[len(entry.Data)-1] ^= 0xff + + _, err := e.applyNormalEntry(entry) + if !errors.Is(err, ErrRaftUnwrapFailed) { + t.Fatalf("expected ErrRaftUnwrapFailed, got %v", err) + } + if got := fsm.calls.Load(); got != 0 { + t.Fatalf("fsm.Apply was called %d times despite unwrap failure", got) + } +} + +// TestApplyCommitted_UnwrapFailure_DoesNotAdvanceApplied locks down +// the design §6.3 invariant: an unwrap failure halts the apply +// loop WITHOUT advancing setApplied. The next restart must replay +// the same entry, not skip it. A regression here would let a +// divergent FSM survive across restarts — exactly the safety +// property the integrity tag was added to detect. +func TestApplyCommitted_UnwrapFailure_DoesNotAdvanceApplied(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + cutover := uint64(100) + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + const startApplied uint64 = 99 + e.applied = startApplied + e.appliedIndex.Store(startApplied) + + good := envelopeEntry(t, c, kid, cutover+1, []byte("ok")) + bad := envelopeEntry(t, c, kid, cutover+2, []byte("tampered")) + bad.Data[len(bad.Data)-1] ^= 0xff + // A third entry that we expect NOT to be processed (it sits + // after the failing one, so applyCommitted must stop at bad). + never := envelopeEntry(t, c, kid, cutover+3, []byte("never")) + + err := e.applyCommitted([]raftpb.Entry{good, bad, never}) + if !errors.Is(err, ErrRaftUnwrapFailed) { + t.Fatalf("applyCommitted: expected ErrRaftUnwrapFailed, got %v", err) + } + // good was applied, so applied advanced to good.Index. + if e.applied != cutover+1 { + t.Fatalf("applied = %d, want %d (good entry advanced, bad did not)", e.applied, cutover+1) + } + if got := e.appliedIndex.Load(); got != cutover+1 { + t.Fatalf("appliedIndex = %d, want %d", got, cutover+1) + } + // good was applied (1 call), bad halted, never untouched. + if got := fsm.calls.Load(); got != 1 { + t.Fatalf("fsm.Apply count = %d, want 1 (only good entry)", got) + } +} + +// TestApplyNormalEntry_BoundaryCutover exercises the strict +// greater-than gate at exactly index = cutover and cutover + 1. +// The cutover index is itself the §7.1 enable-raft-envelope flag +// entry and is NOT raft-DEK-wrapped; only entries strictly greater +// must Unwrap. +func TestApplyNormalEntry_BoundaryCutover(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + const cutover uint64 = 100 + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + + // cutover itself: cleartext payload — must NOT be unwrapped. + cleartext := []byte("enable-raft-envelope flag") + atCutover := raftpb.Entry{ + Type: raftpb.EntryNormal, + Index: cutover, + Data: encodeProposalEnvelope(13, cleartext), + } + if _, err := e.applyNormalEntry(atCutover); err != nil { + t.Fatalf("at-cutover: %v", err) + } + if string(fsm.last) != string(cleartext) { + t.Fatalf("at-cutover: FSM saw %q, want %q", fsm.last, cleartext) + } + + // cutover+1: wrapped payload — MUST be unwrapped. + above := envelopeEntry(t, c, kid, cutover+1, []byte("first encrypted")) + if _, err := e.applyNormalEntry(above); err != nil { + t.Fatalf("above-cutover: %v", err) + } + if string(fsm.last) != "first encrypted" { + t.Fatalf("above-cutover: FSM saw %q, want %q", fsm.last, "first encrypted") + } +} + +// TestApplyNormalEntry_ProposalIDStillResolvable confirms the +// design's load-bearing invariant from §6.3: the engine pre-apply +// hook unwraps the *inner* fsm payload, NOT entry.Data itself, so +// decodeProposalEnvelope(entry.Data) continues to recover the +// proposal id even after unwrap. An earlier draft of the design +// proposed unwrapping entry.Data directly, which would have +// destroyed the leading 0x01 proposalEnvelopeVersion byte and +// timed out every coordinator write. +func TestApplyNormalEntry_ProposalIDStillResolvable(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + const cutover uint64 = 100 + fsm := &fakeStateMachine{} + e := newTestEngine(fsm, c, func() uint64 { return cutover }) + + const wantID uint64 = 31337 + wrapped, err := encryption.WrapRaftPayload(c, kid, newRaftNonce(t), []byte("payload")) + if err != nil { + t.Fatalf("WrapRaftPayload: %v", err) + } + data := encodeProposalEnvelope(wantID, wrapped) + entry := raftpb.Entry{Type: raftpb.EntryNormal, Index: cutover + 1, Data: data} + if _, err := e.applyNormalEntry(entry); err != nil { + t.Fatalf("applyNormalEntry: %v", err) + } + gotID, _, ok := decodeProposalEnvelope(entry.Data) + if !ok { + t.Fatal("decodeProposalEnvelope(entry.Data) failed after applyNormalEntry — entry.Data was mutated") + } + if gotID != wantID { + t.Fatalf("proposal id = %d, want %d", gotID, wantID) + } +} + +// TestApplyNormalEntry_NoCutoverDefault confirms an OpenConfig +// without a cutover callback installs the inert default +// (^uint64(0)). With raftCipher set but cutover = MaxUint64, no +// entry index is greater than the cutover, so the unwrap path +// stays inert. +func TestApplyNormalEntry_NoCutoverDefault(t *testing.T) { + t.Parallel() + c, _ := raftCipherFixture(t) + fsm := &fakeStateMachine{} + // nil cutover → Open's orInertCutover(nil) → inertRaftCutoverIndex + e := newTestEngine(fsm, c, nil) + + cleartext := []byte("legacy cleartext") + for _, idx := range []uint64{1, 1 << 20, 1 << 40, ^uint64(0) - 1} { + entry := raftpb.Entry{ + Type: raftpb.EntryNormal, + Index: idx, + Data: encodeProposalEnvelope(7, cleartext), + } + if _, err := e.applyNormalEntry(entry); err != nil { + t.Fatalf("idx=%d: %v", idx, err) + } + if string(fsm.last) != string(cleartext) { + t.Fatalf("idx=%d: FSM saw %q, want %q", idx, fsm.last, cleartext) + } + } +} + +// TestUnwrapRaftPayload_Helper sanity-checks that the engine-internal +// unwrapRaftPayload helper marks all encryption errors with +// ErrRaftUnwrapFailed (so callers can errors.Is-match without +// caring which underlying code class fired). +func TestUnwrapRaftPayload_Helper(t *testing.T) { + t.Parallel() + c, kid := raftCipherFixture(t) + wrapped, err := encryption.WrapRaftPayload(c, kid, newRaftNonce(t), []byte("payload")) + if err != nil { + t.Fatalf("Wrap: %v", err) + } + // Tag tamper. + tampered := append([]byte(nil), wrapped...) + tampered[len(tampered)-1] ^= 0xff + _, err = unwrapRaftPayload(c, tampered) + if !errors.Is(err, ErrRaftUnwrapFailed) { + t.Fatalf("expected ErrRaftUnwrapFailed via Mark, got %v", err) + } + if !errors.Is(err, encryption.ErrIntegrity) { + t.Fatalf("expected nested ErrIntegrity preserved, got %v", err) + } +} diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 678104f82..fa115c3c0 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/bootjp/elastickv/internal/encryption" "github.com/bootjp/elastickv/internal/monoclock" "github.com/bootjp/elastickv/internal/raftengine" "github.com/cockroachdb/errors" @@ -175,6 +176,17 @@ type OpenConfig struct { // Default: 256. Increase for deeper pipelining on high-bandwidth links; // lower in memory-constrained clusters. MaxInflightMsg int + // RaftCipher carries the AES-GCM Cipher used by the §4.2 raft + // envelope pre-apply hook. nil disables the hook (Stage 3 + // default — production wiring lands once Stage 6's cluster + // flag flow is in place). + RaftCipher *encryption.Cipher + // RaftCutoverIndex returns the §7.1 Phase 2 raft envelope + // cutover index. Only entries whose Raft log index is strictly + // greater than this value go through Unwrap. nil ⇒ no cutover + // has been observed yet, equivalent to "raft envelope hook + // off". + RaftCutoverIndex RaftCutoverIndex } type Engine struct { @@ -203,6 +215,16 @@ type Engine struct { peers map[uint64]Peer transport *GRPCTransport + // raftCipher and raftCutoverIndex implement the §4.2 / §6.3 + // pre-apply unwrap hook. raftCipher is nil unless the + // OpenConfig wiring is provided; raftCutoverIndex is set to a + // permanent-no-op default by Open so applyNormalEntry never + // branches on a nil callback. Both are read-only after Open + // (single-writer discipline matches the rest of the apply-loop + // state). + raftCipher *encryption.Cipher + raftCutoverIndex RaftCutoverIndex + nextRequestID atomic.Uint64 proposeCh chan proposalRequest @@ -495,6 +517,8 @@ func Open(ctx context.Context, cfg OpenConfig) (*Engine, error) { fsm: prepared.cfg.StateMachine, peers: peerMap, transport: prepared.cfg.Transport, + raftCipher: prepared.cfg.RaftCipher, + raftCutoverIndex: orInertCutover(prepared.cfg.RaftCutoverIndex), proposeCh: make(chan proposalRequest), readCh: make(chan readRequest), adminCh: make(chan adminRequest), @@ -1972,38 +1996,20 @@ func (e *Engine) enqueueSnapshotRequest(req snapshotRequest) error { func (e *Engine) applyCommitted(entries []raftpb.Entry) error { for _, entry := range entries { + var err error switch entry.Type { case raftpb.EntryNormal: - response := e.applyNormalEntry(entry) - e.setApplied(entry.Index) - e.resolveProposal(entry.Index, entry.Data, response) + err = e.applyNormalCommitted(entry) case raftpb.EntryConfChange: - var cc raftpb.ConfChange - if err := cc.Unmarshal(entry.Data); err != nil { - return errors.WithStack(err) - } - confState := e.rawNode.ApplyConfChange(cc) - nextPeers := e.nextPeersAfterConfigChange(cc.Type, cc.NodeID, cc.Context, *confState) - if err := e.persistConfigState(entry.Index, *confState, nextPeers); err != nil { - return err - } - e.applyConfigChange(cc.Type, cc.NodeID, cc.Context, entry.Index, *confState) - e.setApplied(entry.Index) + err = e.applyConfChangeCommitted(entry) case raftpb.EntryConfChangeV2: - var cc raftpb.ConfChangeV2 - if err := cc.Unmarshal(entry.Data); err != nil { - return errors.WithStack(err) - } - confState := e.rawNode.ApplyConfChange(cc) - nextPeers := e.nextPeersAfterConfigChangeV2(cc, *confState) - if err := e.persistConfigState(entry.Index, *confState, nextPeers); err != nil { - return err - } - e.applyConfigChangeV2(cc, entry.Index, *confState) - e.setApplied(entry.Index) + err = e.applyConfChangeV2Committed(entry) default: e.setApplied(entry.Index) } + if err != nil { + return err + } } return nil } @@ -2017,15 +2023,127 @@ func (e *Engine) setApplied(index uint64) { e.appliedIndex.Store(index) } -func (e *Engine) applyNormalEntry(entry raftpb.Entry) any { +// applyNormalCommitted is the EntryNormal arm of applyCommitted. +// Extracted so applyCommitted stays under the cyclomatic-complexity +// budget while preserving the load-bearing fail-closed semantics: +// if applyNormalEntry surfaces an error, setApplied is NOT called +// and the error propagates so runLoop's existing fatal-error path +// can halt the process. Silent skip is never an option — the +// integrity tag was added to detect divergence and skipping past +// it would defeat that property (design §6.3). +// +// On the error path resolveProposal is intentionally NOT called: the +// originating coordinator's done channel for this proposal is left +// dangling. That is correct under the §6.3 process-fatal contract — +// runLoop halts and the OS reaps the goroutine plus its channel as +// part of process exit. A graceful resolveProposal on the error path +// would deliver a nil/error response that the coordinator might mis- +// interpret as a committed-but-failed apply (rather than the actual +// "halting; please restart and retry" semantics). +func (e *Engine) applyNormalCommitted(entry raftpb.Entry) error { + response, err := e.applyNormalEntry(entry) + if err != nil { + return err + } + e.setApplied(entry.Index) + e.resolveProposal(entry.Index, entry.Data, response) + return nil +} + +// applyConfChangeCommitted is the EntryConfChange arm of +// applyCommitted (extracted for cyclomatic-budget hygiene; see +// applyNormalCommitted). +func (e *Engine) applyConfChangeCommitted(entry raftpb.Entry) error { + var cc raftpb.ConfChange + if err := cc.Unmarshal(entry.Data); err != nil { + return errors.WithStack(err) + } + confState := e.rawNode.ApplyConfChange(cc) + nextPeers := e.nextPeersAfterConfigChange(cc.Type, cc.NodeID, cc.Context, *confState) + if err := e.persistConfigState(entry.Index, *confState, nextPeers); err != nil { + return err + } + e.applyConfigChange(cc.Type, cc.NodeID, cc.Context, entry.Index, *confState) + e.setApplied(entry.Index) + return nil +} + +// applyConfChangeV2Committed is the EntryConfChangeV2 arm of +// applyCommitted. +func (e *Engine) applyConfChangeV2Committed(entry raftpb.Entry) error { + var cc raftpb.ConfChangeV2 + if err := cc.Unmarshal(entry.Data); err != nil { + return errors.WithStack(err) + } + confState := e.rawNode.ApplyConfChange(cc) + nextPeers := e.nextPeersAfterConfigChangeV2(cc, *confState) + if err := e.persistConfigState(entry.Index, *confState, nextPeers); err != nil { + return err + } + e.applyConfigChangeV2(cc, entry.Index, *confState) + e.setApplied(entry.Index) + return nil +} + +func (e *Engine) applyNormalEntry(entry raftpb.Entry) (any, error) { if len(entry.Data) == 0 { - return nil + return nil, nil } _, payload, ok := decodeProposalEnvelope(entry.Data) if !ok { - return nil + return nil, nil + } + // §6.3 raft envelope pre-apply hook. The order is load-bearing: + // 1. decodeProposalEnvelope already ran above to recover + // (id, payload). This is what keeps resolveProposal's + // proposal-ID lookup working — entry.Data still starts + // with 0x01 proposalEnvelopeVersion. + // 2. If entry.Index is past the cluster-wide cutover, the + // inner fsm payload is a raft envelope and MUST be + // unwrapped via the raft DEK. Strict greater-than: + // entry.Index == cutover is itself the enable-raft- + // envelope flag entry and is NOT raft-DEK-wrapped + // (§7.1). If the cipher is missing while cutover is + // active, fail closed with ErrRaftUnwrapFailed — + // handing wrapped envelope bytes to fsm.Apply would + // diverge this node from peers that successfully + // unwrapped, exactly the safety property the integrity + // tag was added to detect. + // 3. Hand the (now cleartext) payload to fsm.Apply. The + // FSM contract is unchanged at Apply(data []byte) any — + // the FSM never sees a raft envelope. + // + // FSM payload aliasing: when the cipher path is OFF, payload + // is a slice into entry.Data's backing array (DecodeProposalEnvelope + // returns a sub-slice). When the cipher path is ON, payload is + // a fresh allocation from cipher.Decrypt. FSM implementations + // MUST NOT retain `data` past Apply's return without copying, + // or the cipher / no-cipher paths will diverge in ownership. + // Stage 6 plans a defensive copy at the apply boundary so the + // contract becomes uniform regardless of cipher state. + if entry.Index > e.raftCutoverIndex() { + if e.raftCipher == nil { + // Cutover is active (a non-default cutover index has + // elected this entry as wrapped) but the cipher is + // missing — sidecar/init race or operator misconfig. + // Refuse to apply: a silent skip would diverge state + // permanently from peers that DID unwrap. + slog.Error("raft envelope cutover active but no cipher configured; halting apply", + slog.Uint64("entry_index", entry.Index), + slog.Uint64("cutover_index", e.raftCutoverIndex())) + return nil, errors.Wrap(ErrRaftUnwrapFailed, + "raftengine/etcd: entry past raft envelope cutover but no raft cipher wired") + } + plain, err := unwrapRaftPayload(e.raftCipher, payload) + if err != nil { + slog.Error("raft envelope unwrap failed; halting apply", + slog.Uint64("entry_index", entry.Index), + slog.Any("err", err)) + return nil, err + } + payload = plain } - return e.fsm.Apply(payload) + return e.fsm.Apply(payload), nil } func (e *Engine) resolveProposal(commitIndex uint64, data []byte, response any) { diff --git a/kv/raft_payload_wrapper.go b/kv/raft_payload_wrapper.go new file mode 100644 index 000000000..66d8e9109 --- /dev/null +++ b/kv/raft_payload_wrapper.go @@ -0,0 +1,77 @@ +package kv + +import ( + "context" + + "github.com/bootjp/elastickv/internal/raftengine" + "github.com/cockroachdb/errors" +) + +// RaftPayloadWrapper transforms an FSM payload into a §4.2 raft +// envelope just before submission to the engine. The Stage 3 default +// (when no wrapper is installed on a coordinator) is identity — +// payloads pass through unchanged. Stage 6's cluster-flag pipeline +// installs an active wrapper, sourced from the sidecar's currently- +// active raft DEK and a writer-registry-backed nonce factory. +// +// Implementations MUST be safe to call concurrently from many +// goroutines: the coordinator may invoke this on every concurrent +// proposal. Encryption-state transitions (Phase 1 → Phase 2 cutover) +// publish a fresh closure via atomic.Pointer so the wrapper +// observes one consistent (cipher, key_id, nonce_factory) tuple per +// call. +type RaftPayloadWrapper func(payload []byte) ([]byte, error) + +// applyRaftPayloadWrap is a coordinator-internal helper that runs +// the configured wrapper, or returns the payload verbatim when no +// wrapper is installed. Centralised so every coordinator call site +// gates payload bytes through the same path — a future audit can +// grep for engine.Propose / proposer.Propose and verify each goes +// through this helper or has an explicit "intentionally cleartext" +// reason. +func applyRaftPayloadWrap(wrap RaftPayloadWrapper, payload []byte) ([]byte, error) { + if wrap == nil { + return payload, nil + } + wrapped, err := wrap(payload) + if err != nil { + return nil, errors.Wrap(err, "kv: raft payload wrap") + } + return wrapped, nil +} + +// wrappedProposer adapts a raftengine.Proposer so every Propose call +// transparently runs the configured RaftPayloadWrapper. Used by +// transaction.go's applyRequests path and by future code that needs +// to share a single Proposer between callers some of whom wrap and +// some of whom do not — the wrapping decision lives with the +// constructed proposer, not the call site. +// +// When wrap is nil the wrappedProposer is functionally identical to +// the inner proposer; this keeps the Stage 3 default a no-op. +type wrappedProposer struct { + inner raftengine.Proposer + wrap RaftPayloadWrapper +} + +// newWrappedProposer returns the inner proposer untouched when the +// wrapper is nil, so the cipher-disabled path stays a single +// pointer assignment. +func newWrappedProposer(inner raftengine.Proposer, wrap RaftPayloadWrapper) raftengine.Proposer { + if wrap == nil { + return inner + } + return &wrappedProposer{inner: inner, wrap: wrap} +} + +func (p *wrappedProposer) Propose(ctx context.Context, data []byte) (*raftengine.ProposalResult, error) { + wrapped, err := applyRaftPayloadWrap(p.wrap, data) + if err != nil { + return nil, err + } + res, err := p.inner.Propose(ctx, wrapped) + if err != nil { + return nil, errors.Wrap(err, "kv: wrapped propose") + } + return res, nil +} diff --git a/kv/raft_payload_wrapper_test.go b/kv/raft_payload_wrapper_test.go new file mode 100644 index 000000000..7da358214 --- /dev/null +++ b/kv/raft_payload_wrapper_test.go @@ -0,0 +1,154 @@ +package kv + +import ( + "bytes" + "context" + "crypto/rand" + "errors" + "sync/atomic" + "testing" + + "github.com/bootjp/elastickv/internal/encryption" + "github.com/bootjp/elastickv/internal/raftengine" +) + +// fakeProposer records every Propose call so the wrapper tests can +// inspect what bytes the engine would have seen. +type fakeProposer struct { + calls atomic.Int32 + last []byte + resp *raftengine.ProposalResult + err error +} + +func (p *fakeProposer) Propose(_ context.Context, data []byte) (*raftengine.ProposalResult, error) { + p.calls.Add(1) + cp := make([]byte, len(data)) + copy(cp, data) + p.last = cp + if p.err != nil { + return nil, p.err + } + if p.resp == nil { + return &raftengine.ProposalResult{CommitIndex: 1}, nil + } + return p.resp, nil +} + +func TestApplyRaftPayloadWrap_NilIsPassThrough(t *testing.T) { + t.Parallel() + got, err := applyRaftPayloadWrap(nil, []byte("hello")) + if err != nil { + t.Fatalf("applyRaftPayloadWrap: %v", err) + } + if !bytes.Equal(got, []byte("hello")) { + t.Fatalf("nil wrapper mutated payload: got %q", got) + } +} + +func TestApplyRaftPayloadWrap_PropagatesError(t *testing.T) { + t.Parallel() + sentinel := errors.New("wrap-side fail") + wrap := func([]byte) ([]byte, error) { return nil, sentinel } + _, err := applyRaftPayloadWrap(wrap, []byte("x")) + if !errors.Is(err, sentinel) { + t.Fatalf("expected wrapped sentinel, got %v", err) + } +} + +func TestNewWrappedProposer_NilWrapperReturnsInnerVerbatim(t *testing.T) { + t.Parallel() + inner := &fakeProposer{} + got := newWrappedProposer(inner, nil) + // Stage 3 default: identical pointer; no allocation. + if got != raftengine.Proposer(inner) { + t.Fatal("nil wrapper: newWrappedProposer should return the inner proposer verbatim") + } +} + +func TestWrappedProposer_InvokesWrapperOncePerCall(t *testing.T) { + t.Parallel() + var wrapCalls atomic.Int32 + wrap := func(p []byte) ([]byte, error) { + wrapCalls.Add(1) + out := make([]byte, len(p)+1) + out[0] = 'W' + copy(out[1:], p) + return out, nil + } + inner := &fakeProposer{} + wp := newWrappedProposer(inner, wrap) + if _, err := wp.Propose(context.Background(), []byte("payload")); err != nil { + t.Fatalf("Propose: %v", err) + } + if got := wrapCalls.Load(); got != 1 { + t.Fatalf("wrapper call count = %d, want 1", got) + } + if got := inner.calls.Load(); got != 1 { + t.Fatalf("inner.Propose call count = %d, want 1", got) + } + want := append([]byte{'W'}, []byte("payload")...) + if !bytes.Equal(inner.last, want) { + t.Fatalf("inner saw %q, want %q (wrapper output)", inner.last, want) + } +} + +func TestWrappedProposer_PropagatesWrapperError(t *testing.T) { + t.Parallel() + sentinel := errors.New("wrapper denied") + inner := &fakeProposer{} + wp := newWrappedProposer(inner, func([]byte) ([]byte, error) { return nil, sentinel }) + _, err := wp.Propose(context.Background(), []byte("x")) + if !errors.Is(err, sentinel) { + t.Fatalf("expected sentinel, got %v", err) + } + if got := inner.calls.Load(); got != 0 { + t.Fatalf("inner.Propose called %d times despite wrap fail", got) + } +} + +// TestWrappedProposer_RoundTripWithRealCipher exercises the seam end- +// to-end: wrap with a real raft envelope, the inner proposer +// observes the encrypted bytes, and a hand-rolled Unwrap recovers +// the original plaintext (the engine-side hook from +// internal/raftengine/etcd is unit-tested separately; this test +// proves the coordinator's wrap output is shape-compatible with +// what the engine expects). +func TestWrappedProposer_RoundTripWithRealCipher(t *testing.T) { + t.Parallel() + ks := encryption.NewKeystore() + dek := make([]byte, encryption.KeySize) + if _, err := rand.Read(dek); err != nil { + t.Fatalf("rand.Read: %v", err) + } + const kid uint32 = 0x42 + if err := ks.Set(kid, dek); err != nil { + t.Fatalf("Set: %v", err) + } + c, err := encryption.NewCipher(ks) + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + + wrap := func(p []byte) ([]byte, error) { + nonce := make([]byte, encryption.NonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + return encryption.WrapRaftPayload(c, kid, nonce, p) + } + inner := &fakeProposer{} + wp := newWrappedProposer(inner, wrap) + + plaintext := []byte("op=put key=k1 v=secret") + if _, err := wp.Propose(context.Background(), plaintext); err != nil { + t.Fatalf("Propose: %v", err) + } + got, err := encryption.UnwrapRaftPayload(c, inner.last) + if err != nil { + t.Fatalf("UnwrapRaftPayload: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("round-trip mismatch: got %q, want %q", got, plaintext) + } +}