diff --git a/.github/workflows/badge.yml b/.github/workflows/badge.yml index f82cf24fdf..d010d21bf8 100644 --- a/.github/workflows/badge.yml +++ b/.github/workflows/badge.yml @@ -12,13 +12,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ 3.7 ] + python-version: [ 3.8 ] env: GIST_ID: 3690cccd811e4c5f771075c2f785c7bb steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Download cloc diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 9465203300..0dc0235c70 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -17,7 +17,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Generate diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index be7bb6e890..1ed80baa8e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ['3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: code style diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 8acfaf57ec..7bb14a2350 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -11,7 +11,7 @@ jobs: if: ( !contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'ut skip')) strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ['3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v4 @@ -42,7 +42,7 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ['3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v4 diff --git a/ding/envs/env/env_implementation_check.py b/ding/envs/env/env_implementation_check.py index e7f0dfe03b..4bf8c2976c 100644 --- a/ding/envs/env/env_implementation_check.py +++ b/ding/envs/env/env_implementation_check.py @@ -24,38 +24,37 @@ def check_space_dtype(env: 'BaseEnv') -> None: # Util function -def check_array_space(ndarray: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None: - if isinstance(ndarray, np.ndarray): +def check_array_space(data: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None: + if isinstance(data, np.ndarray): # print("{}'s type should be np.ndarray".format(name)) - assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format( - name, ndarray.dtype, space.dtype - ) - assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format( - name, ndarray.shape, space.shape - ) + assert data.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(name, data.dtype, space.dtype) + assert data.shape == space.shape, "{}'s shape is {}, but requires {}".format(name, data.shape, space.shape) if isinstance(space, Box): - assert (space.low <= ndarray).all() and (ndarray <= space.high).all( - ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high) + assert (space.low <= data).all() and (data <= space.high).all( + ), "{}'s value is {}, but requires in range ({},{})".format(name, data, space.low, space.high) elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)): - print(space.start, space.n) - assert (ndarray >= space.start) and (ndarray <= space.n) - elif isinstance(ndarray, Sequence): - for i in range(len(ndarray)): + if isinstance(space, Discrete): + assert (data >= space.start) and (data <= space.n) + else: + assert (data >= 0).all() + assert all([d < n for d, n in zip(data, space.nvec)]) + elif isinstance(data, Sequence): + for i in range(len(data)): try: - check_array_space(ndarray[i], space[i], name) + check_array_space(data[i], space[i], name) except AssertionError as e: print("The following error happens at {}-th index".format(i)) raise e - elif isinstance(ndarray, dict): - for k in ndarray.keys(): + elif isinstance(data, dict): + for k in data.keys(): try: - check_array_space(ndarray[k], space[k], name) + check_array_space(data[k], space[k], name) except AssertionError as e: print("The following error happens at key {}".format(k)) raise e else: raise TypeError( - "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray)) + "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(data)) ) diff --git a/ding/envs/env/tests/test_env_implementation_check.py b/ding/envs/env/tests/test_env_implementation_check.py index 02fa79e064..41e66b62b2 100644 --- a/ding/envs/env/tests/test_env_implementation_check.py +++ b/ding/envs/env/tests/test_env_implementation_check.py @@ -22,6 +22,11 @@ def test_check_array_space(): discrete_array = np.array(11, dtype=np.int64) with pytest.raises(AssertionError): check_array_space(discrete_array, discrete_space, 'test_discrete') + + multi_discrete_space = gym.spaces.MultiDiscrete([2, 3]) + multi_discrete_array = np.array([1, 2], dtype=np.int64) + check_array_space(multi_discrete_array, multi_discrete_space, 'test_multi_discrete') + seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32)) seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)] with pytest.raises(AssertionError):