Skip to content

Commit 2303ff0

Browse files
authored
fix(pt): optimize createNlistTensor (#4403)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced tensor creation process for improved performance and efficiency. - **Bug Fixes** - Improved error handling for PyTorch-related exceptions, providing clearer error messages. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 5a93798 commit 2303ff0

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

source/api_cc/src/DeepPotPT.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,20 @@ void DeepPotPT::translate_error(std::function<void()> f) {
3131
}
3232

3333
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
34-
std::vector<torch::Tensor> row_tensors;
35-
34+
size_t total_size = 0;
3635
for (const auto& row : data) {
37-
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
38-
row_tensors.push_back(row_tensor);
36+
total_size += row.size();
3937
}
40-
41-
torch::Tensor tensor;
42-
if (row_tensors.size() > 0) {
43-
tensor = torch::cat(row_tensors, 0).unsqueeze(0);
44-
} else {
45-
tensor = torch::empty({1, 0, 0}, torch::kInt32);
38+
std::vector<int> flat_data;
39+
flat_data.reserve(total_size);
40+
for (const auto& row : data) {
41+
flat_data.insert(flat_data.end(), row.begin(), row.end());
4642
}
47-
return tensor;
43+
44+
torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32);
45+
int nloc = data.size();
46+
int nnei = nloc > 0 ? total_size / nloc : 0;
47+
return flat_tensor.view({1, nloc, nnei});
4848
}
4949
DeepPotPT::DeepPotPT() : inited(false) {}
5050
DeepPotPT::DeepPotPT(const std::string& model,

source/api_cc/src/DeepSpinPT.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,20 @@ void DeepSpinPT::translate_error(std::function<void()> f) {
3131
}
3232

3333
torch::Tensor createNlistTensor2(const std::vector<std::vector<int>>& data) {
34-
std::vector<torch::Tensor> row_tensors;
35-
34+
size_t total_size = 0;
3635
for (const auto& row : data) {
37-
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
38-
row_tensors.push_back(row_tensor);
36+
total_size += row.size();
3937
}
40-
41-
torch::Tensor tensor;
42-
if (row_tensors.size() > 0) {
43-
tensor = torch::cat(row_tensors, 0).unsqueeze(0);
44-
} else {
45-
tensor = torch::empty({1, 0, 0}, torch::kInt32);
38+
std::vector<int> flat_data;
39+
flat_data.reserve(total_size);
40+
for (const auto& row : data) {
41+
flat_data.insert(flat_data.end(), row.begin(), row.end());
4642
}
47-
return tensor;
43+
44+
torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32);
45+
int nloc = data.size();
46+
int nnei = nloc > 0 ? total_size / nloc : 0;
47+
return flat_tensor.view({1, nloc, nnei});
4848
}
4949
DeepSpinPT::DeepSpinPT() : inited(false) {}
5050
DeepSpinPT::DeepSpinPT(const std::string& model,

0 commit comments

Comments
 (0)