Skip to content

Commit 03a09d6

Browse files
Merge remote-tracking branch 'origin/devel' into support_comm
2 parents 5b0997d + 98fb397 commit 03a09d6

File tree

11 files changed

+174
-66
lines changed

11 files changed

+174
-66
lines changed

.github/workflows/build_wheel.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
rm -rf .git
7171
if: matrix.dp_pkg_name == 'deepmd-kit-cu11'
7272
- name: Build wheels
73-
uses: pypa/cibuildwheel@v3.0
73+
uses: pypa/cibuildwheel@v3.1
7474
env:
7575
CIBW_BUILD_VERBOSITY: 1
7676
CIBW_ARCHS: all

.github/workflows/package_c.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
tensorflow_version: ""
2323
filename: libdeepmd_c.tar.gz
2424
- tensorflow_build_version: "2.14"
25-
tensorflow_version: ">=2.5.0rc0,<2.15"
25+
tensorflow_version: ">=2.5.0,<2.15"
2626
filename: libdeepmd_c_cu11.tar.gz
2727
steps:
2828
- uses: actions/checkout@v4

.pre-commit-config.yaml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
exclude: ^source/3rdparty
3030
- repo: https://github.com/astral-sh/ruff-pre-commit
3131
# Ruff version.
32-
rev: v0.12.3
32+
rev: v0.12.7
3333
hooks:
3434
- id: ruff
3535
args: ["--fix"]
@@ -74,7 +74,7 @@ repos:
7474
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
7575
# Shell
7676
- repo: https://github.com/scop/pre-commit-shfmt
77-
rev: v3.12.0-1
77+
rev: v3.12.0-2
7878
hooks:
7979
- id: shfmt
8080
# CMake
@@ -83,25 +83,25 @@ repos:
8383
hooks:
8484
- id: cmake-format
8585
#- id: cmake-lint
86-
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
87-
# rev: v1.13.0
88-
# hooks:
89-
# - id: bibtex-tidy
90-
# args:
91-
# - --curly
92-
# - --numeric
93-
# - --align=13
94-
# - --blank-lines
95-
# # disable sort: the order of keys and fields has explict meanings
96-
# #- --sort=key
97-
# - --duplicates=key,doi,citation,abstract
98-
# - --merge=combine
99-
# #- --sort-fields
100-
# #- --strip-comments
101-
# - --trailing-commas
102-
# - --encode-urls
103-
# - --remove-empty-fields
104-
# - --wrap=80
86+
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
87+
rev: v1.14.0
88+
hooks:
89+
- id: bibtex-tidy
90+
args:
91+
- --curly
92+
- --numeric
93+
- --align=13
94+
- --blank-lines
95+
# disable sort: the order of keys and fields has explict meanings
96+
#- --sort=key
97+
- --duplicates=key,doi,citation,abstract
98+
- --merge=combine
99+
#- --sort-fields
100+
#- --strip-comments
101+
- --trailing-commas
102+
- --encode-urls
103+
- --remove-empty-fields
104+
- --wrap=80
105105
# license header
106106
- repo: https://github.com/Lucas-C/pre-commit-hooks
107107
rev: v1.5.5

backend/find_tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
8888
# CUDA 12.2, cudnn 9
8989
requires.extend(
9090
[
91-
"tensorflow-cpu>=2.18.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
91+
"tensorflow-cpu>=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'",
9292
]
9393
)
9494
elif cuda_version in SpecifierSet(">=11,<12"):
9595
# CUDA 11.8, cudnn 8
9696
requires.extend(
9797
[
98-
"tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",
98+
"tensorflow-cpu>=2.5.0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",
9999
]
100100
)
101101
tf_version = "2.14.1"

deepmd/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"select_idx_map",
4545
]
4646

47-
_PRECISION = Literal["default", "float16", "float32", "float64"]
47+
_PRECISION = Literal["default", "float16", "bfloat16", "float32", "float64"]
4848
_ACTIVATION = Literal[
4949
"relu",
5050
"relu6",

deepmd/pt/loss/property.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
var_name : str
4343
The atomic property to fit, 'energy', 'dipole', and 'polar'.
4444
loss_func : str
45-
The loss function, such as "smooth_mae", "mae", "rmse".
45+
The loss function, such as "smooth_mae", "mae", "rmse", "mape".
4646
metric : list
4747
The metric such as mae, rmse which will be printed.
4848
beta : float
@@ -151,6 +151,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
151151
reduction="mean",
152152
)
153153
)
154+
elif self.loss_func == "mape":
155+
loss += torch.mean(
156+
torch.abs(
157+
(label[var_name] - model_pred[var_name]) / (label[var_name] + 1e-3)
158+
)
159+
)
154160
else:
155161
raise RuntimeError(f"Unknown loss function : {self.loss_func}")
156162

