From 4dbac3b49ffcdafa85057c418cec3caac00d85ac Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 08:10:27 +0100 Subject: [PATCH 1/2] Fix domain pickle --- src/gt4py/next/common.py | 6 ++++++ .../embedded_tests/test_domain_pickle.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 9b2870e1c0..4fc155951e 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -574,6 +574,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) + def __getstate__(self): + state = self.__dict__.copy() + # remove cached property + state.pop("slice_at", None) + return state + FiniteDomain: TypeAlias = Domain[FiniteUnitRange] diff --git a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py new file mode 100644 index 0000000000..1eb12ec98c --- /dev/null +++ b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py @@ -0,0 +1,13 @@ +import pickle + +from gt4py.next import common + +I = common.Dimension("I") +J = common.Dimension("J") + +def test_domain_pickle_after_slice(): + domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + # use slice_at to populate cached property + domain.slice_at[2:5, 5:7] + + pickle.dumps(domain) \ No newline at end of file From 06842c2f3f6f7504829c5820775579002da1a9b1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 08:12:06 +0100 Subject: [PATCH 2/2] Cleanup --- src/gt4py/next/common.py | 2 +- .../embedded_tests/test_domain_pickle.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4fc155951e..e5b393f1ae 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -574,7 +574,7 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # remove cached property state.pop("slice_at", None) diff --git a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py index 1eb12ec98c..b69950928d 100644 --- a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py +++ b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import pickle from gt4py.next import common @@ -5,9 +13,10 @@ I = common.Dimension("I") J = common.Dimension("J") + def test_domain_pickle_after_slice(): domain = common.domain(((I, (2, 4)), (J, (3, 5)))) # use slice_at to populate cached property domain.slice_at[2:5, 5:7] - pickle.dumps(domain) \ No newline at end of file + pickle.dumps(domain)