Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lib/ocrypto/ec_key_pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ type ECCMode uint8
type KeyType string

const (
RSA2048Key KeyType = "rsa:2048"
RSA4096Key KeyType = "rsa:4096"
EC256Key KeyType = "ec:secp256r1"
EC384Key KeyType = "ec:secp384r1"
EC521Key KeyType = "ec:secp521r1"
RSA2048Key KeyType = "rsa:2048"
RSA4096Key KeyType = "rsa:4096"
EC256Key KeyType = "ec:secp256r1"
EC384Key KeyType = "ec:secp384r1"
EC521Key KeyType = "ec:secp521r1"
MLKem768Key KeyType = "mlkem:768"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In Go, acronyms like ML-KEM should be consistently cased (e.g., MLKEM). This constant should be renamed to MLKEM768Key to match the casing used in the helper function IsMLKEMKeyType and other parts of the PR.

Suggested change
MLKem768Key KeyType = "mlkem:768"
MLKEM768Key KeyType = "mlkem:768"

)

const (
Expand Down Expand Up @@ -91,6 +92,10 @@ func IsRSAKeyType(kt KeyType) bool {
}
}

func IsMLKEMKeyType(kt KeyType) bool {
return kt == MLKem768Key
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update this reference to use the consistently cased constant name.

Suggested change
return kt == MLKem768Key
return kt == MLKEM768Key

}