@@ -182,6 +188,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
182188
reduction="mean",
183189
)
184190
).detach()
191+
if "mape" in self.metric:
192+
more_loss["mape"] = torch.mean(
193+
torch.abs(
194+
(label[var_name] - model_pred[var_name]) / (label[var_name] + 1e-3)
195+
)
196+
).detach()
185197

186198
return model_pred, loss, more_loss
187199

deepmd/pt/train/training.py

Lines changed: 111 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
self.num_steps = training_params["numb_steps"]
141141
self.disp_file = training_params.get("disp_file", "lcurve.out")
142142
self.disp_freq = training_params.get("disp_freq", 1000)
143+
self.disp_avg = training_params.get("disp_avg", False)
143144
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
144145
self.save_freq = training_params.get("save_freq", 1000)
145146
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
@@ -808,23 +809,75 @@ def fake_model():
808809
else:
809810
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
810811

812+
if self.disp_avg:
813+
# Accumulate loss for averaging over display interval
814+
self.step_count_in_interval += 1
815+
if not self.multi_task:
816+
# Accumulate loss for single task
817+
if not self.train_loss_accu:
818+
# Initialize accumulator with current loss structure
819+
for item in more_loss:
820+
if "l2_" not in item:
821+
self.train_loss_accu[item] = 0.0
822+
for item in more_loss:
823+
if "l2_" not in item:
824+
self.train_loss_accu[item] += more_loss[item]
825+
else:
826+
# Accumulate loss for multi-task
827+
if task_key not in self.train_loss_accu:
828+
self.train_loss_accu[task_key] = {}
829+
if task_key not in self.step_count_per_task:
830+
self.step_count_per_task[task_key] = 0
831+
self.step_count_per_task[task_key] += 1
832+
833+
for item in more_loss:
834+
if "l2_" not in item:
835+
if item not in self.train_loss_accu[task_key]:
836+
self.train_loss_accu[task_key][item] = 0.0
837+
self.train_loss_accu[task_key][item] += more_loss[item]
838+
811839
# Log and persist
812840
display_step_id = _step_id + 1
813841
if self.display_in_training and (
814842
display_step_id % self.disp_freq == 0 or display_step_id == 1
815843
):
816844
self.wrapper.eval() # Will set to train mode before fininshing validation
817845

818-
def log_loss_train(_loss, _more_loss, _task_key="Default"):
819-
results = {}
820-
rmse_val = {
821-
item: _more_loss[item]
822-
for item in _more_loss
823-
if "l2_" not in item
824-
}
825-
for item in sorted(rmse_val.keys()):
826-
results[item] = rmse_val[item]
827-
return results
846+
if self.disp_avg:
847+
848+
def log_loss_train(_loss, _more_loss, _task_key="Default"):
849+
results = {}
850+
if not self.multi_task:
851+
# Use accumulated average loss for single task
852+
for item in self.train_loss_accu:
853+
results[item] = (
854+
self.train_loss_accu[item]
855+
/ self.step_count_in_interval
856+
)
857+
else:
858+
# Use accumulated average loss for multi-task
859+
if (
860+
_task_key in self.train_loss_accu
861+
and _task_key in self.step_count_per_task
862+
):
863+
for item in self.train_loss_accu[_task_key]:
864+
results[item] = (
865+
self.train_loss_accu[_task_key][item]
866+
/ self.step_count_per_task[_task_key]
867+
)
868+
return results
869+
else:
870+
871+
def log_loss_train(_loss, _more_loss, _task_key="Default"):
872+
results = {}
873+
rmse_val = {
874+
item: _more_loss[item]
875+
for item in _more_loss
876+
if "l2_" not in item
877+
}
878+
for item in sorted(rmse_val.keys()):
879+
results[item] = rmse_val[item]
880+
return results
828881

