Skip to content

Commit c59e591

Browse files
committed
Fix details by hand
1 parent b501989 commit c59e591

File tree

5 files changed

+374
-386
lines changed

5 files changed

+374
-386
lines changed

docs/asymmetric.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,38 @@ ML-KEM
112112
>>>
113113
>>> ss_recv = mlkem_priv.decapsulate(ct)
114114
>>> ss_send == ss_recv
115+
True
116+
117+
ML-DSA
118+
------
119+
120+
.. autoclass:: MlDsaType
121+
:show-inheritance:
122+
123+
.. autoclass:: MlDsaPublic
124+
:private-members:
125+
:members:
126+
:inherited-members:
127+
128+
.. autoclass:: MlDsaPrivate
129+
:members:
130+
:inherited-members:
131+
132+
**Example:**
133+
134+
>>> from wolfcrypt.ciphers import MlDsaType, MlDsaPrivate, MlDsaPublic
135+
>>>
136+
>>> mldsa_type = MlDsaType.ML_DSA_44
137+
>>>
138+
>>> mldsa_priv = MlDsaPrivate.make_key(mldsa_type)
139+
>>> pub_key = mldsa_priv.encode_pub_key()
140+
>>>
141+
>>> mldsa_pub = MlDsaPublic(mldsa_type)
142+
>>> mldsa_pub.decode_key(pub_key)
143+
>>>
144+
>>> msg = "This is an example message"
145+
>>>
146+
>>> sig = mldsa_priv.sign(msg)
147+
>>>
148+
>>> mldsa_pub.verify(sig, msg)
115149
True

scripts/build_ffi.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def get_features(local_wolfssl, features):
374374
features["AESGCM_STREAM"] = 1 if '#define WOLFSSL_AESGCM_STREAM' in defines else 0
375375
features["RSA_PSS"] = 1 if '#define WC_RSA_PSS' in defines else 0
376376
features["CHACHA20_POLY1305"] = 1 if '#define HAVE_CHACHA' and '#define HAVE_POLY1305' in defines else 0
377+
features["ML_DSA"] = 1 if '#define WOLFSSL_WC_DILITHIUM' in defines else 0
377378