// GetECCurveFromECCMode return elliptic curve from ecc mode
func GetECCurveFromECCMode(mode ECCMode) (elliptic.Curve, error) {
var c elliptic.Curve
Expand Down
33 changes: 33 additions & 0 deletions lib/ocrypto/mlkem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package ocrypto

import (
"crypto/mlkem"
"encoding/pem"
"fmt"
)

const (
// MLKem768CiphertextSize is the byte length of an ML-KEM-768 ciphertext.
MLKem768CiphertextSize = 1088
// MLKem768PublicKeySize is the byte length of an ML-KEM-768 encapsulation key.
MLKem768PublicKeySize = 1184

mlkem768PEMType = "ML-KEM-768 PUBLIC KEY"
)
Comment on lines +10 to +16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is better to use the constants provided by the crypto/mlkem package instead of hardcoding magic numbers. Also, the PEM type constant should be exported so it can be used by other packages (like tests) to avoid string duplication, and acronym casing should be consistent.

const (
	// MLKEM768CiphertextSize is the byte length of an ML-KEM-768 ciphertext.
	MLKEM768CiphertextSize = mlkem.CiphertextSize768
	// MLKEM768PublicKeySize is the byte length of an ML-KEM-768 encapsulation key.
	MLKEM768PublicKeySize = mlkem.EncapsulationKeySize768

	MLKEM768PEMType = "ML-KEM-768 PUBLIC KEY"
)


// MLKEMPublicKeyFromPEM parses an ML-KEM-768 encapsulation key from a PEM block
// with type "ML-KEM-768 PUBLIC KEY".
func MLKEMPublicKeyFromPEM(pemData []byte) (*mlkem.EncapsulationKey768, error) {
block, _ := pem.Decode(pemData)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block for ML-KEM-768 public key")

Check failure on line 23 in lib/ocrypto/mlkem.go

View workflow job for this annotation

GitHub Actions / go (lib/ocrypto)

error-format: fmt.Errorf can be replaced with errors.New (perfsprint)
}
if block.Type != mlkem768PEMType {
return nil, fmt.Errorf("unexpected PEM type %q, expected %q", block.Type, mlkem768PEMType)
Comment on lines +25 to +26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update these lines to use the renamed and exported PEM type constant.

Suggested change
if block.Type != mlkem768PEMType {
return nil, fmt.Errorf("unexpected PEM type %q, expected %q", block.Type, mlkem768PEMType)
if block.Type != MLKEM768PEMType {
return nil, fmt.Errorf("unexpected PEM type %q, expected %q", block.Type, MLKEM768PEMType)

}
key, err := mlkem.NewEncapsulationKey768(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse ML-KEM-768 encapsulation key: %w", err)
}
return key, nil
}
33 changes: 33 additions & 0 deletions sdk/experimental/tdf/key_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ func wrapKeyWithPublicKey(symKey []byte, pubKeyInfo keysplit.KASPublicKey) (stri
// Handle EC key wrapping
return wrapKeyWithEC(ktype, pubKeyInfo.PEM, symKey)
}
if ocrypto.IsMLKEMKeyType(ktype) {
wrapped, err := wrapKeyWithMLKEM(pubKeyInfo.PEM, symKey)
return wrapped, "wrapped", "", err
}
// Handle RSA key wrapping
wrapped, err := wrapKeyWithRSA(pubKeyInfo.PEM, symKey)
return wrapped, "wrapped", "", err
Expand Down Expand Up @@ -245,3 +249,32 @@ func wrapKeyWithRSA(kasPublicKeyPEM string, symKey []byte) (string, error) {

return string(ocrypto.Base64Encode(encryptedKey)), nil
}

// wrapKeyWithMLKEM encapsulates a symmetric key using ML-KEM-768.
// Wire format: base64(ml_kem_ciphertext [1088 bytes] || aes_gcm_wrapped_dek)
// The ephemeralPublicKey field is not set for this key type.
func wrapKeyWithMLKEM(kasPublicKeyPEM string, symKey []byte) (string, error) {
encKey, err := ocrypto.MLKEMPublicKeyFromPEM([]byte(kasPublicKeyPEM))
if err != nil {
return "", fmt.Errorf("failed to parse ML-KEM-768 public key: %w", err)
}

// Encapsulate: sharedKey is used to wrap the DEK; ciphertext goes on the wire.
sharedKey, ciphertext := encKey.Encapsulate()

gcm, err := ocrypto.NewAESGcm(sharedKey)
if err != nil {
return "", fmt.Errorf("failed to create AES-GCM for ML-KEM DEK wrap: %w", err)
}

wrappedDEK, err := gcm.Encrypt(symKey)
if err != nil {
return "", fmt.Errorf("failed to AES-GCM wrap DEK: %w", err)
}

payload := make([]byte, 0, len(ciphertext)+len(wrappedDEK))
payload = append(payload, ciphertext...)
payload = append(payload, wrappedDEK...)

return string(ocrypto.Base64Encode(payload)), nil
}
62 changes: 62 additions & 0 deletions sdk/experimental/tdf/key_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package tdf

import (
"crypto/mlkem"
"crypto/rand"
"encoding/json"
"encoding/pem"
"strings"
"testing"

Expand Down Expand Up @@ -473,3 +475,63 @@ func TestTdfSalt(t *testing.T) {
assert.NotEmpty(t, salt1, "Salt should not be empty")
})
}

// mlkem768TestPublicKeyPEM generates a fresh ML-KEM-768 key pair and returns
// the public key as a PEM block with type "ML-KEM-768 PUBLIC KEY".
func mlkem768TestPublicKeyPEM(t *testing.T) (string, *mlkem.DecapsulationKey768) {
t.Helper()
dk, err := mlkem.GenerateKey768()
require.NoError(t, err)
pubKeyBytes := dk.EncapsulationKey().Bytes()
block := &pem.Block{Type: "ML-KEM-768 PUBLIC KEY", Bytes: pubKeyBytes}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the exported constant from the ocrypto package instead of hardcoding the PEM type string.

Suggested change
block := &pem.Block{Type: "ML-KEM-768 PUBLIC KEY", Bytes: pubKeyBytes}
block := &pem.Block{Type: ocrypto.MLKEM768PEMType, Bytes: pubKeyBytes}

return string(pem.EncodeToMemory(block)), dk
}

func TestWrapKeyWithMLKEM(t *testing.T) {
t.Run("wraps key and produces correct wire format", func(t *testing.T) {
pubKeyPEM, dk := mlkem768TestPublicKeyPEM(t)

symKey := make([]byte, 32)
_, err := rand.Read(symKey)
require.NoError(t, err)

wrappedB64, err := wrapKeyWithMLKEM(pubKeyPEM, symKey)
require.NoError(t, err)
assert.NotEmpty(t, wrappedB64)

payload, err := ocrypto.Base64Decode([]byte(wrappedB64))
require.NoError(t, err)
assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize, "payload must include ciphertext + wrapped DEK")

ciphertext := payload[:ocrypto.MLKem768CiphertextSize]
wrappedDEK := payload[ocrypto.MLKem768CiphertextSize:]
Comment on lines +504 to +507
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update these references to use the renamed MLKEM768CiphertextSize constant.

Suggested change
assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize, "payload must include ciphertext + wrapped DEK")
ciphertext := payload[:ocrypto.MLKem768CiphertextSize]
wrappedDEK := payload[ocrypto.MLKem768CiphertextSize:]
assert.Greater(t, len(payload), ocrypto.MLKEM768CiphertextSize, "payload must include ciphertext + wrapped DEK")
ciphertext := payload[:ocrypto.MLKEM768CiphertextSize]
wrappedDEK := payload[ocrypto.MLKEM768CiphertextSize:]


sharedKey, err := dk.Decapsulate(ciphertext)
require.NoError(t, err)

gcm, err := ocrypto.NewAESGcm(sharedKey)
require.NoError(t, err)

recoveredDEK, err := gcm.Decrypt(wrappedDEK)
require.NoError(t, err)
assert.Equal(t, symKey, recoveredDEK, "recovered DEK must match original")
})

