Skip to content

Commit b501989

Browse files
Support ML-DSA
Co-Authored-By: Koji Takeda <[email protected]>
1 parent fa4142d commit b501989

File tree

3 files changed

+492
-2
lines changed

3 files changed

+492
-2
lines changed

scripts/build_ffi.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def make_flags(prefix, fips):
235235
# ML-KEM
236236
flags.append("--enable-kyber")
237237

238+
# ML-DSA
239+
flags.append("--enable-dilithium")
240+
238241
# disabling other configs enabled by default
239242
flags.append("--disable-oldtls")
240243
flags.append("--disable-oldnames")
@@ -447,6 +450,7 @@ def build_ffi(local_wolfssl, features):
447450
#include <wolfssl/wolfcrypt/chacha20_poly1305.h>
448451
#include <wolfssl/wolfcrypt/kyber.h>
449452
#include <wolfssl/wolfcrypt/wc_kyber.h>
453+
#include <wolfssl/wolfcrypt/dilithium.h>
450454
"""
451455

452456
init_source_string = """
@@ -484,6 +488,7 @@ def build_ffi(local_wolfssl, features):
484488
int RSA_PSS_ENABLED = """ + str(features["RSA_PSS"]) + """;
485489
int CHACHA20_POLY1305_ENABLED = """ + str(features["CHACHA20_POLY1305"]) + """;
486490
int ML_KEM_ENABLED = """ + str(features["ML_KEM"]) + """;
491+
int ML_DSA_ENABLED = """ + str(features["ML_DSA"]) + """;
487492
"""
488493

489494
ffibuilder.set_source( "wolfcrypt._ffi", init_source_string,
@@ -520,6 +525,7 @@ def build_ffi(local_wolfssl, features):
520525
extern int RSA_PSS_ENABLED;
521526
extern int CHACHA20_POLY1305_ENABLED;
522527
extern int ML_KEM_ENABLED;
528+
extern int ML_DSA_ENABLED;
523529
524530
typedef unsigned char byte;
525531
typedef unsigned int word32;
@@ -950,7 +956,28 @@ def build_ffi(local_wolfssl, features):
950956
int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct, unsigned char* ss, const unsigned char* rand, int len);
951957
int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss, const unsigned char* ct, word32 len);
952958
int wc_KyberKey_EncodePrivateKey(KyberKey* key, unsigned char* out, word32 len);
953-
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in, word32 len);
959+
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in, word32 len);
960+
"""
961+
962+
if features["ML_DSA"]:
963+
cdef += """
964+
static const int WC_ML_DSA_44;
965+
static const int WC_ML_DSA_65;
966+
static const int WC_ML_DSA_87;
967+
typedef struct {...; } dilithium_key;
968+
int wc_dilithium_init_ex(dilithium_key* key, void* heap, int devId);
969+
int wc_dilithium_set_level(dilithium_key* key, byte level);
970+
void wc_dilithium_free(dilithium_key* key);
971+
int wc_dilithium_priv_size(dilithium_key* key);
972+
int wc_dilithium_pub_size(dilithium_key* key);
973+
int wc_dilithium_sig_size(dilithium_key* key);
974+
int wc_dilithium_make_key(dilithium_key* key, WC_RNG* rng);
975+
int wc_dilithium_export_private(dilithium_key* key, byte* out, word32* outLen);
976+
int wc_dilithium_import_private(const byte* priv, word32 privSz, dilithium_key* key);
977+
int wc_dilithium_export_public(dilithium_key* key, byte* out, word32* outLen);
978+
int wc_dilithium_import_public(const byte* in, word32 inLen, dilithium_key* key);
979+
int wc_dilithium_sign_msg(const byte* msg, word32 msgLen, byte* sig, word32* sigLen, dilithium_key* key, WC_RNG* rng);
980+
int wc_dilithium_verify_msg(const byte* sig, word32 sigLen, const byte* msg, word32 msgLen, int* res, dilithium_key* key);
954981
"""
955982

956983
ffibuilder.cdef(cdef)
@@ -983,7 +1010,8 @@ def main(ffibuilder):
9831010
"AESGCM_STREAM": 1,
9841011
"RSA_PSS": 1,
9851012
"CHACHA20_POLY1305": 1,
986-
"ML_KEM": 1
1013+
"ML_KEM": 1,
1014+
"ML_DSA": 1
9871015
}
9881016

9891017
# Ed448 requires SHAKE256, which isn't part of the Windows build, yet.

