Skip to content

Commit 8b45941

Browse files
Use executorlib directly (#371)
* Use executorlib directly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ruff * error handling * use contextlib * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * extend tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * get_result() function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up * increase coverage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove old interface * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add backend test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * more tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use more recent lammps version for old tests * relax mpi4py * downgrade lammps * downgrade once more * working version * fix mpi4py * add test for installed packages * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test for command * fix test * test get_file * extend test for extract_atom * provide executor as additional argument * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * use more recent version of executorlib * remove broadcast * support for external executor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f47095f commit 8b45941

File tree

9 files changed

+195
-176
lines changed

9 files changed

+195
-176
lines changed

.ci_support/environment-old.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ dependencies:
55
- openmpi
66
- numpy =1.23.5
77
- mpi4py =4.0.1
8-
- executorlib =1.2.0
8+
- executorlib =1.3.0
99
- ase =3.23.0
1010
- scipy =1.9.3
1111
- hatchling

pylammpsmpi/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import pylammpsmpi._version
22
from pylammpsmpi.wrapper.base import LammpsBase
3-
from pylammpsmpi.wrapper.concurrent import LammpsConcurrent
3+
from pylammpsmpi.wrapper.concurrent import LammpsConcurrent, init_function
44
from pylammpsmpi.wrapper.extended import LammpsLibrary
55

6-
__all__ = ["LammpsLibrary", "LammpsConcurrent", "LammpsBase"]
6+
__all__ = ["LammpsLibrary", "LammpsConcurrent", "LammpsBase", "init_function"]
77
__version__ = pylammpsmpi._version.__version__
88

99

pylammpsmpi/mpi/lmpmpi.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
22
# Distributed under the terms of "New BSD License", see the LICENSE file.
33

4-
import sys
54
from ctypes import c_double, c_int
65

76
import numpy as np
8-
from executorlib.api import (
9-
interface_connect,
10-
interface_receive,
11-
interface_send,
12-
interface_shutdown,
13-
)
14-
from lammps import lammps
157
from mpi4py import MPI
168

179
__author__ = "Sarath Menon, Jan Janssen"
@@ -80,7 +72,10 @@ def convert_data(val, type, length, width):
8072
data=job.numpy.extract_compute(*filtered_args)
8173
)
8274
length = job.get_natoms()
83-
return convert_data(val=val, type=type, length=length, width=width)
75+
if MPI.COMM_WORLD.rank == 0:
76+
return convert_data(val=val, type=type, length=length, width=width)
77+
else:
78+
return val
8479
else: # Todo
8580
raise ValueError("Local style is currently not supported")
8681

@@ -156,7 +151,10 @@ def extract_variable(job, funct_args):
156151
data = _gather_data_from_all_processors(
157152
data=job.numpy.extract_variable(*funct_args)
158153
)
159-
return np.array(data)
154+
if MPI.COMM_WORLD.rank == 0:
155+
return np.array(data)
156+
else:
157+
return np.array([])
160158
else:
161159
# if type is 1 - reformat file
162160
try:
@@ -428,50 +426,7 @@ def select_cmd(argument):
428426

