Skip to content

Commit

Permalink
Autogenerate mjSpec find_* methods.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725255440
Change-Id: Ia9112f6c582ea9bb0af38aa97a917e72a98a84f5
  • Loading branch information
quagla authored and copybara-github committed Feb 10, 2025
1 parent 94115ec commit e54f3c2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 85 deletions.
72 changes: 46 additions & 26 deletions python/mujoco/codegen/generate_spec_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,34 @@

SCALAR_TYPES = {'int', 'double', 'float', 'mjtByte', 'mjtNum'}

# key, parent, default, listname, objtype
SPECS = [
('mjsBody', 'Body', True, 'bodies', 'mjOBJ_BODY'),
('mjsSite', 'Body', True, 'sites', 'mjOBJ_SITE'),
('mjsGeom', 'Body', True, 'geoms', 'mjOBJ_GEOM'),
('mjsJoint', 'Body', True, 'joints', 'mjOBJ_JOINT'),
('mjsCamera', 'Body', True, 'cameras', 'mjOBJ_CAMERA'),
('mjsFrame', 'Body', True, 'frames', 'mjOBJ_FRAME'),
('mjsLight', 'Body', True, 'lights', 'mjOBJ_LIGHT'),
('mjsFlex', 'Spec', False, 'flexes', 'mjOBJ_FLEX'),
('mjsMesh', 'Spec', True, 'meshes', 'mjOBJ_MESH'),
('mjsSkin', 'Spec', False, 'skins', 'mjOBJ_SKIN'),
('mjsHField', 'Spec', False, 'hfields', 'mjOBJ_HFIELD'),
('mjsTexture', 'Spec', False, 'textures', 'mjOBJ_TEXTURE'),
('mjsMaterial', 'Spec', True, 'materials', 'mjOBJ_MATERIAL'),
('mjsPair', 'Spec', True, 'pairs', 'mjOBJ_PAIR'),
('mjsEquality', 'Spec', True, 'equalities', 'mjOBJ_EQUALITY'),
('mjsTendon', 'Spec', True, 'tendons', 'mjOBJ_TENDON'),
('mjsActuator', 'Spec', True, 'actuators', 'mjOBJ_ACTUATOR'),
('mjsSensor', 'Spec', False, 'sensors', 'mjOBJ_SENSOR'),
('mjsNumeric', 'Spec', False, 'numerics', 'mjOBJ_NUMERIC'),
('mjsText', 'Spec', False, 'texts', 'mjOBJ_TEXT'),
('mjsTuple', 'Spec', False, 'tuples', 'mjOBJ_TUPLE'),
('mjsKey', 'Spec', False, 'keys', 'mjOBJ_KEY'),
('mjsExclude', 'Spec', False, 'excludes', 'mjOBJ_EXCLUDE'),
('mjsPlugin', 'Spec', False, 'plugins', 'mjOBJ_PLUGIN'),
]