829882
def log_loss_valid(_task_key="Default"):
830883
single_results = {}
@@ -882,24 +935,31 @@ def log_loss_valid(_task_key="Default"):
882935
else:
883936
train_results = {_key: {} for _key in self.model_keys}
884937
valid_results = {_key: {} for _key in self.model_keys}
885-
train_results[task_key] = log_loss_train(
886-
loss, more_loss, _task_key=task_key
887-
)
888-
for _key in self.model_keys:
889-
if _key != task_key:
890-
self.optimizer.zero_grad()
891-
input_dict, label_dict, _ = self.get_data(
892-
is_train=True, task_key=_key
893-
)
894-
_, loss, more_loss = self.wrapper(
895-
**input_dict,
896-
cur_lr=pref_lr,
897-
label=label_dict,
898-
task_key=_key,
899-
)
938+
if self.disp_avg:
939+
# For multi-task, use accumulated average loss for all tasks
940+
for _key in self.model_keys:
900941
train_results[_key] = log_loss_train(
901942
loss, more_loss, _task_key=_key
902943
)
944+
else:
945+
train_results[task_key] = log_loss_train(
946+
loss, more_loss, _task_key=task_key
947+
)
948+
for _key in self.model_keys:
949+
if _key != task_key:
950+
self.optimizer.zero_grad()
951+
input_dict, label_dict, _ = self.get_data(
952+
is_train=True, task_key=_key
953+
)
954+
_, loss, more_loss = self.wrapper(
955+
**input_dict,
956+
cur_lr=pref_lr,
957+
label=label_dict,
958+
task_key=_key,
959+
)
960+
train_results[_key] = log_loss_train(
961+
loss, more_loss, _task_key=_key
962+
)
903963
valid_results[_key] = log_loss_valid(_task_key=_key)
904964
if self.rank == 0:
905965
log.info(
@@ -921,6 +981,21 @@ def log_loss_valid(_task_key="Default"):
921981
)
922982
self.wrapper.train()
923983

984+
if self.disp_avg:
985+
# Reset loss accumulators after display
986+
if not self.multi_task:
987+
for item in self.train_loss_accu:
988+
self.train_loss_accu[item] = 0.0
989+
else:
990+
for task_key in self.model_keys:
991+
if task_key in self.train_loss_accu:
992+
for item in self.train_loss_accu[task_key]:
993+
self.train_loss_accu[task_key][item] = 0.0
994+
if task_key in self.step_count_per_task:
995+
self.step_count_per_task[task_key] = 0
996+
self.step_count_in_interval = 0
997+
self.last_display_step = display_step_id
998+
924999
current_time = time.time()
9251000
train_time = current_time - self.t0
9261001
self.t0 = current_time
@@ -993,6 +1068,17 @@ def log_loss_valid(_task_key="Default"):
9931068
self.t0 = time.time()
9941069
self.total_train_time = 0.0
9951070
self.timed_steps = 0
1071+
1072+
if self.disp_avg:
1073+
# Initialize loss accumulators
1074+
if not self.multi_task:
1075+
self.train_loss_accu = {}
1076+
else:
1077+
self.train_loss_accu = {key: {} for key in self.model_keys}
1078+
self.step_count_per_task = dict.fromkeys(self.model_keys, 0)
1079+
self.step_count_in_interval = 0
1080+
self.last_display_step = 0
1081+
9961082
for step_id in range(self.start_step, self.num_steps):
9971083
step(step_id)
9981084
if JIT:

deepmd/utils/argcheck.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,6 +3137,9 @@ def training_args(
31373137
)
31383138
doc_disp_training = "Displaying verbose information during training."
31393139
doc_time_training = "Timing during training."
3140+
doc_disp_avg = (
3141+
"Display the average loss over the display interval for training sets."
3142+
)
31403143
doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`."
31413144
doc_profiling_file = "Output file for profiling."
31423145
doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`."
@@ -3213,6 +3216,13 @@ def training_args(
32133216
Argument(
32143217
"time_training", bool, optional=True, default=True, doc=doc_time_training
32153218
),
3219+
Argument(
3220+
"disp_avg",
3221+
bool,
3222+
optional=True,
3223+
default=False,
3224+
doc=doc_only_pt_supported + doc_disp_avg,
3225+
),
32163226
Argument(
32173227
"profiling",
32183228
bool,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ test = [
9090
docs = [
9191
"sphinx>=3.1.1",
9292
"sphinx-book-theme",
93-
"myst-nb>=1.0.0rc0",
93+
"myst-nb>=1.0.0",
9494
"myst-parser>=0.19.2",
9595
"sphinx-design",
9696
"breathe",

source/api_cc/src/DeepPotPT.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
197197
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
198198
torch::Tensor sendlist_tensor =
199199
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
200-
comm_dict.insert("send_list", sendlist_tensor);
201-
comm_dict.insert("send_proc", sendproc_tensor);
202-
comm_dict.insert("recv_proc", recvproc_tensor);
203-
comm_dict.insert("send_num", sendnum_tensor);
204-
comm_dict.insert("recv_num", recvnum_tensor);
205-
comm_dict.insert("communicator", communicator_tensor);
200+
comm_dict.insert_or_assign("send_list", sendlist_tensor);
201+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
202+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
203+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
204+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
205+
comm_dict.insert_or_assign("communicator", communicator_tensor);
206206
}
207207
if (lmp_list.mapping) {
208208
std::vector<std::int64_t> mapping(nall_real);

0 commit comments

Comments
 (0)