429427
def _gather_data_from_all_processors(data):
430428
data_gather = MPI.COMM_WORLD.gather(data, root=0)
431-
return [v for vl in data_gather for v in vl]
432-
433-
434-
def _run_lammps_mpi(argument_lst):
435-
index_selected = argument_lst.index("--zmqport")
436-
port_selected = argument_lst[index_selected + 1]
437-
if "--host" in argument_lst:
438-
index_selected = argument_lst.index("--host")
439-
host = argument_lst[index_selected + 1]
440-
else:
441-
host = "localhost"
442-
argument_red_lst = argument_lst[:index_selected]
443429
if MPI.COMM_WORLD.rank == 0:
444-
context, socket = interface_connect(host=host, port=port_selected)
430+
return [v for vl in data_gather for v in vl]
445431
else:
446-
context, socket = None, None
447-
# Lammps executable
448-
args = ["-screen", "none"]
449-
if len(argument_red_lst) > 1:
450-
args.extend(argument_red_lst[1:])
451-
job = lammps(cmdargs=args)
452-
while True:
453-
if MPI.COMM_WORLD.rank == 0:
454-
input_dict = interface_receive(socket=socket)
455-
else:
456-
input_dict = None
457-
input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0)
458-
if "shutdown" in input_dict and input_dict["shutdown"]:
459-
job.close()
460-
if MPI.COMM_WORLD.rank == 0:
461-
interface_send(socket=socket, result_dict={"result": True})
462-
interface_shutdown(socket=socket, context=context)
463-
break
464-
try:
465-
output = select_cmd(input_dict["command"])(
466-
job=job, funct_args=input_dict["args"]
467-
)
468-
except Exception as error:
469-
if MPI.COMM_WORLD.rank == 0:
470-
interface_send(socket=socket, result_dict={"error": error})
471-
else:
472-
if MPI.COMM_WORLD.rank == 0 and output is not None:
473-
interface_send(socket=socket, result_dict={"result": output})
474-
475-
476-
if __name__ == "__main__":
477-
_run_lammps_mpi(argument_lst=sys.argv)
432+
return []

pylammpsmpi/wrapper/ase.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ase.atoms import Atoms
99
from ase.calculators.lammps import Prism
1010
from ase.data import atomic_masses, atomic_numbers
11+
from executorlib import BaseExecutor
1112
from scipy import constants
1213

1314
from pylammpsmpi.wrapper.base import LammpsBase
@@ -25,6 +26,7 @@ class LammpsASELibrary:
2526
log_file (str, optional): The log file path. Defaults to None.
2627
library (object, optional): The LAMMPS library object. Defaults to None.
2728
disable_log_file (bool, optional): Whether to disable the log file. Defaults to True.
29+
executor: Executor to use for parallel execution (default: None)
2830
"""
2931

3032
def __init__(
@@ -36,6 +38,7 @@ def __init__(
3638
log_file: Optional[str] = None,
3739
library: Optional[object] = None,
3840
disable_log_file: bool = True,
41+
executor: Optional[BaseExecutor] = None,
3942
):
4043
self._logger = logger
4144
self._prism = None
@@ -60,7 +63,9 @@ def __init__(
6063
)
6164
else:
6265
self._interactive_library = LammpsBase(
63-
cores=self._cores, working_directory=working_directory
66+
cores=self._cores,
67+
working_directory=working_directory,
68+
executor=executor,
6469
)
6570

6671
def interactive_lib_command(self, command: str) -> None:

pylammpsmpi/wrapper/base.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
22
# Distributed under the terms of "New BSD License", see the LICENSE file.
33

4-
from typing import Union
4+
from concurrent.futures import Future
5+
from typing import Any, Union
56

67
from pylammpsmpi.wrapper.concurrent import LammpsConcurrent
78

@@ -17,6 +18,13 @@
1718
__date__ = "Feb 28, 2020"
1819

1920

21+
def get_result(future: Future, cores: int) -> Any:
22+
if cores > 1:
23+
return future.result()[0]
24+
else:
25+
return future.result()
26+
27+
2028
class LammpsBase(LammpsConcurrent):
2129
@property
2230
def version(self) -> str:
@@ -27,7 +35,7 @@ def version(self) -> str:
2735
version: str
2836
version string of lammps
2937
"""
30-
return super().version.result()
38+
return get_result(future=super().version, cores=self.cores)
3139

3240
def file(self, inputfile: str) -> None:
3341
"""
@@ -55,7 +63,7 @@ def extract_setting(self, *args) -> Union[int, float, str]:
5563
value: int, float, or str
5664
extracted setting value
5765
"""
58-
return super().extract_setting(*args).result()
66+
return get_result(future=super().extract_setting(*args), cores=self.cores)
5967

