Skip to content

Commit 8405a67

Browse files
committed
MAINT: add assignCrystVector to convert from array.
Re-enable _setInterpPoints function, but use shared helper instead of spelling out numpy array copy. Handle input arrays with non-unit strides. Handle conversion from Python lists of floats.
1 parent fb4db24 commit 8405a67

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

extensions/helpers.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,20 @@
1717
*****************************************************************************/
1818

1919
#include "helpers.hpp"
20+
#include <boost/python/stl_iterator.hpp>
21+
2022
#include <iostream>
23+
#include <list>
2124

22-
namespace bp = boost::python;
25+
#include <ObjCryst/CrystVector/CrystVector.h>
26+
27+
// Use numpy here, but initialize it later in the extension module.
28+
#include "pyobjcryst_numpy_setup.hpp"
29+
#define NO_IMPORT_ARRAY
30+
#include <numpy/arrayobject.h>
2331

24-
using namespace std;
32+
33+
namespace bp = boost::python;
2534

2635
void swapstdout(std::ostream& buf)
2736
{
@@ -30,3 +39,34 @@ void swapstdout(std::ostream& buf)
3039
std::cout.rdbuf(buf.rdbuf());
3140
buf.rdbuf(cout_strbuf);
3241
}
42+
43+
44+
void assignCrystVector(CrystVector<double>& cv, bp::object obj)
45+
{
46+
// copy data directly if it is a numpy array of doubles
47+
PyArrayObject* a = PyArray_Check(obj.ptr()) ?
48+
reinterpret_cast<PyArrayObject*>(obj.ptr()) : NULL;
49+
bool isdoublenumpyarray = a &&
50+
(1 == PyArray_NDIM(a)) &&
51+
(NPY_DOUBLE == PyArray_TYPE(a));
52+
if (isdoublenumpyarray)
53+
{
54+
const double* src = static_cast<double*>(PyArray_DATA(a));
55+
npy_intp stride = PyArray_STRIDE(a, 0) / PyArray_ITEMSIZE(a);
56+
cv.resize(PyArray_SIZE(a));
57+
double* dst = cv.data();
58+
const double* last = dst + cv.size();
59+
for (; dst != last; ++dst, src += stride) *dst = *src;
60+
}
61+
// otherwise copy elementwise converting each element to a double
62+
else
63+
{
64+
bp::stl_input_iterator<double> begin(obj), end;
65+
// use intermediate list to preserve cv when conversion fails.
66+
std::list<double> values(begin, end);
67+
cv.resize(values.size());
68+
std::list<double>::const_iterator vv = values.begin();
69+
double* dst = cv.data();
70+
for (; vv != values.end(); ++vv, ++dst) *dst = *vv;
71+
}
72+
}

extensions/helpers.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,9 @@ bp::list setToPyList(std::set<T>& v)
109109
}
110110

111111

112+
// Extract CrystVector from a Python object
113+
template <class T> class CrystVector;
114+
115+
void assignCrystVector(CrystVector<double>& cv, bp::object obj);
116+
112117
#endif

extensions/powderpatternbackground_ext.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,22 @@
2121
#include <boost/python/args.hpp>
2222
#include <boost/python/copy_const_reference.hpp>
2323

24-
#include <numpy/noprefix.h>
25-
#include <numpy/arrayobject.h>
26-
2724
#include <ObjCryst/ObjCryst/PowderPattern.h>
2825

26+
#include "helpers.hpp"
27+
2928
namespace bp = boost::python;
3029
using namespace ObjCryst;
3130

3231
namespace {
3332

3433
void _SetInterpPoints(PowderPatternBackground& b,
35-
PyObject* tth, PyObject* backgd)
34+
bp::object tth, bp::object backgd)
3635
{
37-
// FIXME -- adjust for NumPy C-API 1.7
38-
39-
// // cout << "_SetInterpPoints:" << tth << ", " << backgd << endl;
40-
// // cout << "dimensions = " << PyArray_NDIM(tth) << endl;
41-
// const unsigned long nb = *(PyArray_DIMS((PyObject*)tth));
42-
// // cout << "nbPoints = " << nb << endl;
43-
// CrystVector_REAL tth2(nb), backgd2(nb);
44-
// // FIXME -- reuse some conversion function here
45-
// //:TODO: We assume the arrays are contiguous & double (float64) !
46-
// double* p = (double*) (PyArray_DATA(tth));
47-
// double* p2 = (double*) (tth2.data());
48-
// for (unsigned long i = 0; i < nb; i++) *p2++ = *p++;
49-
// p = (double*) (PyArray_DATA(backgd));
50-
// p2 = (double*) (backgd2.data());
51-
// for (unsigned long i = 0; i < nb; i++) *p2++ = *p++;
52-
// b.SetInterpPoints(tth2, backgd2);
36+
CrystVector_REAL cvtth, cvbackg;
37+
assignCrystVector(cvtth, tth);
38+
assignCrystVector(cvbackg, backgd);
39+
b.SetInterpPoints(cvtth, cvbackg);
5340
}
5441

5542
} // namespace

0 commit comments

Comments
 (0)