tests/test_mldsa.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# test_mldsa.py
2+
#
3+
# Copyright (C) 2025 wolfSSL Inc.
4+
#
5+
# This file is part of wolfSSL. (formerly known as CyaSSL)
6+
#
7+
# wolfSSL is free software; you can redistribute it and/or modify
8+
# it under the terms of the GNU General Public License as published by
9+
# the Free Software Foundation; either version 2 of the License, or
10+
# (at your option) any later version.
11+
#
12+
# wolfSSL is distributed in the hope that it will be useful,
13+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
# GNU General Public License for more details.
16+
#
17+
# You should have received a copy of the GNU General Public License
18+
# along with this program; if not, write to the Free Software
19+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
20+
21+
# pylint: disable=redefined-outer-name
22+
23+
from wolfcrypt._ffi import lib as _lib
24+
25+
if hasattr(_lib, "ML_DSA_ENABLED") and _lib.ML_DSA_ENABLED:
26+
from binascii import unhexlify as h2b
27+
28+
import pytest
29+
30+
from wolfcrypt.mldsa import MlDsaPrivate, MlDsaPublic, MlDsaType
31+
from wolfcrypt.random import Random
32+
33+
@pytest.fixture
34+
def rng():
35+
return Random()
36+
37+
@pytest.fixture(params=[MlDsaType.ML_DSA_44, MlDsaType.ML_DSA_65, MlDsaType.ML_DSA_87])
38+
def mldsa_type(request):
39+
return request.param
40+
41+
def test_init_base(mldsa_type):
42+
mldsa_priv = MlDsaPrivate(mldsa_type)
43+
assert isinstance(mldsa_priv, MlDsaPrivate)
44+
45+
mldsa_pub = MlDsaPublic(mldsa_type)
46+
assert isinstance(mldsa_pub, MlDsaPublic)
47+
48+
def test_key_sizes(mldsa_type):
49+
mldsa_priv = MlDsaPrivate(mldsa_type)
50+
51+
# Check that key sizes are returned correctly
52+
assert mldsa_priv.priv_key_size > 0
53+
assert mldsa_priv.pub_key_size > 0
54+
assert mldsa_priv.sig_size > 0
55+
56+
# Public key should have the same pub_key_size
57+
mldsa_pub = MlDsaPublic(mldsa_type)
58+
assert mldsa_pub.pub_key_size == mldsa_priv.pub_key_size
59+
assert mldsa_pub.sig_size == mldsa_priv.sig_size
60+
61+
"""
62+
def test_key_generation(mldsa_type, rng):
63+
# Test key generation
64+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
65+
assert isinstance(mldsa_priv, MlDsaPrivate)
66+
67+
# Export keys
68+
priv_key = mldsa_priv.encode_priv_key()
69+
pub_key = mldsa_priv.encode_pub_key()
70+
71+
# Check key sizes
72+
assert len(priv_key) == mldsa_priv.priv_key_size
73+
assert len(pub_key) == mldsa_priv.pub_key_size
74+
"""
75+
76+
"""
77+
def test_key_import_export(mldsa_type, rng):
78+
# Generate a key pair
79+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
80+
81+
# Export keys
82+
priv_key = mldsa_priv.encode_priv_key()
83+
pub_key = mldsa_priv.encode_pub_key()
84+
85+
# Import private key
86+
mldsa_priv2 = MlDsaPrivate(mldsa_type)
87+
mldsa_priv2.decode_key(priv_key)
88+
89+
# Export keys from imported private key
90+
priv_key2 = mldsa_priv2.encode_priv_key()
91+
pub_key2 = mldsa_priv2.encode_pub_key()
92+
93+
# Keys should match
94+
assert priv_key == priv_key2
95+
assert pub_key == pub_key2
96+
97+
# Import public key
98+
mldsa_pub = MlDsaPublic(mldsa_type)
99+
mldsa_pub.decode_key(pub_key)
100+
101+
# Export public key from imported public key
102+
pub_key3 = mldsa_pub.encode_key()
103+
104+
# Public keys should match
105+
assert pub_key == pub_key3
106+
"""
107+
108+
def test_sign_verify(mldsa_type, rng):
109+
# Generate a key pair
110+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
111+
112+
# Export public key
113+
pub_key = mldsa_priv.encode_pub_key()
114+
115+
# Import public key
116+
mldsa_pub = MlDsaPublic(mldsa_type)
117+
mldsa_pub.decode_key(pub_key)
118+
119+
# Sign a message
120+
message = b"This is a test message for ML-DSA signature"
121+
signature = mldsa_priv.sign(message, rng)
122+
123+
# Verify the signature
124+
assert mldsa_pub.verify(signature, message)
125+
126+
# Verify with wrong message
127+
wrong_message = b"This is a wrong message for ML-DSA signature"
128+
assert not mldsa_pub.verify(signature, wrong_message)
129+
130+
"""
131+
def test_der_encoding(mldsa_type, rng):
132+
# Generate a key pair
133+
mldsa_priv = MlDsaPrivate.make_key(mldsa_type, rng)
134+
135+
# Export keys in DER format
136+
priv_key_der = mldsa_priv.encode_priv_key_der()
137+
pub_key_der = mldsa_priv.encode_pub_key_der()
138+
139+
# Check that DER encoded keys are longer than raw keys
140+
assert len(priv_key_der) > mldsa_priv.priv_key_size
141+
assert len(pub_key_der) > mldsa_priv.pub_key_size
142+
143+
# Test public key DER encoding from public key object
144+
mldsa_pub = MlDsaPublic(mldsa_type)
145+
mldsa_pub.decode_key(mldsa_priv.encode_pub_key())
146+
pub_key_der2 = mldsa_pub.encode_key_der()
147+
148+
# DER encoded public keys should match
149+
assert pub_key_der == pub_key_der2
150+
"""

0 commit comments

Comments
 (0)