6068
def extract_global(self, name: str) -> Union[int, float, str]:
6169
"""
@@ -69,7 +77,7 @@ def extract_global(self, name: str) -> Union[int, float, str]:
6977
value: int, float, or str
7078
extracted value of the global parameter
7179
"""
72-
return super().extract_global(name=name).result()
80+
return get_result(future=super().extract_global(name=name), cores=self.cores)
7381

7482
def extract_box(self) -> list[Union[float, list[float], list[int]]]:
7583
"""
@@ -84,7 +92,7 @@ def extract_box(self) -> list[Union[float, list[float], list[int]]]:
8492
the box is periodic in three dimensions, and box_change is a list of booleans
8593
indicating if the box dimensions have changed
8694
"""
87-
return super().extract_box().result()
95+
return get_result(future=super().extract_box(), cores=self.cores)
8896

8997
def extract_atom(self, name: str) -> Union[list[int], list[float]]:
9098
"""
@@ -99,7 +107,7 @@ def extract_atom(self, name: str) -> Union[list[int], list[float]]:
99107
If the requested name has multiple dimensions, output
100108
will be a multi-dimensional list.
101109
"""
102-
return super().extract_atom(name=name).result()
110+
return get_result(future=super().extract_atom(name=name), cores=self.cores)
103111

104112
def extract_fix(self, *args) -> Union[int, float, list[Union[int, float]]]:
105113
"""
@@ -113,7 +121,7 @@ def extract_fix(self, *args) -> Union[int, float, list[Union[int, float]]]:
113121
value: int, float, or list of int or float
114122
extracted fix value corresponding to the requested dimensions
115123
"""
116-
return super().extract_fix(*args).result()
124+
return get_result(future=super().extract_fix(*args), cores=self.cores)
117125

118126
def extract_variable(self, *args) -> Union[int, float, list[Union[int, float]]]:
119127
"""
@@ -127,7 +135,7 @@ def extract_variable(self, *args) -> Union[int, float, list[Union[int, float]]]:
127135
data: int, float, or list of int or float
128136
value of the variable depending on the requested dimension
129137
"""
130-
return super().extract_variable(*args).result()
138+
return get_result(future=super().extract_variable(*args), cores=self.cores)
131139

