Skip to content

Commit 63b9d11

Browse files
authored
Merge pull request #525 from jcarpent/topic/devel
Fix common issues in custom types
2 parents 598032a + 5d952d2 commit 63b9d11

File tree

6 files changed

+55
-24
lines changed

6 files changed

+55
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
99
### Fixed
1010

1111
- Fix Python library linkage for Debug build on Windows ([#514](https://github.com/stack-of-tasks/eigenpy/pull/514))
12+
- Fix np.ones when dtype is a custom user type ([#525](https://github.com/stack-of-tasks/eigenpy/pull/525))
1213

1314
## [3.10.1] - 2024-10-30
1415

include/eigenpy/ufunc.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright (c) 2020-2021 INRIA
2+
// Copyright (c) 2020-2025 INRIA
33
// code aptapted from
44
// https://github.com/numpy/numpy/blob/41977b24ae011a51f64faa75cb524c7350fdedd9/numpy/core/src/umath/_rational_tests.c.src
55
//
@@ -151,6 +151,7 @@ EIGENPY_REGISTER_BINARY_OPERATOR(greater_equal, >=)
151151
}
152152

153153
EIGENPY_REGISTER_UNARY_OPERATOR(negative, -)
154+
EIGENPY_REGISTER_UNARY_OPERATOR(square, x *)
154155

155156
} // namespace internal
156157

@@ -258,6 +259,7 @@ void registerCommonUfunc() {
258259

259260
// Unary operators
260261
EIGENPY_REGISTER_UNARY_UFUNC(negative, type_code, Scalar, Scalar);
262+
EIGENPY_REGISTER_UNARY_UFUNC(square, type_code, Scalar, Scalar);
261263

262264
Py_DECREF(numpy);
263265
}

include/eigenpy/user-type.hpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,34 +132,43 @@ struct SpecialMethods<T, NPY_USERDEF> {
132132
eigenpy::Exception("Cannot retrieve the type stored in the array.");
133133
return -1;
134134
}
135+
135136
PyArrayObject* py_array = static_cast<PyArrayObject*>(array);
136137
PyArray_Descr* descr = PyArray_DTYPE(py_array);
137138
PyTypeObject* array_scalar_type = descr->typeobj;
138139
PyTypeObject* src_obj_type = Py_TYPE(src_obj);
139140

141+
T& dest = *static_cast<T*>(dest_ptr);
140142
if (array_scalar_type != src_obj_type) {
141-
std::stringstream ss;
142-
ss << "The input type is of wrong type. ";
143-
ss << "The expected type is " << bp::type_info(typeid(T)).name()
144-
<< std::endl;
145-
eigenpy::Exception(ss.str());
146-
return -1;
147-
}
143+
long long src_value = PyLong_AsLongLong(src_obj);
144+
if (src_value == -1 && PyErr_Occurred()) {
145+
std::stringstream ss;
146+
ss << "The input type is of wrong type. ";
147+
ss << "The expected type is " << bp::type_info(typeid(T)).name()
148+
<< std::endl;
149+
eigenpy::Exception(ss.str());
150+
return -1;
151+
}
152+
153+
dest = T(src_value);
148154

149-
bp::extract<T&> extract_src_obj(src_obj);
150-
if (!extract_src_obj.check()) {
151-
std::stringstream ss;
152-
ss << "The input type is of wrong type. ";
153-
ss << "The expected type is " << bp::type_info(typeid(T)).name()
154-
<< std::endl;
155-
eigenpy::Exception(ss.str());
156-
return -1;
155+
} else {
156+
bp::extract<T&> extract_src_obj(src_obj);
157+
if (!extract_src_obj.check()) {
158+
std::cout << "if (!extract_src_obj.check())" << std::endl;
159+
std::stringstream ss;
160+
ss << "The input type is of wrong type. ";
161+
ss << "The expected type is " << bp::type_info(typeid(T)).name()
162+
<< std::endl;
163+
eigenpy::Exception(ss.str());
164+
return -1;
165+
}
166+
167+
const T& src = extract_src_obj();
168+
T& dest = *static_cast<T*>(dest_ptr);
169+
dest = src;
157170
}
158171

159-
const T& src = extract_src_obj();
160-
T& dest = *static_cast<T*>(dest_ptr);
161-
dest = src;
162-
163172
return 0;
164173
}
165174

src/register.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ int Register::registerNewType(
6464
}
6565

6666
PyArray_DescrProto* descr_ptr = new PyArray_DescrProto();
67-
Py_SET_TYPE(descr_ptr, &PyArrayDescr_Type);
6867
PyArray_DescrProto& descr = *descr_ptr;
6968
descr.typeobj = py_type_ptr;
7069
descr.kind = 'V';
@@ -92,6 +91,7 @@ int Register::registerNewType(
9291
funcs.fill = fill;
9392
funcs.fillwithscalar = fillwithscalar;
9493
// f->cast = cast;
94+
Py_SET_TYPE(descr_ptr, &PyArrayDescr_Type);
9595

9696
const int code = call_PyArray_RegisterDataType(descr_ptr);
9797
assert(code >= 0 && "The return code should be positive");

unittest/python/test_user_type.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def test(dtype):
1111
rng = np.random.default_rng()
12-
mat = np.array(np.ones((rows, cols)).astype(np.int32), dtype=dtype)
12+
mat = np.ones((rows, cols), dtype=dtype)
1313
mat = rng.random((rows, cols)).astype(dtype)
1414
mat_copy = mat.copy()
1515
assert (mat == mat_copy).all()
@@ -44,6 +44,11 @@ def test(dtype):
4444
mat2 = np.matmul(mat, mat.T)
4545
assert np.isclose(mat2.astype(np.double), mat2_ref).all()
4646

47+
vec = np.ones((rows,), dtype=dtype)
48+
norm = np.linalg.norm(vec)
49+
norm_ref = np.linalg.norm(vec.astype(np.double))
50+
assert norm == norm_ref
51+
4752

4853
def test_cast(from_dtype, to_dtype):
4954
np.can_cast(from_dtype, to_dtype)
@@ -63,8 +68,17 @@ def test_cast(from_dtype, to_dtype):
6368
test_cast(user_type.CustomDouble, np.int32)
6469
test_cast(np.int32, user_type.CustomDouble)
6570

66-
test(user_type.CustomFloat)
67-
6871
v = user_type.CustomDouble(1)
6972
a = np.array(v)
7073
assert type(v) is a.dtype.type
74+
75+
test(user_type.CustomFloat)
76+
77+
test_cast(user_type.CustomFloat, np.float32)
78+
test_cast(np.double, user_type.CustomFloat)
79+
80+
test_cast(user_type.CustomFloat, np.int64)
81+
test_cast(np.int64, user_type.CustomFloat)
82+
83+
test_cast(user_type.CustomFloat, np.int32)
84+
test_cast(np.int32, user_type.CustomFloat)

unittest/user_type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,19 @@ BOOST_PYTHON_MODULE(user_type) {
201201

202202
eigenpy::registerCast<DoubleType, double>(true);
203203
eigenpy::registerCast<double, DoubleType>(true);
204+
eigenpy::registerCast<DoubleType, float>(false);
205+
eigenpy::registerCast<float, DoubleType>(true);
204206
eigenpy::registerCast<DoubleType, int>(false);
205207
eigenpy::registerCast<int, DoubleType>(true);
206208
eigenpy::registerCast<DoubleType, long long>(false);
207209
eigenpy::registerCast<long long, DoubleType>(true);
208210
eigenpy::registerCast<DoubleType, long>(false);
209211
eigenpy::registerCast<long, DoubleType>(true);
212+
210213
eigenpy::registerCast<FloatType, double>(true);
211214
eigenpy::registerCast<double, FloatType>(false);
215+
eigenpy::registerCast<FloatType, float>(true);
216+
eigenpy::registerCast<float, FloatType>(true);
212217
eigenpy::registerCast<FloatType, long long>(false);
213218
eigenpy::registerCast<long long, FloatType>(true);
214219
eigenpy::registerCast<FloatType, int>(false);

0 commit comments

Comments
 (0)