diff --git a/scalecodec/base.py b/scalecodec/base.py index 2c22a10..1269f1a 100644 --- a/scalecodec/base.py +++ b/scalecodec/base.py @@ -164,7 +164,6 @@ def new(self, **kwargs) -> 'ScaleType': return obj - def impl(self, scale_type_cls: type = None, runtime_config=None) -> 'ScaleTypeDef': """ @@ -286,7 +285,7 @@ def deserialize(self, value_serialized: any): return self.value_object self.value_object = self.type_def.deserialize(value_serialized) - self.value_serialized = value_serialized + self.value_serialized = self.type_def.serialize(self.value_object) return self.value_object diff --git a/scalecodec/types.py b/scalecodec/types.py index 09e0a58..cf5dd17 100644 --- a/scalecodec/types.py +++ b/scalecodec/types.py @@ -655,7 +655,18 @@ def deserialize(self, value: str) -> str: return value +class ArrayObject(ScaleType): + + def to_bytes(self) -> bytes: + if self.type_def.type_def is not U8: + raise ScaleDeserializeException('Only an Array of U8 can be represented as bytes') + return self.value_object + + class Array(ScaleTypeDef): + + scale_type_cls = ArrayObject + def __init__(self, type_def: ScaleTypeDef, length: int): self.type_def = type_def self.length = length @@ -715,12 +726,27 @@ def serialize(self, value: Union[list, bytes]) -> Union[list, str]: return f'0x{value.hex()}' def deserialize(self, value: Union[list, str, bytes]) -> Union[list, bytes]: + + if type(value) not in [list, str, bytes]: + raise ScaleDeserializeException('value should be of type list, str or bytes') + if type(value) is str: if value[0:2] == '0x': - return bytes.fromhex(value[2:]) + value = bytes.fromhex(value[2:]) else: - return value.encode() - else: + value = value.encode() + + if len(value) != self.length: + raise ScaleDeserializeException('Length of array does not match size of value') + + if type(value) is bytes: + if self.type_def is not U8: + raise ScaleDeserializeException('Only an Array of U8 can be represented as (hex)bytes') + + return value + + if type(value) is list: + value_object = [] for item in value: @@ -793,7 +819,13 @@ def decode(self, data: ScaleBytes) -> list: return value def serialize(self, value: list) -> list: - return [(k.value_serialized, v.value_serialized) for k, v in value] + output = [] + for k, v in value: + if type(k) is ScaleType and type(v) is ScaleType: + output.append((k.value_serialized, v.value_serialized)) + else: + output.append((k, v)) + return output def deserialize(self, value: list) -> list: return [(self.key_def.deserialize(k), self.value_def.deserialize(v)) for k, v in value] @@ -833,12 +865,20 @@ def decode(self, data: ScaleBytes) -> bytearray: def serialize(self, value: bytearray) -> str: return f'0x{value.hex()}' - def deserialize(self, value: str) -> bytearray: - if type(value) is str: + def deserialize(self, value: Union[bytes, str, list]) -> bytes: + + if type(value) in (list, bytearray): + value = bytes(value) + + elif type(value) is str: if value[0:2] == '0x': - value = bytearray.fromhex(value[2:]) + value = bytes.fromhex(value[2:]) else: value = value.encode('utf-8') + + if type(value) is not bytes: + raise ScaleDeserializeException(f'Cannot deserialize type "{type(value)}"') + return value def example_value(self, _recursion_level: int = 0, max_recursion: int = TYPE_DECOMP_MAX_RECURSIVE): @@ -860,13 +900,19 @@ def serialize(self, value: str) -> str: def deserialize(self, value: str) -> str: return value - def create_example(self, _recursion_level: int = 0): return 'String' +class HashDefObject(ScaleType): + def to_bytes(self) -> bytes: + return self.value_object + + class HashDef(ScaleTypeDef): + scale_type_cls = HashDefObject + def __init__(self, bits: int): super().__init__() self.bits = bits @@ -897,6 +943,10 @@ def serialize(self, value: bytes) -> str: def deserialize(self, value: Union[str, bytes]) -> bytes: if type(value) is str: value = bytes.fromhex(value[2:]) + + if type(value) is not bytes: + raise ScaleDeserializeException('value should be of type str or bytes') + return value def example_value(self, _recursion_level: int = 0, max_recursion: int = TYPE_DECOMP_MAX_RECURSIVE): diff --git a/test/test_boolean.py b/test/test_boolean.py index 9deec2e..7f7d31d 100644 --- a/test/test_boolean.py +++ b/test/test_boolean.py @@ -47,6 +47,11 @@ def test_bool_encode_decode(self): self.assertEqual(value, scale_obj.value) + def test_bool_encode_false(self): + scale_obj = Bool().new() + data = scale_obj.encode(False) + self.assertEqual(ScaleBytes("0x00"), data) + if __name__ == '__main__': unittest.main()