Skip to content

Commit 1dc5b04

Browse files
fix: use tuple in xp.reshape (#4808)
This PR fixes the error in UT where the argument of `xp.reshape` should be a tuple, not a list. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Style** - Standardized the use of tuple syntax for shape arguments in array reshaping operations throughout the application, replacing previous list-based syntax. This change ensures consistency and aligns with best practices for array manipulation. No functional behavior is affected. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1d95c18 commit 1dc5b04

File tree

14 files changed

+43
-41
lines changed

14 files changed

+43
-41
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def xp_take_along_axis(arr, indices, axis):
6060

6161
shape = list(arr.shape)
6262
shape.pop(-1)
63-
shape = [*shape, n]
63+
shape = (*shape, n)
6464

6565
arr = xp.reshape(arr, (-1,))
6666
if n != 0:

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def call(
520520
type_embedding = self.type_embedding.call()
521521
# nf x nall x tebd_dim
522522
atype_embd_ext = xp.reshape(
523-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
523+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
524524
(nf, nall, self.tebd_dim),
525525
)
526526
# nfnl x tebd_dim
@@ -1027,7 +1027,7 @@ def call(
10271027
xp.tile(
10281028
(xp.reshape(atype, (-1, 1)) * ntypes_with_padding), (1, nnei)
10291029
),
1030-
(-1),
1030+
(-1,),
10311031
)
10321032
idx_j = xp.reshape(nei_type, (-1,))
10331033
# (nf x nl x nnei) x ng

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def call(
841841
type_embedding = self.type_embedding.call()
842842
# repinit
843843
g1_ext = xp.reshape(
844-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
844+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
845845
(nframes, nall, self.tebd_dim),
846846
)
847847
g1_inp = g1_ext[:, :nloc, :]

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,12 @@ def call(
562562
type_embedding = self.type_embedding.call()
563563
if self.use_loc_mapping:
564564
node_ebd_ext = xp.reshape(
565-
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
565+
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], (-1,)), axis=0),
566566
(nframes, nloc, self.tebd_dim),
567567
)
568568
else:
569569
node_ebd_ext = xp.reshape(
570-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
570+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
571571
(nframes, nall, self.tebd_dim),
572572
)
573573
node_ebd_inp = node_ebd_ext[:, :nloc, :]

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def call(
358358
type_embedding = self.type_embedding.call()
359359
# nf x nall x tebd_dim
360360
atype_embd_ext = xp.reshape(
361-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
361+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
362362
(nf, nall, self.tebd_dim),
363363
)
364364
# nfnl x tebd_dim

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def _call_common(
412412
)
413413
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
414414
fparam = xp.tile(
415-
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
415+
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
416416
)
417417
xx = xp.concat(
418418
[xx, fparam],
@@ -431,7 +431,7 @@ def _call_common(
431431
f"get an input aparam of dim {aparam.shape[-1]}, "
432432
f"which is not consistent with {self.numb_aparam}."
433433
)
434-
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
434+
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
435435
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
436436
xx = xp.concat(
437437
[xx, aparam],
@@ -446,7 +446,7 @@ def _call_common(
446446
if self.dim_case_embd > 0:
447447
assert self.case_embd is not None
448448
case_embd = xp.tile(
449-
xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1]
449+
xp.reshape(self.case_embd[...], (1, 1, -1)), (nf, nloc, 1)
450450
)
451451
xx = xp.concat(
452452
[xx, case_embd],
@@ -465,7 +465,7 @@ def _call_common(
465465
)
466466
for type_i in range(self.ntypes):
467467
mask = xp.tile(
468-
xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out)
468+
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
469469
)
470470
atom_property = self.nets[(type_i,)](xx)
471471
if self.remove_vaccum_contribution is not None and not (
@@ -485,10 +485,10 @@ def _call_common(
485485
outs += xp.reshape(
486486
xp.take(
487487
xp.astype(self.bias_atom_e[...], outs.dtype),
488-
xp.reshape(atype, [-1]),
488+
xp.reshape(atype, (-1,)),
489489
axis=0,
490490
),
491-
[nf, nloc, net_dim_out],
491+
(nf, nloc, net_dim_out),
492492
)
493493
# nf x nloc
494494
exclude_mask = self.emask.build_type_exclude_mask(atype)

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def call(
289289
]
290290
# out = out * self.scale[atype, ...]
291291
scale_atype = xp.reshape(
292-
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0),
292+
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0),
293293
(*atype.shape, 1),
294294
)
295295
out = out * scale_atype
@@ -315,7 +315,7 @@ def call(
315315
bias = xp.reshape(
316316
xp.take(
317317
xp.astype(self.constant_matrix, out.dtype),
318-
xp.reshape(atype, [-1]),
318+
xp.reshape(atype, (-1,)),
319319
axis=0,
320320
),
321321
(nframes, nloc),

deepmd/dpmodel/loss/ener.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,18 @@ def call(
132132
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
133133
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
134134
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
135-
force_reshape = xp.reshape(force, [-1])
136-
force_hat_reshape = xp.reshape(force_hat, [-1])
135+
force_reshape = xp.reshape(force, (-1,))
136+
force_hat_reshape = xp.reshape(force_hat, (-1,))
137137
diff_f = force_hat_reshape - force_reshape
138138
else:
139139
diff_f = None
140140

141141
if self.relative_f is not None:
142-
force_hat_3 = xp.reshape(force_hat, [-1, 3])
143-
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f
144-
diff_f_3 = xp.reshape(diff_f, [-1, 3])
142+
force_hat_3 = xp.reshape(force_hat, (-1, 3))
143+
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), (-1, 1)) + self.relative_f
144+
diff_f_3 = xp.reshape(diff_f, (-1, 3))
145145
diff_f_3 = diff_f_3 / norm_f
146-
diff_f = xp.reshape(diff_f_3, [-1])
146+
diff_f = xp.reshape(diff_f_3, (-1,))
147147

148148
atom_norm = 1.0 / natoms
149149
atom_norm_ener = 1.0 / natoms
@@ -184,15 +184,15 @@ def call(
184184
loss += pref_f * l2_force_loss
185185
else:
186186
l_huber_loss = custom_huber_loss(
187-
xp.reshape(force, [-1]),
188-
xp.reshape(force_hat, [-1]),
187+
xp.reshape(force, (-1,)),
188+
xp.reshape(force_hat, (-1,)),
189189
delta=self.huber_delta,
190190
)
191191
loss += pref_f * l_huber_loss
192192
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
193193
if self.has_v:
194-
virial_reshape = xp.reshape(virial, [-1])
195-
virial_hat_reshape = xp.reshape(virial_hat, [-1])
194+
virial_reshape = xp.reshape(virial, (-1,))
195+
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
196196
l2_virial_loss = xp.mean(
197197
xp.square(virial_hat_reshape - virial_reshape),
198198
)
@@ -207,8 +207,8 @@ def call(
207207
loss += pref_v * l_huber_loss
208208
more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial)
209209
if self.has_ae:
210-
atom_ener_reshape = xp.reshape(atom_ener, [-1])
211-
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
210+
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
211+
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
212212
l2_atom_ener_loss = xp.mean(
213213
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
214214
)
@@ -225,7 +225,7 @@ def call(
225225
l2_atom_ener_loss, find_atom_ener
226226
)
227227
if self.has_pf:
228-
atom_pref_reshape = xp.reshape(atom_pref, [-1])
228+
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
229229
l2_pref_force_loss = xp.mean(
230230
xp.multiply(xp.square(diff_f), atom_pref_reshape),
231231
)
@@ -236,10 +236,10 @@ def call(
236236
if self.has_gf:
237237
find_drdq = label_dict["find_drdq"]
238238
drdq = label_dict["drdq"]
239-
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
240-
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
239+
force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3))
240+
force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3))
241241
drdq_reshape = xp.reshape(
242-
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
242+
drdq, (-1, natoms[0] * 3, self.numb_generalized_coord)
243243
)
244244
gen_force_hat = xp.einsum(
245245
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes

deepmd/dpmodel/model/transform_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def communicate_extended_output(
100100
if vdef.r_differentiable:
101101
if model_ret[kk_derv_r] is not None:
102102
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
103-
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
103+
mapping = xp.reshape(
104+
mapping, tuple(mldims + [1] * len(derv_r_ext_dims))
105+
)
104106
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
105107
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
106108
force = xp_scatter_sum(

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def iter(
166166
self.last_dim,
167167
),
168168
)
169-
atype = xp.reshape(atype, (coord.shape[0] * coord.shape[1]))
169+
atype = xp.reshape(atype, (coord.shape[0] * coord.shape[1],))
170170
# (1, nloc) eq (ntypes, 1), so broadcast is possible
171171
# shape: (ntypes, nloc)
172172
type_idx = xp.equal(
@@ -189,7 +189,7 @@ def iter(
189189
for type_i in range(self.descriptor.get_ntypes()):
190190
dd = env_mat[type_idx[type_i, ...]]
191191
dd = xp.reshape(
192-
dd, [-1, self.last_dim]
192+
dd, (-1, self.last_dim)
193193
) # typen_atoms * unmasked_nnei, 4
194194
env_mats = {}
195195
env_mats[f"r_{type_i}"] = dd[:, :1]

deepmd/dpmodel/utils/exclude_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_type_exclude_mask(
5353
xp = array_api_compat.array_namespace(atype)
5454
nf, natom = atype.shape
5555
return xp.reshape(
56-
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
56+
xp.take(self.type_mask[...], xp.reshape(atype, (-1,)), axis=0),
5757
(nf, natom),
5858
)
5959

deepmd/dpmodel/utils/neighbor_stat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def call(
8282
nall = coord1.shape[1] // 3
8383
coord0 = coord1[:, : nloc * 3]
8484
diff = (
85-
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
86-
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
85+
xp.reshape(coord1, (nframes, -1, 3))[:, None, :, :]
86+
- xp.reshape(coord0, (nframes, -1, 3))[:, :, None, :]
8787
)
8888
assert list(diff.shape) == [nframes, nloc, nall, 3]
8989
# remove the diagonal elements

deepmd/dpmodel/utils/nlist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def build_neighbor_list(
115115
nsel = sum(sel)
116116
coord0 = coord1[:, : nloc * 3]
117117
diff = (
118-
xp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :]
119-
- xp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :]
118+
xp.reshape(coord1, (batch_size, -1, 3))[:, None, :, :]
119+
- xp.reshape(coord0, (batch_size, -1, 3))[:, :, None, :]
120120
)
121121
assert list(diff.shape) == [batch_size, nloc, nall, 3]
122122
rr = xp.linalg.vector_norm(diff, axis=-1)

deepmd/dpmodel/utils/region.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def to_face_distance(
9393
"""
9494
xp = array_api_compat.array_namespace(cell)
9595
cshape = cell.shape
96-
dist = b_to_face_distance(xp.reshape(cell, [-1, 3, 3]))
97-
return xp.reshape(dist, list(cshape[:-2]) + [3]) # noqa:RUF005
96+
dist = b_to_face_distance(xp.reshape(cell, (-1, 3, 3)))
97+
return xp.reshape(dist, tuple(list(cshape[:-2]) + [3])) # noqa:RUF005
9898

9999

100100
def b_to_face_distance(cell):

0 commit comments

Comments
 (0)