diff --git a/src/codeflare_sdk/common/utils/unit_test_support.py b/src/codeflare_sdk/common/utils/unit_test_support.py index 9345fbc37..8e034378f 100644 --- a/src/codeflare_sdk/common/utils/unit_test_support.py +++ b/src/codeflare_sdk/common/utils/unit_test_support.py @@ -55,7 +55,7 @@ def createClusterWrongType(): config = ClusterConfiguration( name="unit-test-cluster", namespace="ns", - num_workers=2, + num_workers=True, worker_cpu_requests=[], worker_cpu_limits=4, worker_memory_requests=5, diff --git a/src/codeflare_sdk/ray/cluster/config.py b/src/codeflare_sdk/ray/cluster/config.py index f321c278a..b8b097f8c 100644 --- a/src/codeflare_sdk/ray/cluster/config.py +++ b/src/codeflare_sdk/ray/cluster/config.py @@ -242,13 +242,15 @@ def _memory_to_resource(self): def _validate_types(self): """Validate the types of all fields in the ClusterConfiguration dataclass.""" + errors = [] for field_info in fields(self): value = getattr(self, field_info.name) expected_type = field_info.type if not self._is_type(value, expected_type): - raise TypeError( - f"'{field_info.name}' should be of type {expected_type}" - ) + errors.append(f"'{field_info.name}' should be of type {expected_type}.") + + if errors: + raise TypeError("Type validation failed:\n" + "\n".join(errors)) @staticmethod def _is_type(value, expected_type): @@ -268,6 +270,10 @@ def check_type(value, expected_type): ) if origin_type is tuple: return all(check_type(elem, etype) for elem, etype in zip(value, args)) + if expected_type is int: + return isinstance(value, int) and not isinstance(value, bool) + if expected_type is bool: + return isinstance(value, bool) return isinstance(value, expected_type) return check_type(value, expected_type) diff --git a/src/codeflare_sdk/ray/cluster/test_config.py b/src/codeflare_sdk/ray/cluster/test_config.py index 1423fc2b5..3416fc28c 100644 --- a/src/codeflare_sdk/ray/cluster/test_config.py +++ b/src/codeflare_sdk/ray/cluster/test_config.py @@ -108,9 +108,11 @@ def test_all_config_params_aw(mocker): def test_config_creation_wrong_type(): - with pytest.raises(TypeError): + with pytest.raises(TypeError) as error_info: createClusterWrongType() + assert len(str(error_info.value).splitlines()) == 4 + def test_cluster_config_deprecation_conversion(mocker): config = ClusterConfiguration(