Skip to content

Commit 4c264ae

Browse files
committed
.
1 parent b836327 commit 4c264ae

33 files changed

+216884
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,5 @@ First, download the release dataset and object meshes from https://github.com/Ts
3232
# This outputs 3d html visualization. ckeck dataset for object names (<object_name>.npy), and use --object_name <object_name> --num <num> to visualize the <num>th pose for the object named by <object_name>.
3333

3434
python visualization.py --object_name <object_name> --num <num>
35+
36+
One visualization example (both 2D screenshot and 3D html vis.) is in the directory examples/.

examples/example.html

Lines changed: 71 additions & 0 deletions
Large diffs are not rendered by default.

examples/example.png

99.6 KB
Loading

hand_model.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import os
2+
import json
3+
import numpy as np
4+
import torch
5+
from rot6d import robust_compute_rotation_matrix_from_ortho6d
6+
import pytorch_kinematics as pk
7+
import plotly.graph_objects as go
8+
import pytorch3d.structures
9+
import pytorch3d.ops
10+
import trimesh as tm
11+
from torchsdf import index_vertices_by_faces
12+
13+
14+
class HandModel:
15+
def __init__(self,
16+
mjcf_path,
17+
mesh_path,
18+
contact_points_path,
19+
penetration_points_path,
20+
n_surface_points=0,
21+
device='cpu',
22+
handedness=None):
23+
24+
self.device = device
25+
self.handedness = handedness
26+
27+
# load articulation
28+
29+
self.chain = pk.build_chain_from_mjcf(open(mjcf_path).read()).to(dtype=torch.float, device=device)
30+
self.n_dofs = len(self.chain.get_joint_parameter_names())
31+
32+
# load contact points and penetration points
33+
34+
contact_points = None
35+
if contact_points_path is not None:
36+
with open(contact_points_path, 'r') as f:
37+
contact_points = json.load(f)
38+
39+
penetration_points = None
40+
if penetration_points_path is not None:
41+
with open(penetration_points_path, 'r') as f:
42+
penetration_points = json.load(f)
43+
with open(contact_points_path, 'r') as f:
44+
contact_points = json.load(f)
45+
46+
penetration_points = None
47+
if penetration_points_path is not None:
48+
with open(penetration_points_path, 'r') as f:
49+
penetration_points = json.load(f)
50+
51+
# build mesh
52+
53+
self.mesh = {}
54+
areas = {}
55+
56+
def build_mesh_recurse(body):
57+
if len(body.link.visuals) > 0:
58+
link_name = body.link.name
59+
link_vertices = []
60+
link_faces = []
61+
n_link_vertices = 0
62+
for visual in body.link.visuals:
63+
scale = torch.tensor([1, 1, 1], dtype=torch.float, device=device)
64+
if visual.geom_type == "box":
65+
link_mesh = tm.load_mesh(os.path.join(mesh_path, 'box.obj'), process=False)
66+
link_mesh.vertices *= visual.geom_param.detach().cpu().numpy()
67+
elif visual.geom_type == "capsule":
68+
link_mesh = tm.primitives.Capsule(radius=visual.geom_param[0], height=visual.geom_param[1] * 2).apply_translation((0, 0, -visual.geom_param[1]))
69+
elif visual.geom_type == "mesh":
70+
link_mesh = tm.load_mesh(os.path.join(mesh_path, visual.geom_param[0].split(":")[1]+".obj"), process=False)
71+
if visual.geom_param[1] is not None:
72+
scale = torch.tensor(visual.geom_param[1], dtype=torch.float, device=device)
73+
vertices = torch.tensor(link_mesh.vertices, dtype=torch.float, device=device)
74+
faces = torch.tensor(link_mesh.faces, dtype=torch.long, device=device)
75+
pos = visual.offset.to(self.device)
76+
vertices = vertices * scale
77+
vertices = pos.transform_points(vertices)
78+
link_vertices.append(vertices)
79+
link_faces.append(faces + n_link_vertices)
80+
n_link_vertices += len(vertices)
81+
link_vertices = torch.cat(link_vertices, dim=0)
82+
link_faces = torch.cat(link_faces, dim=0)
83+
contact_candidates = torch.tensor(contact_points[link_name], dtype=torch.float32, device=device).reshape(-1, 3) if contact_points is not None else None
84+
penetration_keypoints = torch.tensor(penetration_points[link_name], dtype=torch.float32, device=device).reshape(-1, 3) if penetration_points is not None else None
85+
self.mesh[link_name] = {
86+
'vertices': link_vertices,
87+
'faces': link_faces,
88+
'contact_candidates': contact_candidates,
89+
'penetration_keypoints': penetration_keypoints,
90+
}
91+
if link_name in ['robot0:palm', 'robot0:palm_child', 'robot0:lfmetacarpal_child']:
92+
link_face_verts = index_vertices_by_faces(link_vertices, link_faces)
93+
self.mesh[link_name]['face_verts'] = link_face_verts
94+
else:
95+
self.mesh[link_name]['geom_param'] = body.link.visuals[0].geom_param
96+
areas[link_name] = tm.Trimesh(link_vertices.cpu().numpy(), link_faces.cpu().numpy()).area.item()
97+
for children in body.children:
98+
build_mesh_recurse(children)
99+
build_mesh_recurse(self.chain._root)
100+
101+
# set joint limits
102+
103+
self.joints_names = []
104+
self.joints_lower = []
105+
self.joints_upper = []
106+
107+
def set_joint_range_recurse(body):
108+
if body.joint.joint_type != "fixed":
109+
self.joints_names.append(body.joint.name)
110+
self.joints_lower.append(body.joint.range[0])
111+
self.joints_upper.append(body.joint.range[1])
112+
for children in body.children:
113+
set_joint_range_recurse(children)
114+
set_joint_range_recurse(self.chain._root)
115+
116+
if self.handedness.lower() == 'right_hand':
117+
self.joints_lower = torch.stack(self.joints_lower).float().to(device)
118+
self.joints_upper = torch.stack(self.joints_upper).float().to(device)
119+
elif self.handedness.lower() == 'left_hand':
120+
k = self.joints_lower
121+
self.joints_lower = -torch.stack(self.joints_upper).float().to(device)
122+
self.joints_upper = -torch.stack(k).float().to(device)
123+
else:
124+
raise Exception("You have to declare the handedness of your hand model")
125+
# sample surface points
126+
127+
total_area = sum(areas.values())
128+
num_samples = dict([(link_name, int(areas[link_name] / total_area * n_surface_points)) for link_name in self.mesh])
129+
num_samples[list(num_samples.keys())[0]] += n_surface_points - sum(num_samples.values())
130+
for link_name in self.mesh:
131+
if num_samples[link_name] == 0:
132+
self.mesh[link_name]['surface_points'] = torch.tensor([], dtype=torch.float, device=device).reshape(0, 3)
133+
continue
134+
mesh = pytorch3d.structures.Meshes(self.mesh[link_name]['vertices'].unsqueeze(0), self.mesh[link_name]['faces'].unsqueeze(0))
135+
dense_point_cloud = pytorch3d.ops.sample_points_from_meshes(mesh, num_samples=100 * num_samples[link_name])
136+
surface_points = pytorch3d.ops.sample_farthest_points(dense_point_cloud, K=num_samples[link_name])[0][0]
137+
self.mesh[link_name]['surface_points'] = surface_points.to(dtype=float, device=device)
138+
self.mesh[link_name]['surface_points'] = surface_points
139+
140+
# indexing
141+
142+
self.link_name_to_link_index = dict(zip([link_name for link_name in self.mesh], range(len(self.mesh))))
143+
144+
self.contact_candidates = [self.mesh[link_name]['contact_candidates'] for link_name in self.mesh]
145+
self.global_index_to_link_index = sum([[i] * len(contact_candidates) for i, contact_candidates in enumerate(self.contact_candidates)], [])
146+
self.contact_candidates = torch.cat(self.contact_candidates, dim=0)
147+
self.global_index_to_link_index = torch.tensor(self.global_index_to_link_index, dtype=torch.long, device=device)
148+
self.n_contact_candidates = self.contact_candidates.shape[0]
149+
150+
self.penetration_keypoints = [self.mesh[link_name]['penetration_keypoints'] for link_name in self.mesh]
151+
self.global_index_to_link_index_penetration = sum([[i] * len(penetration_keypoints) for i, penetration_keypoints in enumerate(self.penetration_keypoints)], [])
152+
self.penetration_keypoints = torch.cat(self.penetration_keypoints, dim=0)
153+
self.global_index_to_link_index_penetration = torch.tensor(self.global_index_to_link_index_penetration, dtype=torch.long, device=device)
154+
self.n_keypoints = self.penetration_keypoints.shape[0]
155+
156+
# parameters
157+
158+
self.hand_pose = None
159+
self.contact_point_indices = None
160+
self.global_translation = None
161+
self.global_rotation = None
162+
self.current_status = None
163+
self.contact_points = None
164+
165+
def set_parameters(self, hand_pose, contact_point_indices=None):
166+
"""
167+
Set translation, rotation, joint angles, and contact points of grasps
168+
"""
169+
170+
171+
self.hand_pose = hand_pose
172+
if self.hand_pose.requires_grad:
173+
self.hand_pose.retain_grad()
174+
self.global_translation = self.hand_pose[:, 0:3]
175+
self.global_rotation = robust_compute_rotation_matrix_from_ortho6d(self.hand_pose[:, 3:9])
176+
self.current_status = self.chain.forward_kinematics(self.hand_pose[:, 9:])
177+
if contact_point_indices is not None:
178+
self.contact_point_indices = contact_point_indices
179+
batch_size, n_contact = contact_point_indices.shape
180+
self.contact_points = self.contact_candidates[self.contact_point_indices]
181+
link_indices = self.global_index_to_link_index[self.contact_point_indices]
182+
transforms = torch.zeros(batch_size, n_contact, 4, 4, dtype=torch.float, device=self.device)
183+
for link_name in self.mesh:
184+
mask = link_indices == self.link_name_to_link_index[link_name]
185+
cur = self.current_status[link_name].get_matrix().unsqueeze(1).expand(batch_size, n_contact, 4, 4)
186+
transforms[mask] = cur[mask]
187+
self.contact_points = torch.cat([self.contact_points, torch.ones(batch_size, n_contact, 1, dtype=torch.float, device=self.device)], dim=2)
188+
self.contact_points = (transforms @ self.contact_points.unsqueeze(3))[:, :, :3, 0]
189+
self.contact_points = self.contact_points @ self.global_rotation.transpose(1, 2) + self.global_translation.unsqueeze(1)
190+
191+
self.surface_point = self.get_surface_points()
192+
193+
194+
195+
196+
def get_surface_points(self):
197+
"""
198+
Get surface points
199+
"""
200+
points = []
201+
batch_size = self.global_translation.shape[0]
202+
for link_name in self.mesh:
203+
n_surface_points = self.mesh[link_name]['surface_points'].shape[0]
204+
points.append(self.current_status[link_name].transform_points(self.mesh[link_name]['surface_points']))
205+
if 1 < batch_size != points[-1].shape[0]:
206+
points[-1] = points[-1].expand(batch_size, n_surface_points, 3)
207+
points = torch.cat(points, dim=-2).to(self.device)
208+
points = points @ self.global_rotation.transpose(1, 2) + self.global_translation.unsqueeze(1)
209+
return points
210+
211+
212+
213+
def get_plotly_data(self, i, opacity=0.5, color='lightblue', with_contact_points=False, pose=None):
214+
"""
215+
Get visualization data for plotly.graph_objects
216+
"""
217+
if pose is not None:
218+
pose = np.array(pose, dtype=np.float32)
219+
data = []
220+
for link_name in self.mesh:
221+
v = self.current_status[link_name].transform_points(self.mesh[link_name]['vertices'])
222+
if len(v.shape) == 3:
223+
v = v[i]
224+
v = v @ self.global_rotation[i].T + self.global_translation[i]
225+
v = v.detach().cpu()
226+
f = self.mesh[link_name]['faces'].detach().cpu()
227+
if pose is not None:
228+
v = v @ pose[:3, :3].T + pose[:3, 3]
229+
data.append(go.Mesh3d(x=v[:, 0], y=v[:, 1], z=v[:, 2], i=f[:, 0], j=f[:, 1], k=f[:, 2], color=color, opacity=opacity))
230+
if with_contact_points:
231+
contact_points = self.contact_points[i].detach().cpu()
232+
if pose is not None:
233+
contact_points = contact_points @ pose[:3, :3].T + pose[:3, 3]
234+
data.append(go.Scatter3d(x=contact_points[:, 0], y=contact_points[:, 1], z=contact_points[:, 2], mode='markers', marker=dict(color='red', size=5)))
235+
return data

0 commit comments

Comments
 (0)