132140
@property
133141
def natoms(self) -> int:
@@ -148,7 +156,7 @@ def get_natoms(self) -> int:
148156
natoms : int
149157
number of atoms
150158
"""
151-
return super().get_natoms().result()
159+
return get_result(future=super().get_natoms(), cores=self.cores)
152160

153161
def set_variable(self, *args) -> int:
154162
"""
@@ -162,7 +170,7 @@ def set_variable(self, *args) -> int:
162170
flag : int
163171
0 if successful, -1 otherwise
164172
"""
165-
return super().set_variable(*args).result()
173+
return get_result(future=super().set_variable(*args), cores=self.cores)
166174

167175
def reset_box(self, *args) -> None:
168176
"""
@@ -270,27 +278,27 @@ def create_atoms(
270278
@property
271279
def has_exceptions(self) -> bool:
272280
"""Return whether the LAMMPS shared library was compiled with C++ exceptions handling enabled"""
273-
return super().has_exceptions.result()
281+
return get_result(future=super().has_exceptions, cores=self.cores)
274282

275283
@property
276284
def has_gzip_support(self) -> bool:
277-
return super().has_gzip_support.result()
285+
return get_result(future=super().has_gzip_support, cores=self.cores)
278286

279287
@property
280288
def has_png_support(self) -> bool:
281-
return super().has_png_support.result()
289+
return get_result(future=super().has_png_support, cores=self.cores)
282290

283291
@property
284292
def has_jpeg_support(self) -> bool:
285-
return super().has_jpeg_support.result()
293+
return get_result(future=super().has_jpeg_support, cores=self.cores)
286294

287295
@property
288296
def has_ffmpeg_support(self) -> bool:
289-
return super().has_ffmpeg_support.result()
297+
return get_result(future=super().has_ffmpeg_support, cores=self.cores)
290298

291299
@property
292300
def installed_packages(self) -> list[str]:
293-
return super().installed_packages.result()
301+
return get_result(future=super().installed_packages, cores=self.cores)
294302

295303
def set_fix_external_callback(self, *args) -> None:
296304
_ = super().set_fix_external_callback(*args).result()
@@ -302,7 +310,7 @@ def get_neighlist(self, *args):
302310
:return: an instance of :class:`NeighList` wrapping access to neighbor list data
303311
:rtype: NeighList
304312
"""
305-
return super().get_neighlist(*args).result()
313+
return get_result(future=super().get_neighlist(*args), cores=self.cores)
306314

307315
def find_pair_neighlist(self, *args) -> int:
308316
"""Find neighbor list index of pair style neighbor list
@@ -324,7 +332,7 @@ def find_pair_neighlist(self, *args) -> int:
324332
:return: neighbor list index if found, otherwise -1
325333
:rtype: int
326334
"""
327-
return super().find_pair_neighlist(*args).result()
335+
return get_result(future=super().find_pair_neighlist(*args), cores=self.cores)
328336

329337
def find_fix_neighlist(self, *args):
330338
"""Find neighbor list index of fix neighbor list
@@ -335,7 +343,7 @@ def find_fix_neighlist(self, *args):
335343
:return: neighbor list index if found, otherwise -1
336344
:rtype: int
337345
"""
338-
return super().find_fix_neighlist(*args).result()
346+
return get_result(future=super().find_fix_neighlist(*args), cores=self.cores)
339347

340348
def find_compute_neighlist(self, *args):
341349
"""Find neighbor list index of compute neighbor list
@@ -346,7 +354,9 @@ def find_compute_neighlist(self, *args):
346354
:return: neighbor list index if found, otherwise -1
347355
:rtype: int
348356
"""
349-
return super().find_compute_neighlist(*args).result()
357+
return get_result(
358+
future=super().find_compute_neighlist(*args), cores=self.cores
359+
)
350360

351361
def get_neighlist_size(self, *args):
352362
"""Return the number of elements in neighbor list with the given index
@@ -355,10 +365,12 @@ def get_neighlist_size(self, *args):
355365
:return: number of elements in neighbor list with index idx
356366
:rtype: int
357367
"""
358-
return super().get_neighlist_size(*args).result()
368+
return get_result(future=super().get_neighlist_size(*args), cores=self.cores)
359369

360370
def get_neighlist_element_neighbors(self, *args):
361-
return super().get_neighlist_element_neighbors(*args).result()
371+
return get_result(
372+
future=super().get_neighlist_element_neighbors(*args), cores=self.cores
373+
)
362374

363375
def command(self, cmd):
364376
"""
@@ -407,7 +419,9 @@ def gather_atoms(self, *args, concat=False, ids=None):
407419
--------
408420
extract_atoms
409421
"""
410-
return super().gather_atoms(*args, concat=concat, ids=ids).result()
422+
return get_result(
423+
future=super().gather_atoms(*args, concat=concat, ids=ids), cores=self.cores
424+
)
411425

412426
def scatter_atoms(self, *args, ids=None):
413427
"""
@@ -433,7 +447,7 @@ def get_thermo(self, *args):
433447
value of the thermo keyword
434448
435449
"""
436-
return super().get_thermo(*args).result()
450+
return get_result(future=super().get_thermo(*args), cores=self.cores)
437451

438452
# TODO
439453
def extract_compute(self, id, style, type, length=0, width=0):
@@ -468,8 +482,9 @@ def extract_compute(self, id, style, type, length=0, width=0):
468482
data computed by the fix depending on the chosen inputs
469483
470484
"""
471-
return (
472-
super()
473-
.extract_compute(id=id, style=style, type=type, length=length, width=width)
474-
.result()
485+
return get_result(
486+
future=super().extract_compute(
487+
id=id, style=style, type=type, length=length, width=width
488+
),
489+
cores=self.cores,
475490
)

0 commit comments

Comments
 (0)