def _value_binding_code(
field: ast_nodes.ValueType, classname: str = '', varname: str = ''
Expand Down Expand Up @@ -251,32 +279,7 @@ def generate() -> None:

def generate_add() -> None:
"""Generate add constructors with optional keyword arguments."""
for key, parent, default, listname, objtype in [
('mjsSite', 'Body', True, 'sites', 'mjOBJ_SITE'),
('mjsGeom', 'Body', True, 'geoms', 'mjOBJ_GEOM'),
('mjsJoint', 'Body', True, 'joints', 'mjOBJ_JOINT'),
('mjsLight', 'Body', True, 'lights', 'mjOBJ_LIGHT'),
('mjsCamera', 'Body', True, 'cameras', 'mjOBJ_CAMERA'),
('mjsBody', 'Body', True, 'bodies', 'mjOBJ_BODY'),
('mjsFrame', 'Body', True, 'frames', 'mjOBJ_FRAME'),
('mjsMaterial', 'Spec', True, 'materials', 'mjOBJ_MATERIAL'),
('mjsMesh', 'Spec', True, 'meshes', 'mjOBJ_MESH'),
('mjsPair', 'Spec', True, 'pairs', 'mjOBJ_PAIR'),
('mjsEquality', 'Spec', True, 'equalities', 'mjOBJ_EQUALITY'),
('mjsTendon', 'Spec', True, 'tendons', 'mjOBJ_TENDON'),
('mjsActuator', 'Spec', True, 'actuators', 'mjOBJ_ACTUATOR'),
('mjsSkin', 'Spec', False, 'skins', 'mjOBJ_SKIN'),
('mjsTexture', 'Spec', False, 'textures', 'mjOBJ_TEXTURE'),
('mjsText', 'Spec', False, 'texts', 'mjOBJ_TEXT'),
('mjsTuple', 'Spec', False, 'tuples', 'mjOBJ_TUPLE'),
('mjsFlex', 'Spec', False, 'flexes', 'mjOBJ_FLEX'),
('mjsHField', 'Spec', False, 'hfields', 'mjOBJ_HFIELD'),
('mjsKey', 'Spec', False, 'keys', 'mjOBJ_KEY'),
('mjsNumeric', 'Spec', False, 'numerics', 'mjOBJ_NUMERIC'),
('mjsExclude', 'Spec', False, 'excludes', 'mjOBJ_EXCLUDE'),
('mjsSensor', 'Spec', False, 'sensors', 'mjOBJ_SENSOR'),
('mjsPlugin', 'Spec', False, 'plugins', 'mjOBJ_PLUGIN'),
]:
for key, parent, default, listname, objtype in SPECS:

def _field(f: ast_nodes.StructFieldDecl):
if f.type == ast_nodes.PointerType(
Expand Down Expand Up @@ -573,11 +576,28 @@ def _field(f: ast_nodes.StructFieldDecl):
print(code)


def generate_find() -> None:
"""Generate find functions."""
for key, _, _, _, objtype in SPECS:
elem = key.removeprefix('mjs')
elemlower = elem.lower()
titlecase = 'Mjs' + elem
code = f"""\n
mjSpec.def("find_{elemlower}",
[](MjSpec& self, std::string& name) -> raw::{titlecase}* {{
return mjs_as{elem}(
mjs_findElement(self.ptr, {objtype}, name.c_str()));
}}, py::return_value_policy::reference_internal);
"""
print(code)


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
generate()
generate_add()
generate_find()


if __name__ == '__main__':
Expand Down
59 changes: 0 additions & 59 deletions python/mujoco/specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,71 +395,12 @@ PYBIND11_MODULE(_specs, m) {
return mjs_findBody(self.ptr, "world");
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_body",
[](MjSpec& self, std::string& name) -> raw::MjsBody* {
return mjs_findBody(self.ptr, name.c_str());
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_frame",
[](MjSpec& self, std::string& name) -> raw::MjsFrame* {
return mjs_findFrame(self.ptr, name.c_str());
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_site",
[](MjSpec& self, std::string& name) -> raw::MjsSite* {
return mjs_asSite(mjs_findElement(self.ptr, mjOBJ_SITE, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_actuator",
[](MjSpec& self, std::string& name) -> raw::MjsActuator* {
return mjs_asActuator(
mjs_findElement(self.ptr, mjOBJ_ACTUATOR, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_sensor",
[](MjSpec& self, std::string& name) -> raw::MjsSensor* {
return mjs_asSensor(
mjs_findElement(self.ptr, mjOBJ_SENSOR, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_default",
[](MjSpec& self, std::string& classname) -> const raw::MjsDefault* {
return mjs_findDefault(self.ptr, classname.c_str());
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_geom",
[](MjSpec& self, std::string& name) -> raw::MjsGeom* {
return mjs_asGeom(mjs_findElement(self.ptr, mjOBJ_GEOM, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_joint",
[](MjSpec& self, std::string& name) -> raw::MjsJoint* {
return mjs_asJoint(
mjs_findElement(self.ptr, mjOBJ_JOINT, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_light",
[](MjSpec& self, std::string& name) -> raw::MjsLight* {
return mjs_asLight(
mjs_findElement(self.ptr, mjOBJ_LIGHT, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def(
"find_camera",
[](MjSpec& self, std::string& name) -> raw::MjsCamera* {
return mjs_asCamera(
mjs_findElement(self.ptr, mjOBJ_CAMERA, name.c_str()));
},
py::return_value_policy::reference_internal);
mjSpec.def("compile", [mjmodel_from_spec_ptr](MjSpec& self) -> py::object {
if (self.assets.empty()) {
return mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr));
Expand Down

0 comments on commit e54f3c2

Please sign in to comment.