Skip to content

Commit 3aa9738

Browse files
quaglacopybara-github
authored andcommitted
Fast flex: Interpolate contact points directly.
For each contact point, we used to first compute the vertex weights on a triangle and then add a contact point per vertex using the flex interpolation (trilinear or quadratic), obtaining the total weight by multiplying the vertex weight and the basis function value at that vertex. After this change, each contact point is added directly by evaluating the basis function directly a the point. PiperOrigin-RevId: 834218909 Change-Id: Ic0fdb03fc0ef5478798c293df99204fd33116fd1
1 parent a2fd5fd commit 3aa9738

File tree

1 file changed

+94
-67
lines changed

1 file changed

+94
-67
lines changed

src/engine/engine_core_constraint.c

Lines changed: 94 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,16 @@ static int mj_elemBodyWeight(const mjModel* m, const mjData* d, int f, int e, in
183183

184184

185185
// compute body weights for a given contact vertex, return #bodies
186-
static int mj_vertBodyWeight(const mjModel* m, const mjData* d, int f, int v,
187-
int* body, mjtNum* weight, mjtNum bw) {
188-
mjtNum* coord = m->flex_vert0 + 3*v;
186+
static int mj_vertBodyWeight(const mjModel* m, const mjData* d, int f, int* v,
187+
int* body, mjtNum* bweight, const mjtNum* vweight, int nw) {
188+
if (nw == 0) {
189+
return 0;
190+
}
191+
192+
mjtNum coord[3] = {0, 0, 0};
193+
for (int i = 0; i < nw; i++) {
194+
mju_addToScl3(coord, m->flex_vert0 + 3*v[i], vweight[i]);
195+
}
189196
int nstart = m->flex_nodeadr[f];
190197
int nend = m->flex_nodeadr[f] + m->flex_nodenum[f];
191198
int nb = 0;
@@ -195,7 +202,7 @@ static int mj_vertBodyWeight(const mjModel* m, const mjData* d, int f, int v,
195202
if (w < 1e-5) {
196203
continue;
197204
}
198-
if (weight) weight[nb] = w * bw;
205+
if (bweight) bweight[nb] = w;
199206
body[nb++] = m->flex_nodebodyid[i];
200207
}
201208

@@ -899,43 +906,46 @@ int mj_contactJacobian(const mjModel* m, mjData* d, const mjContact* con, int di
899906
int bid[729]; // 729 = 27*27
900907
mjtNum bweight[729];
901908
for (int side=0; side < 2; side++) {
902-
int nw = 0;
903-
int vid[4];
904-
mjtNum bw[4];
905-
906909
// geom
907910
if (con->geom[side] >= 0) {
908911
bid[nb] = m->geom_bodyid[con->geom[side]];
909912
bweight[nb] = side ? +1 : -1;
910913
nb++;
911914
}
912915

913-
// flex vert
914-
else if (con->vert[side] >= 0) {
915-
vid[0] = m->flex_vertadr[con->flex[side]] + con->vert[side];
916-
bw[0] = side ? +1 : -1;
917-
nw = 1;
918-
}
919-
920-
// flex elem
916+
// flex
921917
else {
922-
nw = mj_elemBodyWeight(m, d, con->flex[side], con->elem[side],
923-
con->vert[1-side], con->pos, vid, bw);
918+
int nw = 0;
919+
int vid[4];
920+
mjtNum vweight[4];
924921

925-
// negative sign for first side of contact
926-
if (side == 0) {
927-
mju_scl(bw, bw, -1, nw);
922+
// vert
923+
if (con->vert[side] >= 0) {
924+
vid[0] = m->flex_vertadr[con->flex[side]] + con->vert[side];
925+
vweight[0] = side ? +1 : -1;
926+
nw = 1;
928927
}
929-
}
930928

931-
// get body or node ids and weights
932-
for (int k=0; k < nw; k++) {
929+
// elem
930+
else {
931+
nw = mj_elemBodyWeight(m, d, con->flex[side], con->elem[side],
932+
con->vert[1-side], con->pos, vid, vweight);
933+
934+
// negative sign for first side of contact
935+
if (side == 0) {
936+
mju_scl(vweight, vweight, -1, nw);
937+
}
938+
}
939+
940+
// get body or node ids and weights
933941
if (m->flex_interp[con->flex[side]] == 0) {
934-
bid[nb] = m->flex_vertbodyid[vid[k]];
935-
bweight[nb] = bw[k];
936-
nb++;
942+
for (int k=0; k < nw; k++) {
943+
bid[nb] = m->flex_vertbodyid[vid[k]];
944+
bweight[nb] = vweight[k];
945+
nb++;
946+
}
937947
} else {
938-
nb += mj_vertBodyWeight(m, d, con->flex[side], vid[k], bid+nb, bweight+nb, bw[k]);
948+
nb += mj_vertBodyWeight(m, d, con->flex[side], vid, bid+nb, bweight+nb, vweight, nw);
939949
}
940950
}
941951
}
@@ -1154,8 +1164,8 @@ void mj_diagApprox(const mjModel* m, mjData* d) {
11541164
tran = rot = 0;
11551165
for (int side=0; side < 2; side++) {
11561166
// get bodies and weights
1157-
int nb = 0, bid[729], vid[4], nw = 0;
1158-
mjtNum bweight[729], bw[4];
1167+
int nb = 0, bid[729];
1168+
mjtNum bweight[729];
11591169

11601170
// geom
11611171
if (con->geom[side] >= 0) {
@@ -1164,27 +1174,34 @@ void mj_diagApprox(const mjModel* m, mjData* d) {
11641174
nb = 1;
11651175
}
11661176

1167-
// flex vert
1168-
else if (con->vert[side] >= 0) {
1169-
vid[0] = m->flex_vertadr[con->flex[side]] + con->vert[side];
1170-
bw[0] = 1;
1171-
nw = 1;
1172-
}
1173-
1174-
// flex elem
1177+
// flex
11751178
else {
1176-
nw = mj_elemBodyWeight(m, d, con->flex[side], con->elem[side],
1177-
con->vert[1-side], con->pos, vid, bw);
1178-
}
1179+
int nw = 0;
1180+
int vid[4];
1181+
mjtNum vweight[4];
1182+
1183+
// vert
1184+
if (con->vert[side] >= 0) {
1185+
vid[0] = m->flex_vertadr[con->flex[side]] + con->vert[side];
1186+
vweight[0] = 1;
1187+
nw = 1;
1188+
}
11791189

1180-
// get body or node ids and weights
1181-
for (int k=0; k < nw; k++) {
1190+
// elem
1191+
else {
1192+
nw = mj_elemBodyWeight(m, d, con->flex[side], con->elem[side],
1193+
con->vert[1-side], con->pos, vid, vweight);
1194+
}
1195+
1196+
// convert verted ids and weights to body ids and weights
11821197
if (m->flex_interp[con->flex[side]] == 0) {
1183-
bid[k] = m->flex_vertbodyid[vid[k]];
1184-
bweight[k] = bw[k];
1185-
nb++;
1198+
for (int k=0; k < nw; k++) {
1199+
bid[k] = m->flex_vertbodyid[vid[k]];
1200+
bweight[k] = vweight[k];
1201+
nb++;
1202+
}
11861203
} else {
1187-
nb = mj_vertBodyWeight(m, d, con->flex[side], vid[k], bid, bweight, bw[k]);
1204+
nb += mj_vertBodyWeight(m, d, con->flex[side], vid, bid, bweight, vweight, nw);
11881205
}
11891206
}
11901207

@@ -1884,36 +1901,46 @@ static int mj_nc(const mjModel* m, mjData* d, int* nnz) {
18841901
// get bodies
18851902
int nb = 0, bid[729];
18861903
for (int side=0; side < 2; side++) {
1887-
int nw = 0;
1888-
int vid[4];
1889-
18901904
// geom
18911905
if (con->geom[side] >= 0) {
18921906
bid[nb++] = m->geom_bodyid[con->geom[side]];
18931907
}
18941908

1895-
// flex vert
1896-
else if (con->vert[side] >= 0) {
1897-
vid[nw++] = m->flex_vertadr[con->flex[side]] + con->vert[side];
1898-
}
1899-
1900-
// flex elem
1909+
// flex
19011910
else {
1902-
int f = con->flex[side];
1903-
int fdim = m->flex_dim[f];
1904-
const int* edata = m->flex_elem + m->flex_elemdataadr[f] + con->elem[side]*(fdim+1);
1905-
for (int k=0; k <= fdim; k++) {
1906-
vid[nw++] = m->flex_vertadr[f] + edata[k];
1911+
int nw = 0;
1912+
int vid[4];
1913+
mjtNum vweight[4];
1914+
1915+
// flex vert
1916+
if (con->vert[side] >= 0) {
1917+
vid[nw++] = m->flex_vertadr[con->flex[side]] + con->vert[side];
1918+
vweight[0] = 1;
19071919
}
1908-
}
19091920

1910-
// get body or node ids and weights
1911-
for (int k=0; k < nw; k++) {
1921+
// flex elem
1922+
else {
1923+
int f = con->flex[side];
1924+
int fdim = m->flex_dim[f];
1925+
const int* edata = m->flex_elem + m->flex_elemdataadr[f] + con->elem[side]*(fdim+1);
1926+
for (int k=0; k <= fdim; k++) {
1927+
vid[nw++] = m->flex_vertadr[f] + edata[k];
1928+
}
1929+
1930+
if (m->flex_interp[f]) {
1931+
nw = mj_elemBodyWeight(m, d, con->flex[side], con->elem[side],
1932+
con->vert[1-side], con->pos, vid, vweight);
1933+
}
1934+
}
1935+
1936+
// get body or node ids and weights
19121937
if (m->flex_interp[con->flex[side]] == 0) {
1913-
bid[nb] = m->flex_vertbodyid[vid[k]];
1914-
nb++;
1938+
for (int k=0; k < nw; k++) {
1939+
bid[nb] = m->flex_vertbodyid[vid[k]];
1940+
nb++;
1941+
}
19151942
} else {
1916-
nb += mj_vertBodyWeight(m, d, con->flex[side], vid[k], bid + nb, NULL, 0);
1943+
nb += mj_vertBodyWeight(m, d, con->flex[side], vid, bid+nb, NULL, vweight, nw);
19171944
}
19181945
}
19191946
}

0 commit comments

Comments
 (0)