378379
if '#define HAVE_FIPS' in defines:
379380
if not fips:
@@ -935,12 +936,16 @@ def build_ffi(local_wolfssl, features):
935936
int wolfCrypt_GetPrivateKeyReadEnable_fips(enum wc_KeyType);
936937
"""
937938

939+
if features["ML_KEM"] or features["ML_DSA"]:
940+
cdef += """
941+
static const int INVALID_DEVID;
942+
"""
943+
938944
if features["ML_KEM"]:
939945
cdef += """
940946
static const int WC_ML_KEM_512;
941947
static const int WC_ML_KEM_768;
942948
static const int WC_ML_KEM_1024;
943-
static const int INVALID_DEVID;
944949
typedef struct {...; } KyberKey;
945950
int wc_KyberKey_CipherTextSize(KyberKey* key, word32* len);
946951
int wc_KyberKey_SharedSecretSize(KyberKey* key, word32* len);
@@ -968,16 +973,17 @@ def build_ffi(local_wolfssl, features):
968973
int wc_dilithium_init_ex(dilithium_key* key, void* heap, int devId);
969974
int wc_dilithium_set_level(dilithium_key* key, byte level);
970975
void wc_dilithium_free(dilithium_key* key);
971-
int wc_dilithium_priv_size(dilithium_key* key);
972-
int wc_dilithium_pub_size(dilithium_key* key);
973-
int wc_dilithium_sig_size(dilithium_key* key);
974976
int wc_dilithium_make_key(dilithium_key* key, WC_RNG* rng);
975977
int wc_dilithium_export_private(dilithium_key* key, byte* out, word32* outLen);
976978
int wc_dilithium_import_private(const byte* priv, word32 privSz, dilithium_key* key);
977979
int wc_dilithium_export_public(dilithium_key* key, byte* out, word32* outLen);
978980
int wc_dilithium_import_public(const byte* in, word32 inLen, dilithium_key* key);
979981
int wc_dilithium_sign_msg(const byte* msg, word32 msgLen, byte* sig, word32* sigLen, dilithium_key* key, WC_RNG* rng);
980982
int wc_dilithium_verify_msg(const byte* sig, word32 sigLen, const byte* msg, word32 msgLen, int* res, dilithium_key* key);
983+
typedef dilithium_key MlDsaKey;
984+
int wc_MlDsaKey_GetPrivLen(MlDsaKey* key, int* len);
985+
int wc_MlDsaKey_GetPubLen(MlDsaKey* key, int* len);
986+
int wc_MlDsaKey_GetSigLen(MlDsaKey* key, int* len);
981987
"""
982988

983989
ffibuilder.cdef(cdef)

tests/test_mldsa.py

Lines changed: 54 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222

2323
from wolfcrypt._ffi import lib as _lib
2424

25-
if hasattr(_lib, "ML_DSA_ENABLED") and _lib.ML_DSA_ENABLED:
26-
from binascii import unhexlify as h2b
27-
25+
if _lib.ML_DSA_ENABLED:
2826
import pytest
2927

30-
from wolfcrypt.mldsa import MlDsaPrivate, MlDsaPublic, MlDsaType
28+
from wolfcrypt.ciphers import MlDsaPrivate, MlDsaPublic, MlDsaType
3129
from wolfcrypt.random import Random
3230

3331
@pytest.fixture
3432
def rng():
3533
return Random()
3634

37-
@pytest.fixture(params=[MlDsaType.ML_DSA_44, MlDsaType.ML_DSA_65, MlDsaType.ML_DSA_87])
35+
@pytest.fixture(
36+
params=[MlDsaType.ML_DSA_44, MlDsaType.ML_DSA_65, MlDsaType.ML_DSA_87]
37+
)
3838
def mldsa_type(request):
3939
return request.param
4040

@@ -45,71 +45,75 @@ def test_init_base(mldsa_type):
4545
mldsa_pub = MlDsaPublic(mldsa_type)
4646
assert isinstance(mldsa_pub, MlDsaPublic)
4747

48-
def test_key_sizes(mldsa_type):
49-
mldsa_priv = MlDsaPrivate(mldsa_type)
50-
51-
# Check that key sizes are returned correctly
52-
assert mldsa_priv.priv_key_size > 0
53-
assert mldsa_priv.pub_key_size > 0
54-
assert mldsa_priv.sig_size > 0
48+
def test_size_properties(mldsa_type):
49+
refvals = {
50+
MlDsaType.ML_DSA_44: {
51+
"sig_size": 2420,
52+
"pub_key_size": 1312,
53+
"priv_key_size": 2560,
54+
},
55+
MlDsaType.ML_DSA_65: {
56+
"sig_size": 3309,
57+
"pub_key_size": 1952,
58+
"priv_key_size": 4032,
59+
},
60+
MlDsaType.ML_DSA_87: {
61+
"sig_size": 4627,
62+
"pub_key_size": 2592,
63+
"priv_key_size": 4896,
64+
},
65+
}
5566

56-
# Public key should have the same pub_key_size
5767
mldsa_pub = MlDsaPublic(mldsa_type)
58-
assert mldsa_pub.pub_key_size == mldsa_priv.pub_key_size
59-
assert mldsa_pub.sig_size == mldsa_priv.sig_size
68+
assert mldsa_pub.sig_size == refvals[mldsa_type]["sig_size"]
69+
assert mldsa_pub.key_size == refvals[mldsa_type]["pub_key_size"]
70+
71+
mldsa_priv = MlDsaPrivate(mldsa_type)
72+
assert mldsa_priv.sig_size == refvals[mldsa_type]["sig_size"]
73+
assert mldsa_priv.pub_key_size == refvals[mldsa_type]["pub_key_size"]
74+
assert mldsa_priv.priv_key_size == refvals[mldsa_type]["priv_key_size"]
6075

61-
"""
62-
def test_key_generation(mldsa_type, rng):
63-
# Test key generation
76+
def test_initializations(mldsa_type, rng):
6477
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
65-
assert isinstance(mldsa_priv, MlDsaPrivate)
78+
assert type(mldsa_priv) is MlDsaPrivate
6679

