|
| 1 | +import os |
| 2 | +import json |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from rot6d import robust_compute_rotation_matrix_from_ortho6d |
| 6 | +import pytorch_kinematics as pk |
| 7 | +import plotly.graph_objects as go |
| 8 | +import pytorch3d.structures |
| 9 | +import pytorch3d.ops |
| 10 | +import trimesh as tm |
| 11 | +from torchsdf import index_vertices_by_faces |
| 12 | + |
| 13 | + |
| 14 | +class HandModel: |
| 15 | + def __init__(self, |
| 16 | + mjcf_path, |
| 17 | + mesh_path, |
| 18 | + contact_points_path, |
| 19 | + penetration_points_path, |
| 20 | + n_surface_points=0, |
| 21 | + device='cpu', |
| 22 | + handedness=None): |
| 23 | + |
| 24 | + self.device = device |
| 25 | + self.handedness = handedness |
| 26 | + |
| 27 | + # load articulation |
| 28 | + |
| 29 | + self.chain = pk.build_chain_from_mjcf(open(mjcf_path).read()).to(dtype=torch.float, device=device) |
| 30 | + self.n_dofs = len(self.chain.get_joint_parameter_names()) |
| 31 | + |
| 32 | + # load contact points and penetration points |
| 33 | + |
| 34 | + contact_points = None |
| 35 | + if contact_points_path is not None: |
| 36 | + with open(contact_points_path, 'r') as f: |
| 37 | + contact_points = json.load(f) |
| 38 | + |
| 39 | + penetration_points = None |
| 40 | + if penetration_points_path is not None: |
| 41 | + with open(penetration_points_path, 'r') as f: |
| 42 | + penetration_points = json.load(f) |
| 43 | + with open(contact_points_path, 'r') as f: |
| 44 | + contact_points = json.load(f) |
| 45 | + |
| 46 | + penetration_points = None |
| 47 | + if penetration_points_path is not None: |
| 48 | + with open(penetration_points_path, 'r') as f: |
| 49 | + penetration_points = json.load(f) |
| 50 | + |
| 51 | + # build mesh |
| 52 | + |
| 53 | + self.mesh = {} |
| 54 | + areas = {} |
| 55 | + |
| 56 | + def build_mesh_recurse(body): |
| 57 | + if len(body.link.visuals) > 0: |
| 58 | + link_name = body.link.name |
| 59 | + link_vertices = [] |
| 60 | + link_faces = [] |
| 61 | + n_link_vertices = 0 |
| 62 | + for visual in body.link.visuals: |
| 63 | + scale = torch.tensor([1, 1, 1], dtype=torch.float, device=device) |
| 64 | + if visual.geom_type == "box": |
| 65 | + link_mesh = tm.load_mesh(os.path.join(mesh_path, 'box.obj'), process=False) |
| 66 | + link_mesh.vertices *= visual.geom_param.detach().cpu().numpy() |
| 67 | + elif visual.geom_type == "capsule": |
| 68 | + link_mesh = tm.primitives.Capsule(radius=visual.geom_param[0], height=visual.geom_param[1] * 2).apply_translation((0, 0, -visual.geom_param[1])) |
| 69 | + elif visual.geom_type == "mesh": |
| 70 | + link_mesh = tm.load_mesh(os.path.join(mesh_path, visual.geom_param[0].split(":")[1]+".obj"), process=False) |
| 71 | + if visual.geom_param[1] is not None: |
| 72 | + scale = torch.tensor(visual.geom_param[1], dtype=torch.float, device=device) |
| 73 | + vertices = torch.tensor(link_mesh.vertices, dtype=torch.float, device=device) |
| 74 | + faces = torch.tensor(link_mesh.faces, dtype=torch.long, device=device) |
| 75 | + pos = visual.offset.to(self.device) |
| 76 | + vertices = vertices * scale |
| 77 | + vertices = pos.transform_points(vertices) |
| 78 | + link_vertices.append(vertices) |
| 79 | + link_faces.append(faces + n_link_vertices) |
| 80 | + n_link_vertices += len(vertices) |
| 81 | + link_vertices = torch.cat(link_vertices, dim=0) |
| 82 | + link_faces = torch.cat(link_faces, dim=0) |
| 83 | + contact_candidates = torch.tensor(contact_points[link_name], dtype=torch.float32, device=device).reshape(-1, 3) if contact_points is not None else None |
| 84 | + penetration_keypoints = torch.tensor(penetration_points[link_name], dtype=torch.float32, device=device).reshape(-1, 3) if penetration_points is not None else None |
| 85 | + self.mesh[link_name] = { |
| 86 | + 'vertices': link_vertices, |
| 87 | + 'faces': link_faces, |
| 88 | + 'contact_candidates': contact_candidates, |
| 89 | + 'penetration_keypoints': penetration_keypoints, |
| 90 | + } |
| 91 | + if link_name in ['robot0:palm', 'robot0:palm_child', 'robot0:lfmetacarpal_child']: |
| 92 | + link_face_verts = index_vertices_by_faces(link_vertices, link_faces) |
| 93 | + self.mesh[link_name]['face_verts'] = link_face_verts |
| 94 | + else: |
| 95 | + self.mesh[link_name]['geom_param'] = body.link.visuals[0].geom_param |
| 96 | + areas[link_name] = tm.Trimesh(link_vertices.cpu().numpy(), link_faces.cpu().numpy()).area.item() |
| 97 | + for children in body.children: |
| 98 | + build_mesh_recurse(children) |
| 99 | + build_mesh_recurse(self.chain._root) |
| 100 | + |
| 101 | + # set joint limits |
| 102 | + |
| 103 | + self.joints_names = [] |
| 104 | + self.joints_lower = [] |
| 105 | + self.joints_upper = [] |
| 106 | + |
| 107 | + def set_joint_range_recurse(body): |
| 108 | + if body.joint.joint_type != "fixed": |
| 109 | + self.joints_names.append(body.joint.name) |
| 110 | + self.joints_lower.append(body.joint.range[0]) |
| 111 | + self.joints_upper.append(body.joint.range[1]) |
| 112 | + for children in body.children: |
| 113 | + set_joint_range_recurse(children) |
| 114 | + set_joint_range_recurse(self.chain._root) |
| 115 | + |
| 116 | + if self.handedness.lower() == 'right_hand': |
| 117 | + self.joints_lower = torch.stack(self.joints_lower).float().to(device) |
| 118 | + self.joints_upper = torch.stack(self.joints_upper).float().to(device) |
| 119 | + elif self.handedness.lower() == 'left_hand': |
| 120 | + k = self.joints_lower |
| 121 | + self.joints_lower = -torch.stack(self.joints_upper).float().to(device) |
| 122 | + self.joints_upper = -torch.stack(k).float().to(device) |
| 123 | + else: |
| 124 | + raise Exception("You have to declare the handedness of your hand model") |
| 125 | + # sample surface points |
| 126 | + |
| 127 | + total_area = sum(areas.values()) |
| 128 | + num_samples = dict([(link_name, int(areas[link_name] / total_area * n_surface_points)) for link_name in self.mesh]) |
| 129 | + num_samples[list(num_samples.keys())[0]] += n_surface_points - sum(num_samples.values()) |
| 130 | + for link_name in self.mesh: |
| 131 | + if num_samples[link_name] == 0: |
| 132 | + self.mesh[link_name]['surface_points'] = torch.tensor([], dtype=torch.float, device=device).reshape(0, 3) |
| 133 | + continue |
| 134 | + mesh = pytorch3d.structures.Meshes(self.mesh[link_name]['vertices'].unsqueeze(0), self.mesh[link_name]['faces'].unsqueeze(0)) |
| 135 | + dense_point_cloud = pytorch3d.ops.sample_points_from_meshes(mesh, num_samples=100 * num_samples[link_name]) |
| 136 | + surface_points = pytorch3d.ops.sample_farthest_points(dense_point_cloud, K=num_samples[link_name])[0][0] |
| 137 | + self.mesh[link_name]['surface_points'] = surface_points.to(dtype=float, device=device) |
| 138 | + self.mesh[link_name]['surface_points'] = surface_points |
| 139 | + |
| 140 | + # indexing |
| 141 | + |
| 142 | + self.link_name_to_link_index = dict(zip([link_name for link_name in self.mesh], range(len(self.mesh)))) |
| 143 | + |
| 144 | + self.contact_candidates = [self.mesh[link_name]['contact_candidates'] for link_name in self.mesh] |
| 145 | + self.global_index_to_link_index = sum([[i] * len(contact_candidates) for i, contact_candidates in enumerate(self.contact_candidates)], []) |
| 146 | + self.contact_candidates = torch.cat(self.contact_candidates, dim=0) |
| 147 | + self.global_index_to_link_index = torch.tensor(self.global_index_to_link_index, dtype=torch.long, device=device) |
| 148 | + self.n_contact_candidates = self.contact_candidates.shape[0] |
| 149 | + |
| 150 | + self.penetration_keypoints = [self.mesh[link_name]['penetration_keypoints'] for link_name in self.mesh] |
| 151 | + self.global_index_to_link_index_penetration = sum([[i] * len(penetration_keypoints) for i, penetration_keypoints in enumerate(self.penetration_keypoints)], []) |
| 152 | + self.penetration_keypoints = torch.cat(self.penetration_keypoints, dim=0) |
| 153 | + self.global_index_to_link_index_penetration = torch.tensor(self.global_index_to_link_index_penetration, dtype=torch.long, device=device) |
| 154 | + self.n_keypoints = self.penetration_keypoints.shape[0] |
| 155 | + |
| 156 | + # parameters |
| 157 | + |
| 158 | + self.hand_pose = None |
| 159 | + self.contact_point_indices = None |
| 160 | + self.global_translation = None |
| 161 | + self.global_rotation = None |
| 162 | + self.current_status = None |
| 163 | + self.contact_points = None |
| 164 | + |
| 165 | + def set_parameters(self, hand_pose, contact_point_indices=None): |
| 166 | + """ |
| 167 | + Set translation, rotation, joint angles, and contact points of grasps |
| 168 | + """ |
| 169 | + |
| 170 | + |
| 171 | + self.hand_pose = hand_pose |
| 172 | + if self.hand_pose.requires_grad: |
| 173 | + self.hand_pose.retain_grad() |
| 174 | + self.global_translation = self.hand_pose[:, 0:3] |
| 175 | + self.global_rotation = robust_compute_rotation_matrix_from_ortho6d(self.hand_pose[:, 3:9]) |
| 176 | + self.current_status = self.chain.forward_kinematics(self.hand_pose[:, 9:]) |
| 177 | + if contact_point_indices is not None: |
| 178 | + self.contact_point_indices = contact_point_indices |
| 179 | + batch_size, n_contact = contact_point_indices.shape |
| 180 | + self.contact_points = self.contact_candidates[self.contact_point_indices] |
| 181 | + link_indices = self.global_index_to_link_index[self.contact_point_indices] |
| 182 | + transforms = torch.zeros(batch_size, n_contact, 4, 4, dtype=torch.float, device=self.device) |
| 183 | + for link_name in self.mesh: |
| 184 | + mask = link_indices == self.link_name_to_link_index[link_name] |
| 185 | + cur = self.current_status[link_name].get_matrix().unsqueeze(1).expand(batch_size, n_contact, 4, 4) |
| 186 | + transforms[mask] = cur[mask] |
| 187 | + self.contact_points = torch.cat([self.contact_points, torch.ones(batch_size, n_contact, 1, dtype=torch.float, device=self.device)], dim=2) |
| 188 | + self.contact_points = (transforms @ self.contact_points.unsqueeze(3))[:, :, :3, 0] |
| 189 | + self.contact_points = self.contact_points @ self.global_rotation.transpose(1, 2) + self.global_translation.unsqueeze(1) |
| 190 | + |
| 191 | + self.surface_point = self.get_surface_points() |
| 192 | + |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | + def get_surface_points(self): |
| 197 | + """ |
| 198 | + Get surface points |
| 199 | + """ |
| 200 | + points = [] |
| 201 | + batch_size = self.global_translation.shape[0] |
| 202 | + for link_name in self.mesh: |
| 203 | + n_surface_points = self.mesh[link_name]['surface_points'].shape[0] |
| 204 | + points.append(self.current_status[link_name].transform_points(self.mesh[link_name]['surface_points'])) |
| 205 | + if 1 < batch_size != points[-1].shape[0]: |
| 206 | + points[-1] = points[-1].expand(batch_size, n_surface_points, 3) |
| 207 | + points = torch.cat(points, dim=-2).to(self.device) |
| 208 | + points = points @ self.global_rotation.transpose(1, 2) + self.global_translation.unsqueeze(1) |
| 209 | + return points |
| 210 | + |
| 211 | + |
| 212 | + |
| 213 | + def get_plotly_data(self, i, opacity=0.5, color='lightblue', with_contact_points=False, pose=None): |
| 214 | + """ |
| 215 | + Get visualization data for plotly.graph_objects |
| 216 | + """ |
| 217 | + if pose is not None: |
| 218 | + pose = np.array(pose, dtype=np.float32) |
| 219 | + data = [] |
| 220 | + for link_name in self.mesh: |
| 221 | + v = self.current_status[link_name].transform_points(self.mesh[link_name]['vertices']) |
| 222 | + if len(v.shape) == 3: |
| 223 | + v = v[i] |
| 224 | + v = v @ self.global_rotation[i].T + self.global_translation[i] |
| 225 | + v = v.detach().cpu() |
| 226 | + f = self.mesh[link_name]['faces'].detach().cpu() |
| 227 | + if pose is not None: |
| 228 | + v = v @ pose[:3, :3].T + pose[:3, 3] |
| 229 | + data.append(go.Mesh3d(x=v[:, 0], y=v[:, 1], z=v[:, 2], i=f[:, 0], j=f[:, 1], k=f[:, 2], color=color, opacity=opacity)) |
| 230 | + if with_contact_points: |
| 231 | + contact_points = self.contact_points[i].detach().cpu() |
| 232 | + if pose is not None: |
| 233 | + contact_points = contact_points @ pose[:3, :3].T + pose[:3, 3] |
| 234 | + data.append(go.Scatter3d(x=contact_points[:, 0], y=contact_points[:, 1], z=contact_points[:, 2], mode='markers', marker=dict(color='red', size=5))) |
| 235 | + return data |
0 commit comments