diff --git a/pkcs11/_pkcs11.pyx b/pkcs11/_pkcs11.pyx index 28e8887..9e9f4f0 100644 --- a/pkcs11/_pkcs11.pyx +++ b/pkcs11/_pkcs11.pyx @@ -1072,6 +1072,9 @@ cdef class Session(HasFuncList, types.Session): public_template_ = self.attribute_mapper.public_key_template( id_=id, label=label, store=store, capabilities=capabilities, ) + private_template_ = self.attribute_mapper.private_key_template( + id_=id, label=label, store=store, capabilities=capabilities, + ) if key_type is KeyType.RSA: if key_length is None: @@ -1091,11 +1094,16 @@ cdef class Session(HasFuncList, types.Session): "in `public_template` (e.g. MLDSAParameterSet.ML_DSA_65)." ) - public_attrs = self.make_attribute_list(merge_templates(public_template_, public_template)) + elif key_type is KeyType.ML_KEM: + if public_template is None or Attribute.PARAMETER_SET not in public_template: + raise ArgumentsBad( + "ML-KEM key generation requires `Attribute.PARAMETER_SET` " + "in `public_template` (e.g. MLKEMParameterSet.ML_KEM_768)." + ) + public_template_[Attribute.ENCAPSULATE] = True + private_template_[Attribute.DECAPSULATE] = True - private_template_ = self.attribute_mapper.private_key_template( - id_=id, label=label, store=store, capabilities=capabilities, - ) + public_attrs = self.make_attribute_list(merge_templates(public_template_, public_template)) private_attrs = self.make_attribute_list(merge_templates(private_template_, private_template)) return self.generate_keypair_from_attrs(public_attrs, private_attrs, mech) @@ -1370,7 +1378,7 @@ cdef object make_object(Session session, CK_OBJECT_HANDLE handle) with gil: """ wrapper = ObjectHandleWrapper.wrap(session, handle) - cdef CK_ATTRIBUTE_TYPE[8] attr_keys = [ + cdef CK_ATTRIBUTE_TYPE[10] attr_keys = [ Attribute.CLASS, Attribute.ENCRYPT, Attribute.DECRYPT, @@ -1378,13 +1386,15 @@ cdef object make_object(Session session, CK_OBJECT_HANDLE handle) with gil: Attribute.VERIFY, Attribute.WRAP, Attribute.UNWRAP, - Attribute.DERIVE + Attribute.DERIVE, + Attribute.ENCAPSULATE, + Attribute.DECAPSULATE, ] try: # Determine a list of base classes to manufacture our class with try: - attributes = wrapper.get_attribute_list(&attr_keys[0], 8) + attributes = wrapper.get_attribute_list(&attr_keys[0], 10) except PKCS11Error: # retry fetching the flags one by one, some tokens do not implement error handling # on bulk fetches correctly. @@ -1407,6 +1417,8 @@ cdef object make_object(Session session, CK_OBJECT_HANDLE handle) with gil: (Attribute.WRAP, WrapMixin), (Attribute.UNWRAP, UnwrapMixin), (Attribute.DERIVE, DeriveMixin), + (Attribute.ENCAPSULATE, EncapsulateMixin), + (Attribute.DECAPSULATE, DecapsulateMixin), ): try: if attributes.get(attribute, session.attribute_mapper): @@ -1957,6 +1969,117 @@ class DeriveMixin(types.DeriveMixin): return make_object(session, key) +class EncapsulateMixin(types.EncapsulateMixin): + """Expand EncapsulateMixin with an implementation (ML-KEM).""" + + def encapsulate_key(self, key_type, + key_length=None, + id=None, label=None, + store=False, capabilities=None, + mechanism=None, mechanism_param=None, + template=None): + + if not isinstance(key_type, KeyType): + raise ArgumentsBad("`key_type` must be KeyType.") + + if capabilities is None: + try: + capabilities = DEFAULT_KEY_CAPABILITIES[key_type] + except KeyError: + raise ArgumentsBad("No default capabilities for this key " + "type. Please specify `capabilities`.") + + mech = MechanismWithParam(self.key_type, DEFAULT_ENCAPSULATE_MECHANISMS, mechanism, mechanism_param) + + cdef Session session = self.session + + if session.funclist32 == NULL: + raise FunctionNotSupported("C_EncapsulateKey requires a PKCS#11 v3.2+ library") + + template_ = session.attribute_mapper.secret_key_template( + capabilities=capabilities, id_=id, label=label, store=store, + ) + if key_length is not None: + template_[Attribute.VALUE_LEN] = key_length // 8 + template_[Attribute.KEY_TYPE] = key_type + cdef AttributeList attrs = session.make_attribute_list(merge_templates(template_, template)) + cdef CK_MECHANISM *mech_data = mech.data + cdef CK_OBJECT_HANDLE pub_key = self.handle + cdef CK_ATTRIBUTE *attr_data = attrs.data + cdef CK_ULONG attr_count = attrs.count + cdef CK_ULONG ct_len = 0 + cdef CK_OBJECT_HANDLE ss_handle + cdef CK_RV retval + + with nogil: + retval = session.funclist32.C_EncapsulateKey( + session.handle, mech_data, pub_key, + attr_data, attr_count, NULL, &ct_len, &ss_handle) + assertRV(retval) + + cdef CK_BYTE [:] ct_buf = CK_BYTE_buffer(ct_len) + + with nogil: + retval = session.funclist32.C_EncapsulateKey( + session.handle, mech_data, pub_key, + attr_data, attr_count, &ct_buf[0], &ct_len, &ss_handle) + assertRV(retval) + + return (bytes(ct_buf[:ct_len]), make_object(session, ss_handle)) + + +class DecapsulateMixin(types.DecapsulateMixin): + """Expand DecapsulateMixin with an implementation (ML-KEM).""" + + def decapsulate_key(self, ciphertext, key_type, + key_length=None, + id=None, label=None, + store=False, capabilities=None, + mechanism=None, mechanism_param=None, + template=None): + + if not isinstance(key_type, KeyType): + raise ArgumentsBad("`key_type` must be KeyType.") + + if capabilities is None: + try: + capabilities = DEFAULT_KEY_CAPABILITIES[key_type] + except KeyError: + raise ArgumentsBad("No default capabilities for this key " + "type. Please specify `capabilities`.") + + mech = MechanismWithParam(self.key_type, DEFAULT_ENCAPSULATE_MECHANISMS, mechanism, mechanism_param) + + cdef Session session = self.session + + if session.funclist32 == NULL: + raise FunctionNotSupported("C_DecapsulateKey requires a PKCS#11 v3.2+ library") + + template_ = session.attribute_mapper.secret_key_template( + capabilities=capabilities, id_=id, label=label, store=store, + ) + if key_length is not None: + template_[Attribute.VALUE_LEN] = key_length // 8 + template_[Attribute.KEY_TYPE] = key_type + cdef AttributeList attrs = session.make_attribute_list(merge_templates(template_, template)) + cdef CK_MECHANISM *mech_data = mech.data + cdef CK_OBJECT_HANDLE priv_key = self.handle + cdef CK_BYTE *ct_ptr = ciphertext + cdef CK_ULONG ct_len = len(ciphertext) + cdef CK_ATTRIBUTE *attr_data = attrs.data + cdef CK_ULONG attr_count = attrs.count + cdef CK_OBJECT_HANDLE key + cdef CK_RV retval + + with nogil: + retval = session.funclist32.C_DecapsulateKey( + session.handle, mech_data, priv_key, + attr_data, attr_count, ct_ptr, ct_len, &key) + assertRV(retval) + + return make_object(session, key) + + _CLASS_MAP = { ObjectClass.SECRET_KEY: SecretKey, ObjectClass.PUBLIC_KEY: PublicKey, diff --git a/pkcs11/attributes.py b/pkcs11/attributes.py index 93c85c5..be11420 100644 --- a/pkcs11/attributes.py +++ b/pkcs11/attributes.py @@ -49,10 +49,12 @@ def _enum(type_: type[IntEnum]) -> Handler: Attribute.CHECK_VALUE: handle_bytes, Attribute.CLASS: _enum(ObjectClass), Attribute.COEFFICIENT: handle_biginteger, + Attribute.DECAPSULATE: handle_bool, Attribute.DECRYPT: handle_bool, Attribute.DERIVE: handle_bool, Attribute.EC_PARAMS: handle_bytes, Attribute.EC_POINT: handle_bytes, + Attribute.ENCAPSULATE: handle_bool, Attribute.ENCRYPT: handle_bool, Attribute.END_DATE: handle_date, Attribute.EXPONENT_1: handle_biginteger, @@ -214,7 +216,9 @@ def public_key_template( ) -> dict[Attribute, Any]: template = dict(self.default_public_key_template) _apply_capabilities( - template, (Attribute.ENCRYPT, Attribute.WRAP, Attribute.VERIFY), capabilities + template, + (Attribute.ENCRYPT, Attribute.WRAP, Attribute.VERIFY), + capabilities, ) _apply_common(template, id_, label, store) return template diff --git a/pkcs11/constants.py b/pkcs11/constants.py index bc98e22..be61cee 100644 --- a/pkcs11/constants.py +++ b/pkcs11/constants.py @@ -349,6 +349,11 @@ class Attribute(IntEnum): PARAMETER_SET = 0x0000061D SEED = 0x00000637 + ENCAPSULATE_TEMPLATE = 0x0000062A + DECAPSULATE_TEMPLATE = 0x0000062B + ENCAPSULATE = 0x00000633 + DECAPSULATE = 0x00000634 + _VENDOR_DEFINED = 0x80000000 def __repr__(self) -> str: @@ -409,6 +414,11 @@ class MechanismFlag(IntFlag): EC_UNCOMPRESS = 0x01000000 EC_COMPRESS = 0x02000000 + ENCAPSULATE = 0x10000000 + """Can perform ML-KEM key encapsulation.""" + DECAPSULATE = 0x20000000 + """Can perform ML-KEM key decapsulation.""" + EXTENSION = 0x80000000 diff --git a/pkcs11/defaults.py b/pkcs11/defaults.py index fa55736..d3d7a04 100644 --- a/pkcs11/defaults.py +++ b/pkcs11/defaults.py @@ -22,6 +22,7 @@ KeyType.RSA: Mechanism.RSA_PKCS_KEY_PAIR_GEN, KeyType.X9_42_DH: Mechanism.X9_42_DH_KEY_PAIR_GEN, KeyType.EC_EDWARDS: Mechanism.EC_EDWARDS_KEY_PAIR_GEN, + KeyType.ML_KEM: Mechanism.ML_KEM_KEY_PAIR_GEN, KeyType.ML_DSA: Mechanism.ML_DSA_KEY_PAIR_GEN, KeyType.GENERIC_SECRET: Mechanism.GENERIC_SECRET_KEY_GEN, } @@ -32,6 +33,7 @@ _ENCRYPTION: Final[MechanismFlag] = MechanismFlag.ENCRYPT | MechanismFlag.DECRYPT _SIGNING: Final[MechanismFlag] = MechanismFlag.SIGN | MechanismFlag.VERIFY _WRAPPING: Final[MechanismFlag] = MechanismFlag.WRAP | MechanismFlag.UNWRAP +_ENCAPSULATING: Final[MechanismFlag] = MechanismFlag.ENCAPSULATE | MechanismFlag.DECAPSULATE DEFAULT_KEY_CAPABILITIES: Final[dict[KeyType, MechanismFlag | int]] = { KeyType.AES: _ENCRYPTION | _SIGNING | _WRAPPING, @@ -43,6 +45,7 @@ KeyType.RSA: _ENCRYPTION | _SIGNING | _WRAPPING, KeyType.GENERIC_SECRET: 0, KeyType.EC_EDWARDS: _SIGNING, + KeyType.ML_KEM: _ENCAPSULATING, KeyType.ML_DSA: _SIGNING, } """ @@ -83,6 +86,11 @@ Default mechanism for wrap/unwrap. """ +DEFAULT_ENCAPSULATE_MECHANISMS: Final[dict[KeyType, Mechanism]] = { + KeyType.ML_KEM: Mechanism.ML_KEM, +} +"""Default mechanisms for ML-KEM encapsulate/decapsulate.""" + DEFAULT_DERIVE_MECHANISMS: Final[dict[KeyType, Mechanism]] = { KeyType.DH: Mechanism.DH_PKCS_DERIVE, KeyType.EC: Mechanism.ECDH1_DERIVE, diff --git a/pkcs11/mechanisms.py b/pkcs11/mechanisms.py index ece7017..49f7003 100644 --- a/pkcs11/mechanisms.py +++ b/pkcs11/mechanisms.py @@ -107,6 +107,7 @@ class KeyType(IntEnum): # from version 3.0 EC_EDWARDS = 0x00000040 # from version 3.2 + ML_KEM = 0x00000049 ML_DSA = 0x0000004A _VENDOR_DEFINED = 0x80000000 @@ -736,6 +737,10 @@ class Mechanism(IntEnum): EDDSA = 0x00001057 EC_EDWARDS_KEY_PAIR_GEN = 0x00001055 + # ML-KEM (v3.2) + ML_KEM_KEY_PAIR_GEN = 0x0000000F + ML_KEM = 0x00000017 + # ML-DSA (v3.2) ML_DSA_KEY_PAIR_GEN = 0x0000001C ML_DSA = 0x0000001D @@ -806,6 +811,17 @@ def __repr__(self) -> str: return "" % self.name +class MLKEMParameterSet(IntEnum): + """ML-KEM parameter sets as defined in FIPS 203.""" + + ML_KEM_512 = 0x00000001 + ML_KEM_768 = 0x00000002 + ML_KEM_1024 = 0x00000003 + + def __repr__(self) -> str: + return "" % self.name + + class MLDSAParameterSet(IntEnum): """ML-DSA parameter sets as defined in FIPS 204.""" diff --git a/pkcs11/types.py b/pkcs11/types.py index bfd9946..e32131f 100644 --- a/pkcs11/types.py +++ b/pkcs11/types.py @@ -1376,3 +1376,81 @@ def derive_key( :rtype: SecretKey """ raise NotImplementedError() + + +class EncapsulateMixin(HasKeyType): + """ + This :class:`Object` supports the encapsulate capability (ML-KEM). + """ + + def encapsulate_key( + self, + key_type: KeyType, + key_length: int | None = None, + id: bytes | None = None, + label: str | None = None, + store: bool = False, + capabilities: MechanismFlag | None = None, + mechanism: Mechanism | None = None, + mechanism_param: bytes | None = None, + template: dict[Attribute, Any] | None = None, + ) -> tuple[bytes, SecretKey]: + """ + Use this ML-KEM public key to encapsulate a fresh shared secret. + + Returns ``(ciphertext, shared_secret)`` where *ciphertext* is the + KEM ciphertext to be transmitted to the decapsulating party, and + *shared_secret* is the newly-created secret key object on the token. + + :param KeyType key_type: Key type for the shared secret (e.g. KeyType.GENERIC_SECRET). + :param int key_length: Shared secret length in bits. + :param bytes id: Key identifier. + :param str label: Key label. + :param store: Store key on token (requires R/W session). + :param MechanismFlag capabilities: Key capabilities (or default). + :param Mechanism mechanism: Encapsulation mechanism (or default). + :param bytes mechanism_param: Optional mechanism parameter. + :param dict(Attribute,*) template: Additional attributes. + + :rtype: tuple[bytes, SecretKey] + """ + raise NotImplementedError() + + +class DecapsulateMixin(HasKeyType): + """ + This :class:`Object` supports the decapsulate capability (ML-KEM). + """ + + def decapsulate_key( + self, + ciphertext: bytes, + key_type: KeyType, + key_length: int | None = None, + id: bytes | None = None, + label: str | None = None, + store: bool = False, + capabilities: MechanismFlag | None = None, + mechanism: Mechanism | None = None, + mechanism_param: bytes | None = None, + template: dict[Attribute, Any] | None = None, + ) -> SecretKey: + """ + Use this ML-KEM private key to decapsulate the shared secret from *ciphertext*. + + Returns the recovered shared secret key object. + + :param bytes ciphertext: KEM ciphertext from the encapsulating party. + :param KeyType key_type: Key type for the shared secret (e.g. KeyType.GENERIC_SECRET). + :param int key_length: Shared secret length in bits. + :param bytes id: Key identifier. + :param str label: Key label. + :param store: Store key on token (requires R/W session). + :param MechanismFlag capabilities: Key capabilities (or default). + :param Mechanism mechanism: Decapsulation mechanism (or default). + :param bytes mechanism_param: Optional mechanism parameter. + :param dict(Attribute,*) template: Additional attributes. + + :rtype: SecretKey + """ + raise NotImplementedError() diff --git a/tests/test_ml_kem.py b/tests/test_ml_kem.py new file mode 100644 index 0000000..132436e --- /dev/null +++ b/tests/test_ml_kem.py @@ -0,0 +1,82 @@ +from parameterized import parameterized + +import pkcs11 +from pkcs11 import Attribute, KeyType, Mechanism +from pkcs11.mechanisms import MLKEMParameterSet + +from . import TestCase, requires + + +class MLKEMTests(TestCase): + @requires(Mechanism.ML_KEM_KEY_PAIR_GEN, Mechanism.ML_KEM) + def test_generate_ml_kem_512(self): + pub, priv = self.session.generate_keypair( + KeyType.ML_KEM, + public_template={Attribute.PARAMETER_SET: MLKEMParameterSet.ML_KEM_512}, + ) + self.assertIsNotNone(pub) + self.assertIsNotNone(priv) + self.assertEqual(pub[Attribute.PARAMETER_SET], int(MLKEMParameterSet.ML_KEM_512)) + + @requires(Mechanism.ML_KEM_KEY_PAIR_GEN, Mechanism.ML_KEM) + def test_generate_ml_kem_768(self): + pub, priv = self.session.generate_keypair( + KeyType.ML_KEM, + public_template={Attribute.PARAMETER_SET: MLKEMParameterSet.ML_KEM_768}, + ) + self.assertIsNotNone(pub) + self.assertIsNotNone(priv) + self.assertEqual(pub[Attribute.PARAMETER_SET], int(MLKEMParameterSet.ML_KEM_768)) + + @requires(Mechanism.ML_KEM_KEY_PAIR_GEN, Mechanism.ML_KEM) + def test_generate_ml_kem_1024(self): + pub, priv = self.session.generate_keypair( + KeyType.ML_KEM, + public_template={Attribute.PARAMETER_SET: MLKEMParameterSet.ML_KEM_1024}, + ) + self.assertIsNotNone(pub) + self.assertIsNotNone(priv) + self.assertEqual(pub[Attribute.PARAMETER_SET], int(MLKEMParameterSet.ML_KEM_1024)) + + @parameterized.expand( + [ + ("AES-256", 256), + ("AES-128", 128), + ] + ) + @requires(Mechanism.ML_KEM_KEY_PAIR_GEN, Mechanism.ML_KEM) + def test_encapsulate_decapsulate(self, test_type, key_length): + pub, priv = self.session.generate_keypair( + KeyType.ML_KEM, + public_template={Attribute.PARAMETER_SET: MLKEMParameterSet.ML_KEM_768}, + ) + + # Encapsulate: public key produces ciphertext + shared secret key + ciphertext, ss_enc = pub.encapsulate_key( + KeyType.AES, + key_length=key_length, + store=False, + ) + self.assertIsInstance(ciphertext, bytes) + self.assertGreater(len(ciphertext), 0) + self.assertIsNotNone(ss_enc) + + ss_dec = priv.decapsulate_key( + ciphertext, + KeyType.AES, + store=False, + key_length=key_length, + ) + self.assertIsNotNone(ss_dec) + + # Verify the two shared secrets are equal by encrypting/decrypting. + # If they differ, the AES_ECB decrypt will produce wrong output. + plaintext = b"mlkem test data!" # exactly 16 bytes for AES_ECB + encrypted = ss_enc.encrypt(plaintext, mechanism=Mechanism.AES_ECB) + recovered = ss_dec.decrypt(encrypted, mechanism=Mechanism.AES_ECB) + self.assertEqual(plaintext, recovered) + + @requires(Mechanism.ML_KEM_KEY_PAIR_GEN, Mechanism.ML_KEM) + def test_missing_parameter_set_raises(self): + with self.assertRaises(pkcs11.exceptions.ArgumentsBad): + self.session.generate_keypair(KeyType.ML_KEM)