t.Run("buildKeyAccessObjects uses wrapped key type with no ephemeral key", func(t *testing.T) {
pubKeyPEM, _ := mlkem768TestPublicKeyPEM(t)
splitResult := createTestSplitResult(testKAS1URL, pubKeyPEM, "mlkem:768")

keyAccessList, err := buildKeyAccessObjects(splitResult, []byte(testPolicyJSON), "")
require.NoError(t, err)
require.Len(t, keyAccessList, 1)

ka := keyAccessList[0]
assert.Equal(t, "wrapped", ka.KeyType, "ML-KEM KAO must use 'wrapped' key type")
assert.Empty(t, ka.EphemeralPublicKey, "ML-KEM KAO must not set ephemeralPublicKey")
assert.NotEmpty(t, ka.WrappedKey)

payload, err := ocrypto.Base64Decode([]byte(ka.WrappedKey))
require.NoError(t, err)
assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update this reference to use the renamed MLKEM768CiphertextSize constant.

Suggested change
assert.Greater(t, len(payload), ocrypto.MLKem768CiphertextSize)
assert.Greater(t, len(payload), ocrypto.MLKEM768CiphertextSize)

})
}
11 changes: 11 additions & 0 deletions sdk/experimental/tdf/keysplit/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ import (

const unknownAlgorithm = "unknown"

// These mirror enum values that the platform-proto cell adds to objects.proto.
// Replace with generated named constants once those proto changes are merged.
const (
kasPublicKeyAlgEnumMLKem768 policy.KasPublicKeyAlgEnum = 13
algorithmMLKem768 policy.Algorithm = 6
)

// resolveAttributeGrants follows the hierarchy: value → definition → namespace
// Returns the most specific grants available for the given attribute value
func resolveAttributeGrants(value *policy.Value) (*AttributeGrant, error) {
Expand Down Expand Up @@ -199,6 +206,8 @@ func formatAlgorithm(alg policy.Algorithm) string {
return "rsa:2048"
case policy.Algorithm_ALGORITHM_RSA_4096:
return "rsa:4096"
case algorithmMLKem768:
return "mlkem:768"
default:
return unknownAlgorithm
}
Expand All @@ -217,6 +226,8 @@ func convertAlgEnum2Simple(a policy.KasPublicKeyAlgEnum) policy.Algorithm {
return policy.Algorithm_ALGORITHM_RSA_2048
case policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_RSA_4096:
return policy.Algorithm_ALGORITHM_RSA_4096
case kasPublicKeyAlgEnumMLKem768:
return algorithmMLKem768
case policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_UNSPECIFIED:
return policy.Algorithm_ALGORITHM_UNSPECIFIED
default:
Expand Down
Loading