Skip to content
This repository was archived by the owner on Feb 25, 2025. It is now read-only.

Commit 3b6b8fe

Browse files
authored
Extend py::class_ for nonconst read (#257)
1 parent fd89f23 commit 3b6b8fe

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

pycolmap/geometry/bindings.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "pycolmap/geometry/homography_matrix.h"
77
#include "pycolmap/helpers.h"
8+
#include "pycolmap/pybind11_extension.h"
89

910
#include <sstream>
1011

@@ -20,7 +21,7 @@ using namespace pybind11::literals;
2021
void BindGeometry(py::module& m) {
2122
BindHomographyGeometry(m);
2223

23-
py::class_<Eigen::Quaterniond> PyRotation3d(m, "Rotation3d");
24+
py::class_ext_<Eigen::Quaterniond> PyRotation3d(m, "Rotation3d");
2425
PyRotation3d.def(py::init([]() { return Eigen::Quaterniond::Identity(); }))
2526
.def(py::init<const Eigen::Vector4d&>(),
2627
"xyzw"_a,
@@ -55,7 +56,7 @@ void BindGeometry(py::module& m) {
5556
py::implicitly_convertible<py::array, Eigen::Quaterniond>();
5657
MakeDataclass(PyRotation3d);
5758

58-
py::class_<Rigid3d> PyRigid3d(m, "Rigid3d");
59+
py::class_ext_<Rigid3d> PyRigid3d(m, "Rigid3d");
5960
PyRigid3d.def(py::init<>())
6061
.def(py::init<const Eigen::Quaterniond&, const Eigen::Vector3d&>())
6162
.def(py::init([](const Eigen::Matrix3x4d& matrix) {
@@ -87,7 +88,7 @@ void BindGeometry(py::module& m) {
8788
py::implicitly_convertible<py::array, Rigid3d>();
8889
MakeDataclass(PyRigid3d);
8990

90-
py::class_<Sim3d> PySim3d(m, "Sim3d");
91+
py::class_ext_<Sim3d> PySim3d(m, "Sim3d");
9192
PySim3d.def(py::init<>())
9293
.def(
9394
py::init<double, const Eigen::Quaterniond&, const Eigen::Vector3d&>())

pycolmap/pybind11_extension.h

+32
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,38 @@ struct type_caster<std::vector<Eigen::Matrix<Scalar, Size, 1>>> {
9696

9797
} // namespace detail
9898

99+
template <typename type_, typename... options>
100+
class class_ext_ : public class_<type_, options...> {
101+
public:
102+
using Parent = class_<type_, options...>;
103+
using Parent::class_; // inherit constructors
104+
using type = type_;
105+
106+
template <typename C, typename D, typename... Extra>
107+
class_ext_& def_readwrite(const char* name, D C::*pm, const Extra&... extra) {
108+
static_assert(
109+
std::is_same<C, type>::value || std::is_base_of<C, type>::value,
110+
"def_readwrite() requires a class member (or base class member)");
111+
cpp_function fget([pm](type&c) -> D& { return c.*pm; }, is_method(*this)),
112+
fset([pm](type&c, const D&value) { c.*pm = value; }, is_method(*this));
113+
this->def_property(
114+
name, fget, fset, return_value_policy::reference_internal, extra...);
115+
return *this;
116+
}
117+
118+
template <typename... Args>
119+
class_ext_& def(Args&&... args) {
120+
Parent::def(std::forward<Args>(args)...);
121+
return *this;
122+
}
123+
124+
template <typename... Args>
125+
class_ext_& def_property(Args&&... args) {
126+
Parent::def_property(std::forward<Args>(args)...);
127+
return *this;
128+
}
129+
};
130+
99131
// Fix long-standing bug https://github.com/pybind/pybind11/issues/4529
100132
// TODO(sarlinpe): remove when https://github.com/pybind/pybind11/pull/4972
101133
// appears in the next release of pybind11.

pycolmap/scene/point2D.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void BindPoint2D(py::module& m) {
4545
return repr;
4646
});
4747

48-
py::class_<Point2D, std::shared_ptr<Point2D>> PyPoint2D(m, "Point2D");
48+
py::class_ext_<Point2D, std::shared_ptr<Point2D>> PyPoint2D(m, "Point2D");
4949
PyPoint2D.def(py::init<>())
5050
.def(py::init<const Eigen::Vector2d&, size_t>(),
5151
"xy"_a,

pycolmap/scene/point3D.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void BindPoint3D(py::module& m) {
2626
std::to_string(self.size()) + ")";
2727
});
2828

29-
py::class_<Point3D, std::shared_ptr<Point3D>> PyPoint3D(m, "Point3D");
29+
py::class_ext_<Point3D, std::shared_ptr<Point3D>> PyPoint3D(m, "Point3D");
3030
PyPoint3D.def(py::init<>())
3131
.def_readwrite("xyz", &Point3D::xyz)
3232
.def_readwrite("color", &Point3D::color)

0 commit comments

Comments
 (0)