diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3d285b638ae2..4e87d9b1b2f6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,10 @@ Changelog :class:`~cryptography.x509.CertificateSigningRequest`, and :class:`~cryptography.x509.CertificateRevocationList` as field types in :doc:`/hazmat/asn1/index` structures. +* Added :func:`~cryptography.hazmat.asn1.value_set`, a class decorator that + registers an :class:`enum.Enum` subclass as an ASN.1 value set: members + are encoded as their underlying value, and decoding fails if the decoded + value does not match one of the declared members. .. _v48-0-0: diff --git a/docs/hazmat/asn1/reference.rst b/docs/hazmat/asn1/reference.rst index 0f94772620aa..8bec4a903199 100644 --- a/docs/hazmat/asn1/reference.rst +++ b/docs/hazmat/asn1/reference.rst @@ -46,7 +46,8 @@ Serialization Serialize an ASN.1 object into DER-encoded bytes. :param value: The ASN.1 object to encode. Must be an instance of a - class decorated with :func:`sequence` or :func:`set`, or a primitive ASN.1 type + class decorated with :func:`sequence`, :func:`set`, or + :func:`value_set`, or a primitive ASN.1 type (``int``, ``bool``, ``bytes``, ``str``, :class:`~cryptography.x509.ObjectIdentifier`, :class:`PrintableString`, :class:`IA5String`, :class:`UTCTime`, @@ -154,6 +155,39 @@ that have no direct Python equivalent: ... ValueError: error parsing asn1 value: ... +.. decorator:: value_set(value_type) + + A class decorator that registers an :class:`enum.Enum` subclass as an + ASN.1 value set: a set of named values of a single underlying type. + All the member values must be instances of ``value_type``. + + Members are encoded exactly as their underlying value. When decoding, + the value is decoded and mapped back to the corresponding enum member; + decoding fails with :class:`ValueError` if the decoded value does not + match any member. + + Fields of a value set type can be annotated with :class:`Explicit`, + :class:`Implicit`, and :class:`Default` using :class:`typing.Annotated`. + + :param value_type: The underlying ASN.1 type of the member values. + :type value_type: :class:`type` + + .. doctest:: + + >>> import enum + >>> from cryptography import x509 + >>> from cryptography.hazmat import asn1 + >>> @asn1.value_set(x509.ObjectIdentifier) + ... class HashAlgorithm(enum.Enum): + ... SHA_256 = x509.ObjectIdentifier("2.16.840.1.101.3.4.2.1") + ... SHA_384 = x509.ObjectIdentifier("2.16.840.1.101.3.4.2.2") + >>> @asn1.sequence + ... class Example: + ... algorithm: HashAlgorithm + >>> encoded = asn1.encode_der(Example(algorithm=HashAlgorithm.SHA_256)) + >>> asn1.decode_der(Example, encoded).algorithm + > + .. class:: PrintableString(value) Wraps ASN.1 ``PrintableString`` values. ``PrintableString`` is a restricted diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index ac3d4bd8590b..7fb0fb4a47d0 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -20,6 +20,7 @@ encode_der, sequence, set, + value_set, ) __all__ = [ @@ -40,4 +41,5 @@ "encode_der", "sequence", "set", + "value_set", ] diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index a308d69db7dd..8c01f74985c6 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -6,6 +6,7 @@ import builtins import dataclasses +import enum import sys import types import typing @@ -170,7 +171,11 @@ def _normalize_field_type( root_type = field_type.__asn1_root__ if not isinstance( root_type, - (declarative_asn1.Type.Sequence, declarative_asn1.Type.Set), + ( + declarative_asn1.Type.Sequence, + declarative_asn1.Type.Set, + declarative_asn1.Type.ValueSet, + ), ): raise TypeError(f"unsupported root type: {root_type}") return declarative_asn1.AnnotatedType( @@ -425,6 +430,47 @@ def set(cls: type[U]) -> type[U]: return dataclass_cls +def value_set( + value_type: type, +) -> typing.Callable[[type[U]], type[U]]: + """ + A class decorator that registers an `enum.Enum` subclass as an + ASN.1 value set of the given underlying type. All the member + values must be instances of `value_type`. Members are encoded as + their value; decoding fails if the decoded value does not match + any member. + """ + rust_type = declarative_asn1.non_root_python_to_rust(value_type) + + def decorator(cls: type[U]) -> type[U]: + if not issubclass(cls, enum.Enum): + raise TypeError( + "value sets can only be defined from enum.Enum subclasses" + ) + members = list(cls) + if not members: + raise TypeError( + f"value set '{cls.__name__}' must have at least one member" + ) + for member in members: + if not isinstance(member.value, value_type): + raise TypeError( + f"member '{member.name}' of value set '{cls.__name__}' " + f"must have a value of type " + f"'{value_type.__name__}', got: " + f"'{type(member.value).__name__}'" + ) + inner = declarative_asn1.AnnotatedType( + rust_type, declarative_asn1.Annotation() + ) + root = declarative_asn1.Type.ValueSet(cls, inner) + + setattr(cls, "__asn1_root__", root) + return cls + + return decorator + + # TODO: replace with `Default[U]` once the min Python version is >= 3.12 @dataclasses.dataclass(frozen=True) class Default(typing.Generic[U]): diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index e499eb5c846e..6480ba647809 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -20,6 +20,7 @@ class Type: SetOf: typing.ClassVar[type] Option: typing.ClassVar[type] Choice: typing.ClassVar[type] + ValueSet: typing.ClassVar[type] PyBool: typing.ClassVar[type] PyInt: typing.ClassVar[type] PyBytes: typing.ClassVar[type] diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index a945fa7bc38e..a328045faa2d 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -3,13 +3,13 @@ // for complete details. use asn1::Parser; -use pyo3::types::{PyAnyMethods, PyListMethods}; +use pyo3::types::{PyAnyMethods, PyListMethods, PyTypeMethods}; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ - check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType, - Annotation, BitString, Encoding, GeneralizedTime, IA5String, Null, PrintableString, SetOf, Tlv, - Type, UtcTime, Variant, + check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, value_set_inner_type, + AnnotatedType, Annotation, BitString, Encoding, GeneralizedTime, IA5String, Null, + PrintableString, SetOf, Tlv, Type, UtcTime, Variant, }; use crate::error::CryptographyError; @@ -254,6 +254,38 @@ fn decode_null<'a>( Ok(pyo3::Bound::new(py, Null {})?) } +// Decodes a value set field: decodes the underlying value, then maps +// it back to the enum member with that value. Fails if the decoded +// value does not correspond to any member. +fn decode_value_set<'a>( + py: pyo3::Python<'a>, + parser: &mut Parser<'a>, + cls: &pyo3::Py, + inner_type: &AnnotatedType, + annotation: &Annotation, +) -> ParseResult> { + let inner_ann_type = value_set_inner_type(py, inner_type, annotation)?; + let decoded = decode_annotated_type(py, parser, &inner_ann_type)?; + // NOTE: This is a linear scan over the members of the enum. If this + // ever becomes a performance problem, it could be replaced with a + // value -> member map stored in `Type::ValueSet` (keeping in mind + // that hash-based lookups won't work for the asn1 wrapper types, + // which implement `__eq__` but not `__hash__`). + for member in cls.bind(py).try_iter()? { + let member = member?; + if member.getattr(pyo3::intern!(py, "value"))?.eq(&decoded)? { + return Ok(member); + } + } + Err(CryptographyError::Py( + pyo3::exceptions::PyValueError::new_err(format!( + "{} is not a valid value for {}", + decoded.repr()?, + cls.bind(py).name()?, + )), + )) +} + // Utility function to handle explicit encoding when parsing // CHOICE fields. fn decode_choice_with_encoding<'a>( @@ -420,6 +452,9 @@ pub(crate) fn decode_annotated_type<'a>( ))? } }, + Type::ValueSet(cls, inner_type) => { + decode_value_set(py, parser, cls, inner_type.get(), annotation)? + } Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(), Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(), Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(), diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index 6d6493645985..139978fc173b 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -3,11 +3,11 @@ // for complete details. use asn1::{SimpleAsn1Writable, Writer}; -use pyo3::types::{PyAnyMethods, PyListMethods}; +use pyo3::types::{PyAnyMethods, PyListMethods, PyTypeMethods}; use crate::declarative_asn1::types::{ - check_size_constraint, AnnotatedType, AnnotatedTypeObject, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Type, UtcTime, Variant, + check_size_constraint, value_set_inner_type, AnnotatedType, AnnotatedTypeObject, BitString, + Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime, Variant, }; use crate::error::CryptographyError; @@ -181,6 +181,22 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { ), )) } + Type::ValueSet(cls, inner_type) => { + if !value.is_instance(cls.bind(py))? { + return Err(CryptographyError::Py( + pyo3::exceptions::PyTypeError::new_err(format!( + "value set field must be an instance of {}, got: {}", + cls.bind(py).name()?, + value.get_type().name()?, + )), + )); + } + let object = AnnotatedTypeObject { + annotated_type: &value_set_inner_type(py, inner_type.get(), annotation)?, + value: value.getattr(pyo3::intern!(py, "value"))?, + }; + object.write(writer) + } Type::PyBool() => { let val: bool = value.extract()?; Ok(write_value(writer, &val, encoding)?) diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 4efbe75a849d..a7fd94a22df7 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -34,6 +34,12 @@ pub enum Type { /// CHOICE (`T | U | ...`) /// The list contains elements of type Variant Choice(pyo3::Py), + /// Value set (an `enum.Enum` whose member values all share + /// a single underlying ASN.1 type). + /// The first element is the Python enum class, the second + /// element is the (already converted) underlying type of the + /// member values. + ValueSet(pyo3::Py, pyo3::Py), // Python types that we map to canonical ASN.1 types // @@ -658,6 +664,7 @@ pub(crate) fn is_tag_valid_for_type( Type::Choice(variants) => variants.bind(py).into_iter().any(|v| { is_tag_valid_for_variant(py, tag, v.cast::().unwrap().get(), encoding) }), + Type::ValueSet(_, t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag), Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag), Type::PyBytes() => { @@ -698,6 +705,29 @@ pub(crate) fn is_tag_valid_for_type( } } +// Builds the AnnotatedType used to encode/decode the underlying value of +// a value set member: the underlying type, annotated with the encoding of +// the value set field. The DEFAULT annotation (if any) applies to the enum +// member (not the underlying value), so it is handled at the value set +// level and not propagated here. +pub(crate) fn value_set_inner_type( + py: pyo3::Python<'_>, + inner: &AnnotatedType, + annotation: &Annotation, +) -> pyo3::PyResult { + Ok(AnnotatedType { + inner: inner.inner.clone_ref(py), + annotation: pyo3::Py::new( + py, + Annotation { + default: None, + encoding: annotation.encoding.as_ref().map(|e| e.clone_ref(py)), + size: annotation.size.as_ref().map(|s| s.clone_ref(py)), + }, + )?, + }) +} + pub(crate) fn check_size_constraint( size_annotation: &Option>, data_length: impl FnOnce() -> usize, diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 63294768fad8..372530349cf0 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -3,6 +3,7 @@ # for complete details. import datetime +import enum import re import sys import typing @@ -378,6 +379,10 @@ def test_fields_of_variant_type(self) -> None: choice = declarative_asn1.Type.Choice(my_list) assert choice._0 is my_list + value_set = declarative_asn1.Type.ValueSet(type(None), ann_type) + assert value_set._0 is type(None) + assert value_set._1 is ann_type + def test_fields_of_variant_encoding(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1 @@ -527,3 +532,46 @@ class Invalid: @asn1.set class Example: foo: Invalid + + +class TestValueSetAPI: + def test_fail_non_enum(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "value sets can only be defined from enum.Enum subclasses" + ), + ): + + @asn1.value_set(int) + class Example: + pass + + def test_fail_empty_enum(self) -> None: + with pytest.raises( + TypeError, + match="value set 'Example' must have at least one member", + ): + + @asn1.value_set(int) + class Example(enum.Enum): + pass + + def test_fail_member_value_of_wrong_type(self) -> None: + with pytest.raises( + TypeError, + match="member 'B' of value set 'Example' must have a value " + "of type 'int', got: 'str'", + ): + + @asn1.value_set(int) + class Example(enum.Enum): + A = 1 + B = "b" + + def test_fail_unsupported_value_type(self) -> None: + with pytest.raises( + TypeError, + match="cannot handle type", + ): + asn1.value_set(float) diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 7a60a2f6827e..d88828f04623 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -4,6 +4,7 @@ import dataclasses import datetime +import enum import os import re import sys @@ -2088,3 +2089,178 @@ class Example: with pytest.raises(TypeError): asn1.encode_der(Example(cert=9)) # type: ignore[arg-type] + + +@asn1.value_set(x509.ObjectIdentifier) +class Algorithm(enum.Enum): + A = x509.ObjectIdentifier("1.2.3.4") + B = x509.ObjectIdentifier("1.2.3.5") + + +class TestValueSet: + def test_ok_oid_value_set(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Algorithm + + assert_roundtrips( + [ + ( + Example(algorithm=Algorithm.A), + b"\x30\x05\x06\x03\x2a\x03\x04", + ), + ( + Example(algorithm=Algorithm.B), + b"\x30\x05\x06\x03\x2a\x03\x05", + ), + ] + ) + + # Decoding returns the enum member itself + decoded = asn1.decode_der(Example, b"\x30\x05\x06\x03\x2a\x03\x04") + assert decoded.algorithm is Algorithm.A + + def test_ok_int_value_set(self) -> None: + @asn1.value_set(int) + class Version(enum.Enum): + V1 = 1 + V2 = 2 + + @asn1.sequence + @_comparable_dataclass + class Example: + version: Version + + assert_roundtrips( + [ + (Example(version=Version.V1), b"\x30\x03\x02\x01\x01"), + (Example(version=Version.V2), b"\x30\x03\x02\x01\x02"), + ] + ) + + def test_ok_top_level_value_set(self) -> None: + assert_roundtrips( + [ + (Algorithm.A, b"\x06\x03\x2a\x03\x04"), + (Algorithm.B, b"\x06\x03\x2a\x03\x05"), + ] + ) + + def test_ok_value_set_implicit(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Annotated[Algorithm, asn1.Implicit(0)] + + assert_roundtrips( + [ + ( + Example(algorithm=Algorithm.A), + b"\x30\x05\x80\x03\x2a\x03\x04", + ), + ] + ) + + def test_ok_value_set_explicit(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Annotated[Algorithm, asn1.Explicit(0)] + + assert_roundtrips( + [ + ( + Example(algorithm=Algorithm.A), + b"\x30\x07\xa0\x05\x06\x03\x2a\x03\x04", + ), + ] + ) + + def test_ok_optional_value_set(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: typing.Union[Algorithm, None] + + assert_roundtrips( + [ + ( + Example(algorithm=Algorithm.A), + b"\x30\x05\x06\x03\x2a\x03\x04", + ), + (Example(algorithm=None), b"\x30\x00"), + ] + ) + + def test_ok_value_set_default(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Annotated[Algorithm, asn1.Default(Algorithm.A)] + + assert_roundtrips( + [ + (Example(algorithm=Algorithm.A), b"\x30\x00"), + ( + Example(algorithm=Algorithm.B), + b"\x30\x05\x06\x03\x2a\x03\x05", + ), + ] + ) + + with pytest.raises( + ValueError, match="DEFAULT value was explicitly encoded" + ): + asn1.decode_der(Example, b"\x30\x05\x06\x03\x2a\x03\x04") + + def test_ok_value_set_in_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + field: typing.Union[Algorithm, int] + + assert_roundtrips( + [ + ( + Example(field=Algorithm.A), + b"\x30\x05\x06\x03\x2a\x03\x04", + ), + (Example(field=9), b"\x30\x03\x02\x01\x09"), + ] + ) + + def test_fail_decode_non_member_value(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Algorithm + + with pytest.raises( + ValueError, match="is not a valid value for Algorithm" + ): + asn1.decode_der(Example, b"\x30\x05\x06\x03\x2a\x03\x06") + + def test_fail_decode_wrong_type(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Algorithm + + with pytest.raises(ValueError): + asn1.decode_der(Example, b"\x30\x03\x02\x01\x01") + + def test_fail_encode_non_member(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + algorithm: Algorithm + + with pytest.raises( + TypeError, + match="value set field must be an instance of Algorithm, " + "got: ObjectIdentifier", + ): + asn1.encode_der( + Example(algorithm=x509.ObjectIdentifier("1.2.3.4")) # type: ignore[arg-type] + )