From 7b490dac009e5643a0c1fda6be74b6a59ea5483a Mon Sep 17 00:00:00 2001 From: Ali Date: Sat, 4 Apr 2026 15:03:16 +0400 Subject: [PATCH] fix: error handling and missing request context for Vault provider --- providers/vault/vault.go | 28 +++++++---- providers/vault/vault_test.go | 89 +++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 10 deletions(-) diff --git a/providers/vault/vault.go b/providers/vault/vault.go index dd79838..11e7b58 100644 --- a/providers/vault/vault.go +++ b/providers/vault/vault.go @@ -2,9 +2,11 @@ package vault import ( "bytes" + "context" "encoding/base64" "io" "path/filepath" + "time" vaultapi "github.com/hashicorp/vault/api" "github.com/pkg/errors" @@ -66,23 +68,26 @@ func (v *VaultCrypter) Encrypt(w io.Writer, r io.Reader) error { } // Plaintext must be base64-encoded - p := map[string]interface{}{ + p := map[string]any{ "plaintext": base64.StdEncoding.EncodeToString(src.Bytes()), } - resp, err := v.client.Logical().Write(v.getEncryptEndpoint(), p) + vCtx, vCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer vCancel() + + resp, err := v.client.Logical().WriteWithContext(vCtx, v.getEncryptEndpoint(), p) if err != nil { return errors.Wrapf(err, "failed to encrypt data using transit secrets engine: mount %q and key %q", v.mount, v.key) } data, ok := resp.Data["ciphertext"] if !ok { - return errors.Wrap(err, "failed to extract ciphertext from Vault's response") + return errors.New("failed to extract ciphertext from Vault's response") } ciphertext, ok := data.(string) if !ok { - return errors.Wrap(err, "failed to convert ciphertext to string") + return errors.Errorf("failed to convert ciphertext to string: got %T", data) } w.Write([]byte(ciphertext)) @@ -96,28 +101,31 @@ func (v *VaultCrypter) Encrypt(w io.Writer, r io.Reader) error { // See: https://www.vaultproject.io/api-docs/secret/transit#decrypt-data func (v *VaultCrypter) Decrypt(w io.Writer, r io.Reader) error { src := new(bytes.Buffer) - _, err := src.ReadFrom(r) - if err != nil { + + if _, err := src.ReadFrom(r); err != nil { return errors.Wrap(err, "failed to read from io.Reader") } - p := map[string]interface{}{ + p := map[string]any{ "ciphertext": src.String(), } - resp, err := v.client.Logical().Write(v.getDecryptEndpoint(), p) + vCtx, vCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer vCancel() + + resp, err := v.client.Logical().WriteWithContext(vCtx, v.getDecryptEndpoint(), p) if err != nil { return errors.Wrapf(err, "failed to decrypt data using transit secrets engine: mount %q and key %q", v.mount, v.key) } data, ok := resp.Data["plaintext"] if !ok { - return errors.Wrap(err, "failed to extract plaintext from Vault's response") + return errors.New("failed to extract plaintext from Vault's response") } b64Plaintext, ok := data.(string) if !ok { - return errors.Wrap(err, "failed to convert plaintext to string") + return errors.Errorf("failed to convert plaintext to string: got %T", data) } // Plaintext is base64 encoded and must be decoded diff --git a/providers/vault/vault_test.go b/providers/vault/vault_test.go index 8d32e31..4a99a89 100644 --- a/providers/vault/vault_test.go +++ b/providers/vault/vault_test.go @@ -2,9 +2,13 @@ package vault import ( "bytes" + "net/http" + "net/http/httptest" "testing" vaultapi "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/bincyber/go-sqlcrypter" @@ -33,6 +37,19 @@ type VaultCrypterTestSuite struct { vaultCrypter sqlcrypter.Crypterer } +func newTestVaultClient(t *testing.T, h http.HandlerFunc) (*vaultapi.Client, func()) { + t.Helper() + + server := httptest.NewServer(h) + + client, err := vaultapi.NewClient(&vaultapi.Config{ + Address: server.URL, + }) + require.NoError(t, err) + + return client, server.Close +} + func (s *VaultCrypterTestSuite) SetupTest() { s.client = getVaultClient() @@ -161,6 +178,78 @@ func (s *VaultCrypterTestSuite) Test_Decrypt_err() { r.Contains(err.Error(), "failed to decrypt data using transit secrets engine") } +func Test_Encrypt_ResponseMissingCiphertext(t *testing.T) { + client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{}}`)) + }) + defer closeServer() + + crypter, err := New(client, transitMount, transitKey) + require.NoError(t, err) + + writer := new(bytes.Buffer) + reader := bytes.NewBufferString("Hello World") + + err = crypter.Encrypt(writer, reader) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to extract ciphertext from Vault's response") +} + +func Test_Encrypt_ResponseCiphertextNotString(t *testing.T) { + client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{"ciphertext":123}}`)) + }) + defer closeServer() + + crypter, err := New(client, transitMount, transitKey) + require.NoError(t, err) + + writer := new(bytes.Buffer) + reader := bytes.NewBufferString("Hello World") + + err = crypter.Encrypt(writer, reader) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to convert ciphertext to string") +} + +func Test_Decrypt_ResponseMissingPlaintext(t *testing.T) { + client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{}}`)) + }) + defer closeServer() + + crypter, err := New(client, transitMount, transitKey) + require.NoError(t, err) + + writer := new(bytes.Buffer) + reader := bytes.NewBufferString("vault:v1:some-ciphertext") + + err = crypter.Decrypt(writer, reader) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to extract plaintext from Vault's response") +} + +func Test_Decrypt_ResponsePlaintextNotString(t *testing.T) { + client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{"plaintext":123}}`)) + }) + defer closeServer() + + crypter, err := New(client, transitMount, transitKey) + require.NoError(t, err) + + writer := new(bytes.Buffer) + reader := bytes.NewBufferString("vault:v1:some-ciphertext") + + err = crypter.Decrypt(writer, reader) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to convert plaintext to string") +} + func Test_VaultCrypterTestSuite(t *testing.T) { suite.Run(t, new(VaultCrypterTestSuite)) }