Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions nipype/interfaces/niftyreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import subprocess
from warnings import warn

from ..base import CommandLine, isdefined, CommandLineInputSpec, traits
from ..base import CommandLine, CommandLineInputSpec, traits, Undefined
from ...utils.filemanip import split_filename


Expand All @@ -47,8 +47,9 @@ def no_nifty_package(cmd='reg_f3d'):
class NiftyRegCommandInputSpec(CommandLineInputSpec):
"""Input Spec for niftyreg interfaces."""
# Set the number of omp thread to use
omp_core_val = traits.Int(desc='Number of openmp thread to use',
argstr='-omp %i')
omp_core_val = traits.Int(int(os.environ.get('OMP_NUM_THREADS', '1')),
desc='Number of openmp thread to use',
argstr='-omp %i', usedefault=True)


class NiftyRegCommand(CommandLine):
Expand All @@ -58,7 +59,10 @@ class NiftyRegCommand(CommandLine):
_suffix = '_nr'
_min_version = '1.5.30'

input_spec = NiftyRegCommandInputSpec

def __init__(self, required_version=None, **inputs):
self.num_threads = 1
super(NiftyRegCommand, self).__init__(**inputs)
self.required_version = required_version
_version = self.get_version()
Expand All @@ -73,6 +77,29 @@ def __init__(self, required_version=None, **inputs):
msg = 'The version of NiftyReg differs from the required'
msg += '(%s != %s)'
warn(msg % (_version, self.required_version))
self.inputs.on_trait_change(self._omp_update, 'omp_core_val')
self.inputs.on_trait_change(self._environ_update, 'environ')
self._omp_update()

def _omp_update(self):
if self.inputs.omp_core_val:
self.inputs.environ['OMP_NUM_THREADS'] = \
str(self.inputs.omp_core_val)
self.num_threads = self.inputs.omp_core_val
else:
if 'OMP_NUM_THREADS' in self.inputs.environ:
del self.inputs.environ['OMP_NUM_THREADS']
self.num_threads = 1

def _environ_update(self):
if self.inputs.environ:
if 'OMP_NUM_THREADS' in self.inputs.environ:
self.inputs.omp_core_val = \
int(self.inputs.environ['OMP_NUM_THREADS'])
else:
self.inputs.omp_core_val = Undefined
else:
self.inputs.omp_core_val = Undefined

def check_version(self):
_version = self.get_version()
Expand Down Expand Up @@ -102,13 +129,6 @@ def version(self):
def exists(self):
return self.get_version() is not None

def _run_interface(self, runtime):
# Update num threads estimate from OMP_NUM_THREADS env var
# Default to 1 if not set
if not isdefined(self.inputs.environ['OMP_NUM_THREADS']):
self.inputs.environ['OMP_NUM_THREADS'] = self.num_threads
return super(NiftyRegCommand, self)._run_interface(runtime)

