Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions lib/ocrypto/ec_key_pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ type ECCMode uint8
type KeyType string

const (
RSA2048Key KeyType = "rsa:2048"
RSA4096Key KeyType = "rsa:4096"
EC256Key KeyType = "ec:secp256r1"
EC384Key KeyType = "ec:secp384r1"
EC521Key KeyType = "ec:secp521r1"
RSA2048Key KeyType = "rsa:2048"
RSA4096Key KeyType = "rsa:4096"
EC256Key KeyType = "ec:secp256r1"
EC384Key KeyType = "ec:secp384r1"
EC521Key KeyType = "ec:secp521r1"
MLKEM768Key KeyType = "mlkem:768"
)

const (
Expand Down Expand Up @@ -64,6 +65,8 @@ func NewKeyPair(kt KeyType) (KeyPair, error) {
return nil, err
}
return NewECKeyPair(mode)
case MLKEM768Key:
return NewMLKEMKeyPair()
default:
return nil, fmt.Errorf("unsupported key type: %v", kt)
}
Expand Down
205 changes: 205 additions & 0 deletions lib/ocrypto/mlkem_key_pair.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package ocrypto

import (
"crypto/aes"
"crypto/cipher"
"crypto/mlkem"
"crypto/rand"
"crypto/sha256"
"encoding/pem"
"errors"
"fmt"
"io"

"golang.org/x/crypto/hkdf"
)

const (
// MLKEM768CiphertextSize is the byte length of an ML-KEM-768 ciphertext.
MLKEM768CiphertextSize = 1088
)

// MLKEMKeyPair holds an ML-KEM-768 decapsulation (private) key.
// The public encapsulation key is derived from the private key.
type MLKEMKeyPair struct {
dk *mlkem.DecapsulationKey768
}

// NewMLKEMKeyPair generates a fresh ML-KEM-768 key pair.
func NewMLKEMKeyPair() (MLKEMKeyPair, error) {
dk, err := mlkem.GenerateKey768()
if err != nil {
return MLKEMKeyPair{}, fmt.Errorf("mlkem.GenerateKey768 failed: %w", err)
}
return MLKEMKeyPair{dk: dk}, nil
}

// IsMLKEMKeyType reports whether the given KeyType is an ML-KEM type.
func IsMLKEMKeyType(kt KeyType) bool {
return kt == MLKEM768Key
}

// GetKeyType implements KeyPair.
func (kp MLKEMKeyPair) GetKeyType() KeyType {
return MLKEM768Key
}

// PublicKeyInPemFormat returns the ML-KEM-768 encapsulation key in PEM format.
func (kp MLKEMKeyPair) PublicKeyInPemFormat() (string, error) {
if kp.dk == nil {
return "", errors.New("nil ML-KEM-768 key")
}
b := kp.dk.EncapsulationKey().Bytes()
block := &pem.Block{
Type: "ML-KEM-768 PUBLIC KEY",
Bytes: b,
}
return string(pem.EncodeToMemory(block)), nil
}

// PrivateKeyInPemFormat returns the ML-KEM-768 seed (private key) in PEM format.
func (kp MLKEMKeyPair) PrivateKeyInPemFormat() (string, error) {
if kp.dk == nil {
return "", errors.New("nil ML-KEM-768 key")
}
block := &pem.Block{
Type: "ML-KEM-768 PRIVATE KEY",
Bytes: kp.dk.Bytes(),
}
return string(pem.EncodeToMemory(block)), nil
}

// MLKEMDecapsulateAndUnwrap recovers the DEK from an ML-KEM-768 wrapped key blob.
//
// wrappedKey layout (after base64 decoding by the caller):
//
// [0 : 1088] ML-KEM-768 ciphertext
// [1088 : ] AES-256-GCM encrypted DEK (12-byte nonce prepended)
//
// The AES wrapping key is: HKDF-SHA256(shared_secret, salt=TDF-salt).
func MLKEMDecapsulateAndUnwrap(privateKeyPEM []byte, wrappedKey []byte) ([]byte, error) {
if len(wrappedKey) <= MLKEM768CiphertextSize {
return nil, fmt.Errorf("mlkem wrapped key too short: %d bytes", len(wrappedKey))
}

dk, err := mlkemDecapsKeyFromPEM(privateKeyPEM)
if err != nil {
return nil, err
}

ct := wrappedKey[:MLKEM768CiphertextSize]
encDEK := wrappedKey[MLKEM768CiphertextSize:]

sharedSecret, err := dk.Decapsulate(ct)
if err != nil {
return nil, fmt.Errorf("mlkem decapsulate failed: %w", err)
}

wk, err := deriveMLKEMWrappingKey(sharedSecret)
if err != nil {
return nil, err
}

return aesGCMDecrypt(wk, encDEK)
}