67-
# Export keys
68-
priv_key = mldsa_priv.encode_priv_key()
69-
pub_key = mldsa_priv.encode_pub_key()
80+
mldsa_priv2 = MlDsaPrivate(mldsa_type)
81+
assert type(mldsa_priv2) is MlDsaPrivate
7082

71-
# Check key sizes
72-
assert len(priv_key) == mldsa_priv.priv_key_size
73-
assert len(pub_key) == mldsa_priv.pub_key_size
74-
"""
83+
mldsa_pub = MlDsaPublic(mldsa_type)
84+
assert type(mldsa_pub) is MlDsaPublic
7585

76-
"""
7786
def test_key_import_export(mldsa_type, rng):
78-
# Generate a key pair
87+
# Generate key pair and export keys
7988
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
80-
81-
# Export keys
8289
priv_key = mldsa_priv.encode_priv_key()
8390
pub_key = mldsa_priv.encode_pub_key()
91+
assert len(priv_key) == mldsa_priv.priv_key_size
92+
assert len(pub_key) == mldsa_priv.pub_key_size
8493

85-
# Import private key
94+
# Export key pair from imported one
8695
mldsa_priv2 = MlDsaPrivate(mldsa_type)
87-
mldsa_priv2.decode_key(priv_key)
88-
89-
# Export keys from imported private key
96+
mldsa_priv2.decode_key(priv_key, pub_key)
9097
priv_key2 = mldsa_priv2.encode_priv_key()
9198
pub_key2 = mldsa_priv2.encode_pub_key()
92-
93-
# Keys should match
9499
assert priv_key == priv_key2
95100
assert pub_key == pub_key2
96101

97-
# Import public key
102+
# Export private key from imported one
103+
mldsa_priv3 = MlDsaPrivate(mldsa_type)
104+
mldsa_priv3.decode_key(priv_key)
105+
priv_key3 = mldsa_priv3.encode_priv_key()
106+
assert priv_key == priv_key3
107+
108+
# Export public key from imported one
98109
mldsa_pub = MlDsaPublic(mldsa_type)
99110
mldsa_pub.decode_key(pub_key)
100-
101-
# Export public key from imported public key
102111
pub_key3 = mldsa_pub.encode_key()
103-
104-
# Public keys should match
105112
assert pub_key == pub_key3
106-
"""
107113

108114
def test_sign_verify(mldsa_type, rng):
109-
# Generate a key pair
115+
# Generate a key pair and export public key
110116
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
111-
112-
# Export public key
113117
pub_key = mldsa_priv.encode_pub_key()
114118

115119
# Import public key
@@ -119,32 +123,14 @@ def test_sign_verify(mldsa_type, rng):
119123
# Sign a message
120124
message = b"This is a test message for ML-DSA signature"
121125
signature = mldsa_priv.sign(message, rng)
126+
assert len(signature) == mldsa_priv.sig_size
127+
128+
# Verify the signature by MlDsaPrivate
129+
assert mldsa_priv.verify(signature, message)
122130

123-
# Verify the signature
131+
# Verify the signature by MlDsaPublic
124132
assert mldsa_pub.verify(signature, message)
125133

126134
# Verify with wrong message
127135
wrong_message = b"This is a wrong message for ML-DSA signature"
128136
assert not mldsa_pub.verify(signature, wrong_message)
129-
130-
"""
131-
def test_der_encoding(mldsa_type, rng):
132-
# Generate a key pair
133-
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
134-
135-
# Export keys in DER format
136-
priv_key_der = mldsa_priv.encode_priv_key_der()
137-
pub_key_der = mldsa_priv.encode_pub_key_der()
138-
139-
# Check that DER encoded keys are longer than raw keys
140-
assert len(priv_key_der) > mldsa_priv.priv_key_size
141-
assert len(pub_key_der) > mldsa_priv.pub_key_size
142-
143-
# Test public key DER encoding from public key object
144-
mldsa_pub = MlDsaPublic(mldsa_type)
145-
mldsa_pub.decode_key(mldsa_priv.encode_pub_key())
146-
pub_key_der2 = mldsa_pub.encode_key_der()
147-
148-
# DER encoded public keys should match
149-
assert pub_key_der == pub_key_der2
150-
"""

0 commit comments

Comments
 (0)