Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions providers/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package vault

import (
"bytes"
"context"
"encoding/base64"
"io"
"path/filepath"
"time"

vaultapi "github.com/hashicorp/vault/api"
"github.com/pkg/errors"
Expand Down Expand Up @@ -66,23 +68,26 @@ func (v *VaultCrypter) Encrypt(w io.Writer, r io.Reader) error {
}

// Plaintext must be base64-encoded
p := map[string]interface{}{
p := map[string]any{
"plaintext": base64.StdEncoding.EncodeToString(src.Bytes()),
}

resp, err := v.client.Logical().Write(v.getEncryptEndpoint(), p)
vCtx, vCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer vCancel()

resp, err := v.client.Logical().WriteWithContext(vCtx, v.getEncryptEndpoint(), p)
if err != nil {
return errors.Wrapf(err, "failed to encrypt data using transit secrets engine: mount %q and key %q", v.mount, v.key)
}

data, ok := resp.Data["ciphertext"]
if !ok {
return errors.Wrap(err, "failed to extract ciphertext from Vault's response")
return errors.New("failed to extract ciphertext from Vault's response")
}

ciphertext, ok := data.(string)
if !ok {
return errors.Wrap(err, "failed to convert ciphertext to string")
return errors.Errorf("failed to convert ciphertext to string: got %T", data)
}

w.Write([]byte(ciphertext))
Expand All @@ -96,28 +101,31 @@ func (v *VaultCrypter) Encrypt(w io.Writer, r io.Reader) error {
// See: https://www.vaultproject.io/api-docs/secret/transit#decrypt-data
func (v *VaultCrypter) Decrypt(w io.Writer, r io.Reader) error {
src := new(bytes.Buffer)
_, err := src.ReadFrom(r)
if err != nil {

if _, err := src.ReadFrom(r); err != nil {
return errors.Wrap(err, "failed to read from io.Reader")
}

p := map[string]interface{}{
p := map[string]any{
"ciphertext": src.String(),
}

resp, err := v.client.Logical().Write(v.getDecryptEndpoint(), p)
vCtx, vCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer vCancel()

resp, err := v.client.Logical().WriteWithContext(vCtx, v.getDecryptEndpoint(), p)
if err != nil {
return errors.Wrapf(err, "failed to decrypt data using transit secrets engine: mount %q and key %q", v.mount, v.key)
}

data, ok := resp.Data["plaintext"]
if !ok {
return errors.Wrap(err, "failed to extract plaintext from Vault's response")
return errors.New("failed to extract plaintext from Vault's response")
}

b64Plaintext, ok := data.(string)
if !ok {
return errors.Wrap(err, "failed to convert plaintext to string")
return errors.Errorf("failed to convert plaintext to string: got %T", data)
}

// Plaintext is base64 encoded and must be decoded
Expand Down
89 changes: 89 additions & 0 deletions providers/vault/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package vault

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

vaultapi "github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"github.com/bincyber/go-sqlcrypter"
Expand Down Expand Up @@ -33,6 +37,19 @@ type VaultCrypterTestSuite struct {
vaultCrypter sqlcrypter.Crypterer
}

func newTestVaultClient(t *testing.T, h http.HandlerFunc) (*vaultapi.Client, func()) {
t.Helper()

server := httptest.NewServer(h)

client, err := vaultapi.NewClient(&vaultapi.Config{
Address: server.URL,
})
require.NoError(t, err)

return client, server.Close
}

func (s *VaultCrypterTestSuite) SetupTest() {
s.client = getVaultClient()

Expand Down Expand Up @@ -161,6 +178,78 @@ func (s *VaultCrypterTestSuite) Test_Decrypt_err() {
r.Contains(err.Error(), "failed to decrypt data using transit secrets engine")
}

func Test_Encrypt_ResponseMissingCiphertext(t *testing.T) {
client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":{}}`))
})
defer closeServer()

crypter, err := New(client, transitMount, transitKey)
require.NoError(t, err)

writer := new(bytes.Buffer)
reader := bytes.NewBufferString("Hello World")

err = crypter.Encrypt(writer, reader)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract ciphertext from Vault's response")
}

func Test_Encrypt_ResponseCiphertextNotString(t *testing.T) {
client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":{"ciphertext":123}}`))
})
defer closeServer()

crypter, err := New(client, transitMount, transitKey)
require.NoError(t, err)

writer := new(bytes.Buffer)
reader := bytes.NewBufferString("Hello World")

err = crypter.Encrypt(writer, reader)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to convert ciphertext to string")
}

func Test_Decrypt_ResponseMissingPlaintext(t *testing.T) {
client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":{}}`))
})
defer closeServer()

crypter, err := New(client, transitMount, transitKey)
require.NoError(t, err)

writer := new(bytes.Buffer)
reader := bytes.NewBufferString("vault:v1:some-ciphertext")

err = crypter.Decrypt(writer, reader)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract plaintext from Vault's response")
}

func Test_Decrypt_ResponsePlaintextNotString(t *testing.T) {
client, closeServer := newTestVaultClient(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":{"plaintext":123}}`))
})
defer closeServer()

crypter, err := New(client, transitMount, transitKey)
require.NoError(t, err)

writer := new(bytes.Buffer)
reader := bytes.NewBufferString("vault:v1:some-ciphertext")

err = crypter.Decrypt(writer, reader)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to convert plaintext to string")
}

func Test_VaultCrypterTestSuite(t *testing.T) {
suite.Run(t, new(VaultCrypterTestSuite))
}
Loading