// mlkemDecapsKeyFromPEM parses a PEM-encoded ML-KEM-768 private key (seed).
func mlkemDecapsKeyFromPEM(privateKeyPEM []byte) (*mlkem.DecapsulationKey768, error) {
block, _ := pem.Decode(privateKeyPEM)
if block == nil {
return nil, errors.New("failed to parse ML-KEM-768 PEM private key")
}
Comment on lines +108 to +111
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PEM decoding logic should validate the Type of the PEM block to ensure it matches the expected ML-KEM-768 private key header. This prevents accidental processing of incorrect key types.

Suggested change
block, _ := pem.Decode(privateKeyPEM)
if block == nil {
return nil, errors.New("failed to parse ML-KEM-768 PEM private key")
}
block, _ := pem.Decode(privateKeyPEM)
if block == nil || block.Type != "ML-KEM-768 PRIVATE KEY" {
return nil, errors.New("failed to parse ML-KEM-768 PEM private key")
}

dk, err := mlkem.NewDecapsulationKey768(block.Bytes)
if err != nil {
return nil, fmt.Errorf("mlkem.NewDecapsulationKey768 failed: %w", err)
}
return dk, nil
}

// deriveMLKEMWrappingKey derives a 32-byte AES key from the ML-KEM shared secret
// using HKDF-SHA256 with the standard TDF salt.
func deriveMLKEMWrappingKey(sharedSecret []byte) ([]byte, error) {
salt := mlkemTDFSalt()
h := hkdf.New(sha256.New, sharedSecret, salt, nil)
key := make([]byte, 32) //nolint:mnd // AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, fmt.Errorf("hkdf derivation failed: %w", err)
}
return key, nil
}

// mlkemTDFSalt returns the SHA-256("TDF") salt used for HKDF in ML-KEM key wrapping.
func mlkemTDFSalt() []byte {
h := sha256.New()
h.Write([]byte("TDF"))
return h.Sum(nil)
}
Comment on lines +131 to +136
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Computing the SHA-256 hash of the "TDF" string on every call is inefficient as the result is constant. Consider precomputing this salt as a package-level variable.

Suggested change
// mlkemTDFSalt returns the SHA-256("TDF") salt used for HKDF in ML-KEM key wrapping.
func mlkemTDFSalt() []byte {
h := sha256.New()
h.Write([]byte("TDF"))
return h.Sum(nil)
}
var mlkemTDFSaltValue = sha256.Sum256([]byte("TDF"))
// mlkemTDFSalt returns the SHA-256("TDF") salt used for HKDF in ML-KEM key wrapping.
func mlkemTDFSalt() []byte {
return mlkemTDFSaltValue[:]
}


// aesGCMDecrypt decrypts AES-256-GCM ciphertext of the form: [12-byte nonce | ciphertext+tag].
func aesGCMDecrypt(key, data []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("aes.NewCipher failed: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("cipher.NewGCM failed: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short for nonce")
}
Comment on lines +149 to +151
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The length check for the ciphertext is insufficient. It should ensure the data is long enough to contain both the nonce and the GCM authentication tag (typically 12 + 16 = 28 bytes).

Suggested change
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short for nonce")
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize+gcm.Overhead() {
return nil, errors.New("ciphertext too short")
}

plaintext, err := gcm.Open(nil, data[:nonceSize], data[nonceSize:], nil)

Check failure on line 152 in lib/ocrypto/mlkem_key_pair.go

View workflow job for this annotation

GitHub Actions / go (lib/ocrypto)

G602: slice bounds out of range (gosec)
if err != nil {
return nil, fmt.Errorf("aes-gcm open failed: %w", err)
}
return plaintext, nil
}

// MLKEMEncapsulateAndWrap encapsulates a DEK for the given ML-KEM-768 public key PEM.
// Returns wrappedKey = ciphertext || AES-GCM(wk, dek).
// This is the counterpart used by SDK implementations; provided here for testing.
func MLKEMEncapsulateAndWrap(publicKeyPEM []byte, dek []byte) ([]byte, error) {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse ML-KEM-768 PEM public key")
}
Comment on lines +163 to +166
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PEM decoding logic should validate the Type of the PEM block to ensure it matches the expected ML-KEM-768 public key header.

