Skip to content

Commit 16172e6

Browse files
authored
fix(pt): keep mapping not none during lmp steps when nghost == 0 (#4209)
enhancement on #4144 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced tensor mapping capabilities with the addition of a new `mapping_tensor` variable. - Updated `compute` method to handle ghost atoms and support improved tensor creation logic. - Overloaded `computew` methods to support both double and float types. - **Bug Fixes** - Improved error handling in the `translate_error` method for better exception management. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5c092e6 commit 16172e6

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

source/api_cc/include/DeepPotPT.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ class DeepPotPT : public DeepPotBase {
338338
int do_message_passing; // 1:dpa2 model 0:others
339339
bool gpu_enabled;
340340
at::Tensor firstneigh_tensor;
341+
c10::optional<torch::Tensor> mapping_tensor;
341342
torch::Dict<std::string, torch::Tensor> comm_dict;
342343
/**
343344
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.

source/api_cc/src/DeepPotPT.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
164164
std::vector<std::int64_t> atype_64(datype.begin(), datype.end());
165165
at::Tensor atype_Tensor =
166166
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
167-
c10::optional<torch::Tensor> mapping_tensor;
168167
if (ago == 0) {
169168
nlist_data.copy_from_nlist(lmp_list);
170169
nlist_data.shuffle_exclude_empty(fwd_map);

0 commit comments

Comments
 (0)