def _format_arg(self, name, spec, value):
if name == 'omp_core_val':
self.numthreads = value
Expand Down
3 changes: 3 additions & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_NiftyRegCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def test_NiftyRegCommand_inputs():
ignore_exception=dict(nohash=True,
usedefault=True,
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
terminal_output=dict(nohash=True,
),
)
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegAladin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_RegAladin_inputs():
nosym_flag=dict(argstr='-noSym',
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
platform_val=dict(argstr='-platf %i',
),
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegAverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_RegAverage_inputs():
usedefault=True,
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='%s',
genfile=True,
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegF3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_RegF3D_inputs():
noz_flag=dict(argstr='-noz',
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
pad_val=dict(argstr='-pad %f',
),
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegJacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_RegJacobian_inputs():
usedefault=True,
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='%s',
name_source=['trans_file'],
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_RegMeasure_inputs():
mandatory=True,
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='-out %s',
name_source=['flo_file'],
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegResample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_RegResample_inputs():
inter_val=dict(argstr='-inter %d',
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='%s',
name_source=['flo_file'],
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_RegTools_inputs():
noscl_flag=dict(argstr='-noscl',
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='-out %s',
name_source=['in_file'],
Expand Down
1 change: 1 addition & 0 deletions nipype/interfaces/niftyreg/tests/test_auto_RegTransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_RegTransform_inputs():
xor=['def_input', 'disp_input', 'flow_input', 'comp_input', 'upd_s_form_input', 'inv_aff_input', 'inv_nrr_input', 'half_input', 'aff_2_rig_input', 'flirt_2_nr_input'],
),
omp_core_val=dict(argstr='-omp %i',
usedefault=True,
),
out_file=dict(argstr='%s',
genfile=True,
Expand Down
29 changes: 17 additions & 12 deletions nipype/interfaces/niftyreg/tests/test_regutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_reg_average():
two_file = example_data('im2.nii')
three_file = example_data('im3.nii')
nr_average.inputs.avg_files = [one_file, two_file, three_file]
nr_average.inputs.omp_core_val = 1
generated_cmd = nr_average.cmdline

# Read the reg_average_cmd
Expand All @@ -198,10 +199,10 @@ def test_reg_average():
argv = f_obj.read()
os.remove(reg_average_cmd)

expected_argv = '%s %s -avg %s %s %s' % (get_custom_path('reg_average'),
os.path.join(os.getcwd(),
'avg_out.nii.gz'),
one_file, two_file, three_file)
expected_argv = '%s %s -avg %s %s %s -omp 1' % (
get_custom_path('reg_average'),
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
one_file, two_file, three_file)

assert argv.decode('utf-8') == expected_argv

Expand All @@ -217,6 +218,7 @@ def test_reg_average():
two_file = example_data('ants_Affine.txt')
three_file = example_data('elastix.txt')
nr_average_2.inputs.avg_files = [one_file, two_file, three_file]
nr_average_2.inputs.omp_core_val = 1
generated_cmd = nr_average_2.cmdline

# Read the reg_average_cmd
Expand All @@ -225,10 +227,10 @@ def test_reg_average():
argv = f_obj.read()
os.remove(reg_average_cmd)

expected_argv = '%s %s -avg %s %s %s' % (get_custom_path('reg_average'),
os.path.join(os.getcwd(),
'avg_out.txt'),
one_file, two_file, three_file)
expected_argv = '%s %s -avg %s %s %s -omp 1' % (
get_custom_path('reg_average'),
os.path.join(os.getcwd(), 'avg_out.txt'),
one_file, two_file, three_file)

assert argv.decode('utf-8') == expected_argv

Expand All @@ -238,6 +240,7 @@ def test_reg_average():
two_file = example_data('ants_Affine.txt')
three_file = example_data('elastix.txt')
nr_average_3.inputs.avg_lts_files = [one_file, two_file, three_file]
nr_average_3.inputs.omp_core_val = 1
generated_cmd = nr_average_3.cmdline

# Read the reg_average_cmd
Expand All @@ -246,7 +249,7 @@ def test_reg_average():
argv = f_obj.read()
os.remove(reg_average_cmd)

expected_argv = ('%s %s -avg_lts %s %s %s'
expected_argv = ('%s %s -avg_lts %s %s %s -omp 1'
% (get_custom_path('reg_average'),
os.path.join(os.getcwd(), 'avg_out.txt'),
one_file, two_file, three_file))
Expand All @@ -266,6 +269,7 @@ def test_reg_average():
trans2_file, two_file,
trans3_file, three_file]
nr_average_4.inputs.avg_ref_file = ref_file
nr_average_4.inputs.omp_core_val = 1
generated_cmd = nr_average_4.cmdline

# Read the reg_average_cmd
Expand All @@ -274,12 +278,12 @@ def test_reg_average():
argv = f_obj.read()
os.remove(reg_average_cmd)

expected_argv = ('%s %s -avg_tran %s %s %s %s %s %s %s'
expected_argv = ('%s %s -avg_tran %s -omp 1 %s %s %s %s %s %s'
% (get_custom_path('reg_average'),
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
ref_file, trans1_file, one_file, trans2_file, two_file,
trans3_file, three_file))

assert argv.decode('utf-8') == expected_argv

# Test Reg Average: demean3
Expand All @@ -298,6 +302,7 @@ def test_reg_average():
aff2_file, trans2_file, two_file,
aff3_file, trans3_file, three_file]
nr_average_5.inputs.demean3_ref_file = ref_file
nr_average_5.inputs.omp_core_val = 1
generated_cmd = nr_average_5.cmdline

# Read the reg_average_cmd
Expand All @@ -306,7 +311,7 @@ def test_reg_average():
argv = f_obj.read()
os.remove(reg_average_cmd)

expected_argv = ('%s %s -demean3 %s %s %s %s %s %s %s %s %s %s'
expected_argv = ('%s %s -demean3 %s -omp 1 %s %s %s %s %s %s %s %s %s'
% (get_custom_path('reg_average'),
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
ref_file,
Expand Down
5 changes: 3 additions & 2 deletions nipype/interfaces/niftyseg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
See the docstrings of the individual classes for examples.
"""

from nipype.interfaces.niftyreg.base import NiftyRegCommand, no_nifty_package
from nipype.interfaces.niftyreg.base import no_nifty_package
from nipype.interfaces.niftyfit.base import NiftyFitCommand
import subprocess
import warnings

Expand All @@ -25,7 +26,7 @@
warnings.filterwarnings('always', category=UserWarning)


class NiftySegCommand(NiftyRegCommand):
class NiftySegCommand(NiftyFitCommand):
"""
Base support interface for NiftySeg commands.
"""
Expand Down