Suggested change
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse ML-KEM-768 PEM public key")
}
block, _ := pem.Decode(publicKeyPEM)
if block == nil || block.Type != "ML-KEM-768 PUBLIC KEY" {
return nil, errors.New("failed to parse ML-KEM-768 PEM public key")
}

ek, err := mlkem.NewEncapsulationKey768(block.Bytes)
if err != nil {
return nil, fmt.Errorf("mlkem.NewEncapsulationKey768 failed: %w", err)
}

sharedSecret, ct := ek.Encapsulate()

wk, err := deriveMLKEMWrappingKey(sharedSecret)
if err != nil {
return nil, err
}

encDEK, err := aesGCMEncrypt(wk, dek)
if err != nil {
return nil, err
}

result := make([]byte, 0, MLKEM768CiphertextSize+len(encDEK))
result = append(result, ct...)
result = append(result, encDEK...)
return result, nil
}

// aesGCMEncrypt encrypts plaintext using AES-256-GCM, prepending a random 12-byte nonce.
func aesGCMEncrypt(key, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("aes.NewCipher failed: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("cipher.NewGCM failed: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("nonce generation failed: %w", err)
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
}
94 changes: 94 additions & 0 deletions lib/ocrypto/mlkem_key_pair_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ocrypto

import (
"crypto/rand"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewMLKEMKeyPair(t *testing.T) {
kp, err := NewMLKEMKeyPair()
require.NoError(t, err)
assert.Equal(t, MLKEM768Key, kp.GetKeyType())

pubPEM, err := kp.PublicKeyInPemFormat()
require.NoError(t, err)
assert.Contains(t, pubPEM, "ML-KEM-768 PUBLIC KEY")

privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)
assert.Contains(t, privPEM, "ML-KEM-768 PRIVATE KEY")
}

func TestMLKEMRoundtrip(t *testing.T) {
kp, err := NewMLKEMKeyPair()
require.NoError(t, err)

pubPEM, err := kp.PublicKeyInPemFormat()
require.NoError(t, err)

privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)

// Generate a random DEK to wrap
dek := make([]byte, 32)
_, err = rand.Read(dek)
require.NoError(t, err)

// Encapsulate (SDK side)
wrappedKey, err := MLKEMEncapsulateAndWrap([]byte(pubPEM), dek)
require.NoError(t, err)
assert.Len(t, wrappedKey, MLKEM768CiphertextSize+12+32+16, "wrappedKey should be ciphertext + nonce + dek + gcm-tag")

// Decapsulate (KAS side)
recovered, err := MLKEMDecapsulateAndUnwrap([]byte(privPEM), wrappedKey)
require.NoError(t, err)
assert.Equal(t, dek, recovered)
}

func TestMLKEMDecapsulateWrongKey(t *testing.T) {
kp1, err := NewMLKEMKeyPair()
require.NoError(t, err)
kp2, err := NewMLKEMKeyPair()
require.NoError(t, err)

pubPEM1, err := kp1.PublicKeyInPemFormat()
require.NoError(t, err)
privPEM2, err := kp2.PrivateKeyInPemFormat()
require.NoError(t, err)

dek := make([]byte, 32)
_, err = rand.Read(dek)
require.NoError(t, err)

wrappedKey, err := MLKEMEncapsulateAndWrap([]byte(pubPEM1), dek)
require.NoError(t, err)

// Decapsulating with a different key should fail
_, err = MLKEMDecapsulateAndUnwrap([]byte(privPEM2), wrappedKey)
assert.Error(t, err)
}

func TestMLKEMDecapsulateTooShort(t *testing.T) {
kp, err := NewMLKEMKeyPair()
require.NoError(t, err)
privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)

_, err = MLKEMDecapsulateAndUnwrap([]byte(privPEM), make([]byte, 100))
assert.Error(t, err)
}

func TestIsMLKEMKeyType(t *testing.T) {
assert.True(t, IsMLKEMKeyType(MLKEM768Key))
assert.False(t, IsMLKEMKeyType(RSA2048Key))
assert.False(t, IsMLKEMKeyType(EC256Key))
}

func TestNewKeyPairMLKEM(t *testing.T) {
kp, err := NewKeyPair(MLKEM768Key)
require.NoError(t, err)
assert.Equal(t, MLKEM768Key, kp.GetKeyType())
}
1 change: 1 addition & 0 deletions opentdf-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ services:
preview:
ec_tdf_enabled: false
key_management: false
mlkem_enabled: true
root_key: a8c4824daafcfa38ed0d13002e92b08720e6c4fcee67d52e954c1a6e045907d1 # For local development testing only
keyring:
- kid: e1
Expand Down
Loading
Loading