diff --git a/lib/ocrypto/ec_key_pair.go b/lib/ocrypto/ec_key_pair.go index f9a9554d4f..7ff17ed4e6 100644 --- a/lib/ocrypto/ec_key_pair.go +++ b/lib/ocrypto/ec_key_pair.go @@ -22,11 +22,12 @@ type ECCMode uint8 type KeyType string const ( - RSA2048Key KeyType = "rsa:2048" - RSA4096Key KeyType = "rsa:4096" - EC256Key KeyType = "ec:secp256r1" - EC384Key KeyType = "ec:secp384r1" - EC521Key KeyType = "ec:secp521r1" + RSA2048Key KeyType = "rsa:2048" + RSA4096Key KeyType = "rsa:4096" + EC256Key KeyType = "ec:secp256r1" + EC384Key KeyType = "ec:secp384r1" + EC521Key KeyType = "ec:secp521r1" + MLKem768Key KeyType = "mlkem:768" ) const ( @@ -91,6 +92,10 @@ func IsRSAKeyType(kt KeyType) bool { } } +func IsMLKEMKeyType(kt KeyType) bool { + return kt == MLKem768Key +} + // GetECCurveFromECCMode return elliptic curve from ecc mode func GetECCurveFromECCMode(mode ECCMode) (elliptic.Curve, error) { var c elliptic.Curve diff --git a/lib/ocrypto/mlkem.go b/lib/ocrypto/mlkem.go new file mode 100644 index 0000000000..c0e814e854 --- /dev/null +++ b/lib/ocrypto/mlkem.go @@ -0,0 +1,33 @@ +package ocrypto + +import ( + "crypto/mlkem" + "encoding/pem" + "fmt" +) + +const ( + // MLKem768CiphertextSize is the byte length of an ML-KEM-768 ciphertext. + MLKem768CiphertextSize = 1088 + // MLKem768PublicKeySize is the byte length of an ML-KEM-768 encapsulation key. + MLKem768PublicKeySize = 1184 + + mlkem768PEMType = "ML-KEM-768 PUBLIC KEY" +) + +// MLKEMPublicKeyFromPEM parses an ML-KEM-768 encapsulation key from a PEM block +// with type "ML-KEM-768 PUBLIC KEY". +func MLKEMPublicKeyFromPEM(pemData []byte) (*mlkem.EncapsulationKey768, error) { + block, _ := pem.Decode(pemData) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block for ML-KEM-768 public key") + } + if block.Type != mlkem768PEMType { + return nil, fmt.Errorf("unexpected PEM type %q, expected %q", block.Type, mlkem768PEMType) + } + key, err := mlkem.NewEncapsulationKey768(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse ML-KEM-768 encapsulation key: %w", err) + } + return key, nil +} diff --git a/sdk/experimental/tdf/key_access.go b/sdk/experimental/tdf/key_access.go index 6e97701ad5..6b8ff9956d 100644 --- a/sdk/experimental/tdf/key_access.go +++ b/sdk/experimental/tdf/key_access.go @@ -169,6 +169,10 @@ func wrapKeyWithPublicKey(symKey []byte, pubKeyInfo keysplit.KASPublicKey) (stri // Handle EC key wrapping return wrapKeyWithEC(ktype, pubKeyInfo.PEM, symKey) } + if ocrypto.IsMLKEMKeyType(ktype) { + wrapped, err := wrapKeyWithMLKEM(pubKeyInfo.PEM, symKey) + return wrapped, "wrapped", "", err + } // Handle RSA key wrapping wrapped, err := wrapKeyWithRSA(pubKeyInfo.PEM, symKey) return wrapped, "wrapped", "", err @@ -245,3 +249,32 @@ func wrapKeyWithRSA(kasPublicKeyPEM string, symKey []byte) (string, error) { return string(ocrypto.Base64Encode(encryptedKey)), nil } + +// wrapKeyWithMLKEM encapsulates a symmetric key using ML-KEM-768. +// Wire format: base64(ml_kem_ciphertext [1088 bytes] || aes_gcm_wrapped_dek) +// The ephemeralPublicKey field is not set for this key type. +func wrapKeyWithMLKEM(kasPublicKeyPEM string, symKey []byte) (string, error) { + encKey, err := ocrypto.MLKEMPublicKeyFromPEM([]byte(kasPublicKeyPEM)) + if err != nil { + return "", fmt.Errorf("failed to parse ML-KEM-768 public key: %w", err) + } + + // Encapsulate: sharedKey is used to wrap the DEK; ciphertext goes on the wire. + sharedKey, ciphertext := encKey.Encapsulate() + + gcm, err := ocrypto.NewAESGcm(sharedKey) + if err != nil { + return "", fmt.Errorf("failed to create AES-GCM for ML-KEM DEK wrap: %w", err) + } + + wrappedDEK, err := gcm.Encrypt(symKey) + if err != nil { + return "", fmt.Errorf("failed to AES-GCM wrap DEK: %w", err) + } + + payload := make([]byte, 0, len(ciphertext)+len(wrappedDEK)) + payload = append(payload, ciphertext...) + payload = append(payload, wrappedDEK...) + + return string(ocrypto.Base64Encode(payload)), nil +} diff --git a/sdk/experimental/tdf/key_access_test.go b/sdk/experimental/tdf/key_access_test.go index 9dac3ca3b6..562484c46d 100644 --- a/sdk/experimental/tdf/key_access_test.go +++ b/sdk/experimental/tdf/key_access_test.go @@ -3,8 +3,10 @@ package tdf import ( + "crypto/mlkem" "crypto/rand" "encoding/json" + "encoding/pem" "strings" "testing" @@ -473,3 +475,63 @@ func TestTdfSalt(t *testing.T) { assert.NotEmpty(t, salt1, "Salt should not be empty") }) } + +// mlkem768TestPublicKeyPEM generates a fresh ML-KEM-768 key pair and returns +// the public key as a PEM block with type "ML-KEM-768 PUBLIC KEY". +func mlkem768TestPublicKeyPEM(t *testing.T) (string, *mlkem.DecapsulationKey768) { + t.Helper() + dk, err := mlkem.GenerateKey768() + require.NoError(t, err) + pubKeyBytes := dk.EncapsulationKey().Bytes() + block := &pem.Block{Type: "ML-KEM-768 PUBLIC KEY", Bytes: pubKeyBytes} + return string(pem.EncodeToMemory(block)), dk +} + +func TestWrapKeyWithMLKEM(t *testing.T) { + t.Run("wraps key and produces correct wire format", func(t *testing.T) { + pubKeyPEM, dk := mlkem768TestPublicKeyPEM(t) + + symKey := make([]byte, 32) + _, err := rand.Read(symKey) + require.NoError(t, err) + + wrappedB64, err := wrapKeyWithMLKEM(pubKeyPEM, symKey) + require.NoError(t, err) + assert.NotEmpty(t, wrappedB64) + + payload, err := ocrypto.Base64Decode([]byte(wrappedB64)) + require.NoError(t, err) + assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize, "payload must include ciphertext + wrapped DEK") + + ciphertext := payload[:ocrypto.MLKem768CiphertextSize] + wrappedDEK := payload[ocrypto.MLKem768CiphertextSize:] + + sharedKey, err := dk.Decapsulate(ciphertext) + require.NoError(t, err) + + gcm, err := ocrypto.NewAESGcm(sharedKey) + require.NoError(t, err) + + recoveredDEK, err := gcm.Decrypt(wrappedDEK) + require.NoError(t, err) + assert.Equal(t, symKey, recoveredDEK, "recovered DEK must match original") + }) + + t.Run("buildKeyAccessObjects uses wrapped key type with no ephemeral key", func(t *testing.T) { + pubKeyPEM, _ := mlkem768TestPublicKeyPEM(t) + splitResult := createTestSplitResult(testKAS1URL, pubKeyPEM, "mlkem:768") + + keyAccessList, err := buildKeyAccessObjects(splitResult, []byte(testPolicyJSON), "") + require.NoError(t, err) + require.Len(t, keyAccessList, 1) + + ka := keyAccessList[0] + assert.Equal(t, "wrapped", ka.KeyType, "ML-KEM KAO must use 'wrapped' key type") + assert.Empty(t, ka.EphemeralPublicKey, "ML-KEM KAO must not set ephemeralPublicKey") + assert.NotEmpty(t, ka.WrappedKey) + + payload, err := ocrypto.Base64Decode([]byte(ka.WrappedKey)) + require.NoError(t, err) + assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize) + }) +} diff --git a/sdk/experimental/tdf/keysplit/attributes.go b/sdk/experimental/tdf/keysplit/attributes.go index e12d838e97..a6b924fa21 100644 --- a/sdk/experimental/tdf/keysplit/attributes.go +++ b/sdk/experimental/tdf/keysplit/attributes.go @@ -11,6 +11,13 @@ import ( const unknownAlgorithm = "unknown" +// These mirror enum values that the platform-proto cell adds to objects.proto. +// Replace with generated named constants once those proto changes are merged. +const ( + kasPublicKeyAlgEnumMLKem768 policy.KasPublicKeyAlgEnum = 13 + algorithmMLKem768 policy.Algorithm = 6 +) + // resolveAttributeGrants follows the hierarchy: value → definition → namespace // Returns the most specific grants available for the given attribute value func resolveAttributeGrants(value *policy.Value) (*AttributeGrant, error) { @@ -199,6 +206,8 @@ func formatAlgorithm(alg policy.Algorithm) string { return "rsa:2048" case policy.Algorithm_ALGORITHM_RSA_4096: return "rsa:4096" + case algorithmMLKem768: + return "mlkem:768" default: return unknownAlgorithm } @@ -217,6 +226,8 @@ func convertAlgEnum2Simple(a policy.KasPublicKeyAlgEnum) policy.Algorithm { return policy.Algorithm_ALGORITHM_RSA_2048 case policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_RSA_4096: return policy.Algorithm_ALGORITHM_RSA_4096 + case kasPublicKeyAlgEnumMLKem768: + return algorithmMLKem768 case policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_UNSPECIFIED: return policy.Algorithm_ALGORITHM_UNSPECIFIED default: