Skip to content

Commit d9c3782

Browse files
authored
Merge pull request #2768 from BioDataAnalysis/option_to_disable_temporary_object_in_assignment
Adding the ability to enable memory overlap check in assignment to avoid unneeded temporary memory allocation
2 parents a17f3de + 4507f14 commit d9c3782

File tree

6 files changed

+263
-0
lines changed

6 files changed

+263
-0
lines changed

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ target_link_libraries(xtensor INTERFACE xtl)
199199

200200
OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
201201
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
202+
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
202203
OPTION(BUILD_TESTS "xtensor test suite" OFF)
203204
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
204205
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
@@ -219,6 +220,10 @@ if(XTENSOR_CHECK_DIMENSION)
219220
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
220221
endif()
221222

223+
if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
224+
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
225+
endif()
226+
222227
if(DEFAULT_COLUMN_MAJOR)
223228
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
224229
endif()

include/xtensor/xbroadcast.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,29 @@ namespace xt
118118
return linear_end(c.expression());
119119
}
120120

121+
/*************************************
122+
* overlapping_memory_checker_traits *
123+
*************************************/
124+
125+
template <class E>
126+
struct overlapping_memory_checker_traits<
127+
E,
128+
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
129+
{
130+
static bool check_overlap(const E& expr, const memory_range& dst_range)
131+
{
132+
if (expr.size() == 0)
133+
{
134+
return false;
135+
}
136+
else
137+
{
138+
using ChildE = std::decay_t<decltype(expr.expression())>;
139+
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
140+
}
141+
}
142+
};
143+
121144
/**
122145
* @class xbroadcast
123146
* @brief Broadcasted xexpression to a specified shape.

include/xtensor/xfunction.hpp

+36
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,42 @@ namespace xt
162162
{
163163
};
164164

165+
/*************************************
166+
* overlapping_memory_checker_traits *
167+
*************************************/
168+
169+
template <class E>
170+
struct overlapping_memory_checker_traits<
171+
E,
172+
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
173+
{
174+
template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
175+
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
176+
{
177+
return false;
178+
}
179+
180+
template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
181+
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
182+
{
183+
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
184+
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range)
185+
|| check_tuple<I + 1>(t, dst_range);
186+
}
187+
188+
static bool check_overlap(const E& expr, const memory_range& dst_range)
189+
{
190+
if (expr.size() == 0)
191+
{
192+
return false;
193+
}
194+
else
195+
{
196+
return check_tuple(expr.arguments(), dst_range);
197+
}
198+
}
199+
};
200+
165201
/*************
166202
* xfunction *
167203
*************/

include/xtensor/xgenerator.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ namespace xt
7676
using size_type = std::size_t;
7777
};
7878

79+
/*************************************
80+
* overlapping_memory_checker_traits *
81+
*************************************/
82+
83+
template <class E>
84+
struct overlapping_memory_checker_traits<
85+
E,
86+
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
87+
{
88+
static bool check_overlap(const E&, const memory_range&)
89+
{
90+
return false;
91+
}
92+
};
93+
7994
/**
8095
* @class xgenerator
8196
* @brief Multidimensional function operating on indices.

include/xtensor/xsemantic.hpp

+37
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,29 @@ namespace xt
217217
template <class E, class R = void>
218218
using disable_xcontainer_semantics = typename std::enable_if<!has_container_semantics<E>::value, R>::type;
219219

220+
221+
template <class D>
222+
class xview_semantic;
223+
224+
template <class E>
225+
struct overlapping_memory_checker_traits<
226+
E,
227+
std::enable_if_t<!has_memory_address<E>::value && is_crtp_base_of<xview_semantic, E>::value>>
228+
{
229+
static bool check_overlap(const E& expr, const memory_range& dst_range)
230+
{
231+
if (expr.size() == 0)
232+
{
233+
return false;
234+
}
235+
else
236+
{
237+
using ChildE = std::decay_t<decltype(expr.expression())>;
238+
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
239+
}
240+
}
241+
};
242+
220243
/**
221244
* @class xview_semantic
222245
* @brief Implementation of the xsemantic_base interface for
@@ -598,8 +621,22 @@ namespace xt
598621
template <class E>
599622
inline auto xsemantic_base<D>::operator=(const xexpression<E>& e) -> derived_type&
600623
{
624+
#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS
601625
temporary_type tmp(e);
602626
return this->derived_cast().assign_temporary(std::move(tmp));
627+
#else
628+
auto&& this_derived = this->derived_cast();
629+
auto memory_checker = make_overlapping_memory_checker(this_derived);
630+
if (memory_checker.check_overlap(e.derived_cast()))
631+
{
632+
temporary_type tmp(e);
633+
return this_derived.assign_temporary(std::move(tmp));
634+
}
635+
else
636+
{
637+
return this->assign(e);
638+
}
639+
#endif
603640
}
604641

605642
/**************************************

include/xtensor/xutils.hpp

+147
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ namespace xt
119119
using type = T;
120120
};
121121

122+
/***************************************
123+
* is_specialization_of implementation *
124+
***************************************/
125+
126+
template <template <class...> class TT, class T>
127+
struct is_specialization_of : std::false_type
128+
{
129+
};
130+
131+
template <template <class...> class TT, class... Ts>
132+
struct is_specialization_of<TT, TT<Ts...>> : std::true_type
133+
{
134+
};
135+
122136
/*******************************
123137
* remove_class implementation *
124138
*******************************/
@@ -860,6 +874,139 @@ namespace xt
860874
{
861875
};
862876

877+
/*************************************
878+
* overlapping_memory_checker_traits *
879+
*************************************/
880+
881+
template <class T, class Enable = void>
882+
struct has_memory_address : std::false_type
883+
{
884+
};
885+
886+
template <class T>
887+
struct has_memory_address<T, void_t<decltype(std::addressof(*std::declval<T>().begin()))>> : std::true_type
888+
{
889+
};
890+
891+
struct memory_range
892+
{
893+
// Checking pointer overlap is more correct in integer values,
894+
// for more explanation check https://devblogs.microsoft.com/oldnewthing/20170927-00/?p=97095
895+
const uintptr_t m_first = 0;
896+
const uintptr_t m_last = 0;
897+
898+
explicit memory_range() = default;
899+
900+
template <class T>
901+
explicit memory_range(T* first, T* last)
902+
: m_first(reinterpret_cast<uintptr_t>(last < first ? last : first))
903+
, m_last(reinterpret_cast<uintptr_t>(last < first ? first : last))
904+
{
905+
}
906+
907+
template <class T>
908+
bool overlaps(T* first, T* last) const
909+
{
910+
if (first <= last)
911+
{
912+
return reinterpret_cast<uintptr_t>(first) <= m_last
913+
&& reinterpret_cast<uintptr_t>(last) >= m_first;
914+
}
915+
else
916+
{
917+
return reinterpret_cast<uintptr_t>(last) <= m_last
918+
&& reinterpret_cast<uintptr_t>(first) >= m_first;
919+
}
920+
}
921+
};
922+
923+
template <class E, class Enable = void>
924+
struct overlapping_memory_checker_traits
925+
{
926+
static bool check_overlap(const E&, const memory_range&)
927+
{
928+
return true;
929+
}
930+
};
931+
932+
template <class E>
933+
struct overlapping_memory_checker_traits<E, std::enable_if_t<has_memory_address<E>::value>>
934+
{
935+
static bool check_overlap(const E& expr, const memory_range& dst_range)
936+
{
937+
if (expr.size() == 0)
938+
{
939+
return false;
940+
}
941+
else
942+
{
943+
return dst_range.overlaps(std::addressof(*expr.begin()), std::addressof(*expr.rbegin()));
944+
}
945+
}
946+
};
947+
948+
struct overlapping_memory_checker_base
949+
{
950+
memory_range m_dst_range;
951+
952+
explicit overlapping_memory_checker_base() = default;
953+
954+
explicit overlapping_memory_checker_base(memory_range dst_memory_range)
955+
: m_dst_range(std::move(dst_memory_range))
956+
{
957+
}
958+
959+
template <class E>
960+
bool check_overlap(const E& expr) const
961+
{
962+
if (!m_dst_range.m_first || !m_dst_range.m_last)
963+
{
964+
return false;
965+
}
966+
else
967+
{
968+
return overlapping_memory_checker_traits<E>::check_overlap(expr, m_dst_range);
969+
}
970+
}
971+
};
972+
973+
template <class Dst, class Enable = void>
974+
struct overlapping_memory_checker : overlapping_memory_checker_base
975+
{
976+
explicit overlapping_memory_checker(const Dst&)
977+
: overlapping_memory_checker_base()
978+
{
979+
}
980+
};
981+
982+
template <class Dst>
983+
struct overlapping_memory_checker<Dst, std::enable_if_t<has_memory_address<Dst>::value>>
984+
: overlapping_memory_checker_base
985+
{
986+
explicit overlapping_memory_checker(const Dst& aDst)
987+
: overlapping_memory_checker_base(
988+
[&]()
989+
{
990+
if (aDst.size() == 0)
991+
{
992+
return memory_range();
993+
}
994+
else
995+
{
996+
return memory_range(std::addressof(*aDst.begin()), std::addressof(*aDst.rbegin()));
997+
}
998+
}()
999+
)
1000+
{
1001+
}
1002+
};
1003+
1004+
template <class Dst>
1005+
auto make_overlapping_memory_checker(const Dst& a_dst)
1006+
{
1007+
return overlapping_memory_checker<Dst>(a_dst);
1008+
}
1009+
8631010
/********************
8641011
* rebind_container *
8651012
********************/

0 commit comments

Comments
 (0)