Skip to content

Commit 260bd25

Browse files
committed
BF: updated engine to prevent dual connections to same input port; DOC: several docstring cleanups in engine code
1 parent 7ab5333 commit 260bd25

File tree

2 files changed

+131
-25
lines changed

2 files changed

+131
-25
lines changed

nipype/pipeline/engine.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,19 @@ def outputs(self):
118118
raise NotImplementedError
119119

120120
def clone(self, name):
121-
if name is None:
121+
"""Clone a workflowbase object
122+
123+
Parameters
124+
----------
125+
126+
name : string (mandatory)
127+
A clone of node or workflow must have a new name
128+
"""
129+
if (name is None) or (name == self.name):
122130
raise Exception('Cloning requires a new name')
123-
if hasattr(self, '_flatgraph'):
124-
self._flatgraph = None
125-
if hasattr(self, '_execgraph'):
126-
self._execgraph = None
127131
clone = deepcopy(self)
128132
clone.name = name
129133
clone._id = name
130-
clone._reset_hierarchy()
131134
return clone
132135

133136
def _check_outputs(self, parameter):
@@ -154,7 +157,7 @@ def save(self, filename=None):
154157
np.savez(filename, object=self)
155158

156159
def load(self, filename):
157-
np.load(filename)
160+
return np.load(filename)
158161

159162
def _report_crash(self, traceback=None, execgraph=None):
160163
"""Writes crash related information to a file
@@ -205,7 +208,8 @@ def __init__(self, **kwargs):
205208
self._init_runtime_fields()
206209

207210
def _init_runtime_fields(self):
208-
# attributes for running with manager
211+
"""Reset runtime attributes to none
212+
"""
209213
self.procs = None
210214
self.depidx = None
211215
self.refidx = None
@@ -218,10 +222,30 @@ def _init_runtime_fields(self):
218222

219223
# PUBLIC API
220224
def clone(self, name):
225+
"""Clone a workflow
226+
227+
.. note::
228+
229+
Will reset attributes used for executing workflow. See
230+
_init_runtime_fields.
231+
232+
Parameters
233+
----------
234+
235+
name: string (mandatory )
236+
every clone requires a new name
237+
238+
"""
221239
self._init_runtime_fields()
222-
return super(Workflow, self).clone(name)
240+
clone = super(Workflow, self).clone(name)
241+
clone._reset_hierarchy()
242+
return clone
223243

224244
def disconnect(self, *args):
245+
"""Disconnect two nodes
246+
247+
See the docstring for connect for format.
248+
"""
225249
# yoh: explicit **dict was introduced for compatibility with Python 2.5
226250
return self.connect(*args, **dict(disconnect=True))
227251

@@ -237,6 +261,7 @@ def connect(self, *args, **kwargs):
237261
238262
Parameters
239263
----------
264+
240265
args : list or a set of four positional arguments
241266
242267
Four positional arguments of the form::
@@ -286,10 +311,21 @@ def connect(self, *args, **kwargs):
286311
if node._hierarchy is None:
287312
node._hierarchy = self.name
288313
for srcnode, destnode, connects in connection_list:
314+
connected_ports = []
315+
# check to see which ports of destnode are already
316+
# connected.
317+
if not disconnect and (destnode in self._graph.nodes()):
318+
for edge in self._graph.in_edges_iter(destnode):
319+
data = self._graph.get_edge_data(*edge)
320+
for sourceinfo, destname in data['connect']:
321+
connected_ports += [destname]
289322
for source, dest in connects:
290323
# Currently datasource/sink/grabber.io modules
291324
# determine their inputs/outputs depending on
292325
# connection settings. Skip these modules in the check
326+
if dest in connected_ports:
327+
raise Exception('Input %s of node %s is already ' \
328+
'connected'%(dest,destnode))
293329
if not (hasattr(destnode, '_interface') and '.io' in str(destnode._interface.__class__)):
294330
if not destnode._check_inputs(dest):
295331
not_found.append(['in', destnode.name, dest])
@@ -323,7 +359,12 @@ def connect(self, *args, **kwargs):
323359
if disconnect:
324360
logger.debug('Removing connection: %s'%str(data))
325361
edge_data['connect'].remove(data)
326-
self._graph.add_edges_from([(srcnode, destnode, edge_data)])
362+
if edge_data['connect']:
363+
self._graph.add_edges_from([(srcnode, destnode, edge_data)])
364+
else:
365+
#pass
366+
logger.debug('Removing connection: %s->%s'%(srcnode,destnode))
367+
self._graph.remove_edges_from([(srcnode, destnode)])
327368
elif not disconnect:
328369
logger.debug('(%s, %s): No edge data' % (srcnode, destnode))
329370
self._graph.add_edges_from([(srcnode, destnode,
@@ -333,8 +374,7 @@ def connect(self, *args, **kwargs):
333374
str(edge_data)))
334375

335376
def add_nodes(self, nodes):
336-
""" Wraps the networkx functionality in a more semantically
337-
relevant function name
377+
""" Add nodes to a workflow
338378
339379
Parameters
340380
----------
@@ -354,6 +394,16 @@ def add_nodes(self, nodes):
354394
node._hierarchy = self.name
355395
self._graph.add_nodes_from(newnodes)
356396

397+
def remove_nodes(self, nodes):
398+
""" Remove nodes from a workflow
399+
400+
Parameters
401+
----------
402+
nodes : list
403+
A list of WorkflowBase-based objects
404+
"""
405+
self._graph.remove_nodes_from(nodes)
406+
357407
@property
358408
def inputs(self):
359409
return self._get_inputs()
@@ -386,9 +436,18 @@ def get_node(self, name):
386436
outnode = None
387437
return outnode
388438

389-
def write_graph(self, dotfilename='graph.dot', graph2use='orig'):
390-
"""
391-
graph2use = 'orig', 'flat', 'exec'
439+
def write_graph(self, dotfilename='graph.dot', graph2use='flat'):
440+
"""Generates a graphviz dot file and a png file
441+
442+
Parameters
443+
----------
444+
445+
graph2use: 'orig', 'flat' (default), 'exec'
446+
orig - creates a top level graph without expanding internal
447+
workflow nodes
448+
flat - expands workflow nodes recursively
449+
exec - expands workflows to depict iterables
450+
392451
"""
393452
graph = self._graph
394453
if graph2use in ['flat', 'exec']:
@@ -423,7 +482,9 @@ def run(self, inseries=False, updatehash=False):
423482
# PRIVATE API AND FUNCTIONS
424483

425484
def _check_nodes(self, nodes):
426-
"docstring for _check_nodes"
485+
"""Checks if any of the nodes are already in the graph
486+
487+
"""
427488
node_names = [node.name for node in self._graph.nodes()]
428489
node_lineage = [node._hierarchy for node in self._graph.nodes()]
429490
for node in nodes:
@@ -435,6 +496,8 @@ def _check_nodes(self, nodes):
435496
node_names.append(node.name)
436497

437498
def _has_attr(self, parameter, subtype='in'):
499+
"""Checks if a parameter is available as an input or output
500+
"""
438501
if subtype == 'in':
439502
subobject = self.inputs
440503
else:
@@ -448,6 +511,9 @@ def _has_attr(self, parameter, subtype='in'):
448511
return True
449512

450513
def _get_parameter_node(self, parameter, subtype='in'):
514+
"""Returns the underlying node corresponding to an input or
515+
output parameter
516+
"""
451517
if subtype == 'in':
452518
subobject = self.inputs
453519
else:
@@ -465,6 +531,10 @@ def _check_inputs(self, parameter):
465531
return self._has_attr(parameter, subtype='in')
466532

467533
def _get_inputs(self):
534+
"""Returns the inputs of a workflow
535+
536+
This function does not return any input ports that are already connected
537+
"""
468538
inputdict = TraitedSpec()
469539
for node in self._graph.nodes():
470540
inputdict.add_trait(node.name, traits.Instance(TraitedSpec))
@@ -486,6 +556,8 @@ def _get_inputs(self):
486556
return inputdict
487557

488558
def _get_outputs(self):
559+
"""Returns all possible output ports that are not already connected
560+
"""
489561
outputdict = TraitedSpec()
490562
for node in self._graph.nodes():
491563
outputdict.add_trait(node.name, traits.Instance(TraitedSpec))
@@ -500,6 +572,8 @@ def _get_outputs(self):
500572
return outputdict
501573

502574
def _set_input(self, object, name, newvalue):
575+
"""Trait callback function to update a node input
576+
"""
503577
object.traits()[name].node.set_input(name, newvalue)
504578

505579
def _set_node_input(self, node, param, source, sourceinfo):
@@ -519,13 +593,17 @@ def _set_node_input(self, node, param, source, sourceinfo):
519593
node.set_input(param, deepcopy(newval))
520594

521595
def _create_flat_graph(self):
596+
"""Turn a hierarchical DAG into a simple DAG where no node is a workflow
597+
"""
522598
logger.debug('Creating flat graph for workflow: %s', self.name)
523599
self._init_runtime_fields()
524600
workflowcopy = deepcopy(self)
525-
workflowcopy._generate_execgraph()
601+
workflowcopy._generate_flatgraph()
526602
self._flatgraph = workflowcopy._graph
527603

528604
def _reset_hierarchy(self):
605+
"""Reset the hierarchy on a graph
606+
"""
529607
for node in self._graph.nodes():
530608
if isinstance(node, Workflow):
531609
node._reset_hierarchy()
@@ -534,7 +612,9 @@ def _reset_hierarchy(self):
534612
else:
535613
node._hierarchy = self.name
536614

537-
def _generate_execgraph(self):
615+
def _generate_flatgraph(self):
616+
"""Generate a graph containing only Nodes or MapNodes
617+
"""
538618
logger.debug('expanding workflow: %s', self)
539619
nodes2remove = []
540620
if not nx.is_directed_acyclic_graph(self._graph):
@@ -544,7 +624,10 @@ def _generate_execgraph(self):
544624
logger.debug('processing node: %s'%node)
545625
if isinstance(node, Workflow):
546626
nodes2remove.append(node)
547-
for u, _, d in self._graph.in_edges_iter(nbunch=node, data=True):
627+
# use in_edges instead of in_edges_iter to allow
628+
# disconnections to take place properly. otherwise, the
629+
# edge dict is modified.
630+
for u, _, d in self._graph.in_edges(nbunch=node, data=True):
548631
logger.debug('in: connections-> %s'%str(d['connect']))
549632
for cd in deepcopy(d['connect']):
550633
logger.debug("in: %s" % str (cd))
@@ -555,7 +638,8 @@ def _generate_execgraph(self):
555638
logger.debug('in edges: %s %s %s %s'%(srcnode, srcout, dstnode, dstin))
556639
self.disconnect(u, cd[0], node, cd[1])
557640
self.connect(srcnode, srcout, dstnode, dstin)
558-
for _, v, d in self._graph.out_edges_iter(nbunch=node, data=True):
641+
# do not use out_edges_iter for reasons stated in in_edges
642+
for _, v, d in self._graph.out_edges(nbunch=node, data=True):
559643
logger.debug('out: connections-> %s'%str(d['connect']))
560644
for cd in deepcopy(d['connect']):
561645
logger.debug("out: %s" % str (cd))
@@ -577,7 +661,7 @@ def _generate_execgraph(self):
577661
self.connect(srcnode, srcout, dstnode, dstin)
578662
# expand the workflow node
579663
#logger.debug('expanding workflow: %s', node)
580-
node._generate_execgraph()
664+
node._generate_flatgraph()
581665
for innernode in node._graph.nodes():
582666
innernode._hierarchy = '.'.join((self.name,innernode._hierarchy))
583667
self._graph.add_nodes_from(node._graph.nodes())
@@ -701,6 +785,8 @@ def _remove_node_deps(self, jobid, crashfile):
701785
crashfile = crashfile)
702786

703787
def _remove_node_dirs(self):
788+
"""Removes directories whos outputs have already been used up
789+
"""
704790
if config.getboolean('execution', 'remove_node_directories'):
705791
for idx in np.nonzero(np.all(self.refidx==0,axis=1))[0]:
706792
if self.proc_done[idx] and (not self.proc_pending[idx]):
@@ -739,7 +825,6 @@ def _execute_with_manager(self):
739825
self._execute_in_series()
740826
return
741827
logger.info("Running in parallel.")
742-
# self.taskclient.clear()
743828
# in the absence of a dirty bit on the object, generate the
744829
# parameterization each time before running
745830
# Generate appropriate structures for worker-manager model
@@ -944,10 +1029,9 @@ def run(self, updatehash=None, force_execute=False):
9441029
"""Executes an interface within a directory.
9451030
"""
9461031
# check to see if output directory and hash exist
947-
logger.info("Node: %s"%self._id)
9481032
outdir = self._output_directory()
9491033
outdir = make_output_dir(outdir)
950-
logger.info("in dir: %s"%outdir)
1034+
logger.info("Executing node %s in dir: %s"%(self._id,outdir))
9511035
# Get a dictionary with hashed filenames and a hashvalue
9521036
# of the dictionary itself.
9531037
hashed_inputs, hashvalue = self._get_hashval()
@@ -959,6 +1043,7 @@ def run(self, updatehash=None, force_execute=False):
9591043
self._save_hashfile(hashfile, hashed_inputs)
9601044
if force_execute or (not updatehash and (self.overwrite or not os.path.exists(hashfile))):
9611045
logger.debug("Node hash: %s"%hashvalue)
1046+
9621047
hashfile_unfinished = os.path.join(outdir, '_0x%s_unfinished.json' % hashvalue)
9631048
if os.path.exists(outdir) and not (os.path.exists(hashfile_unfinished) and self._interface.can_resume):
9641049
logger.debug("Removing old %s and its contents"%outdir)

nipype/pipeline/tests/test_engine.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def test4():
152152
mod1.iterables = dict(input1=lambda:[1,2])
153153
mod2.iterables = {}
154154
pipe.connect([(mod1,mod2,[('output1','input2')])])
155-
pipe.connect([(mod1,mod2,[('output1','input2')])])
156155
pipe._create_flat_graph()
157156
pipe._execgraph = pe._generate_expanded_graph(deepcopy(pipe._flatgraph))
158157
yield assert_equal(len(pipe._execgraph.nodes()), 4)
@@ -263,3 +262,25 @@ def test_iterable_expansion():
263262
wf3.add_nodes([wf1.clone(name='test%d'%i)])
264263
wf3._create_flat_graph()
265264
yield assert_equal(len(pe._generate_expanded_graph(wf3._flatgraph).nodes()),12)
265+
266+
def test_disconnect():
267+
import nipype.pipeline.engine as pe
268+
from nipype.interfaces.utility import IdentityInterface
269+
a = pe.Node(IdentityInterface(fields=['a','b']),name='a')
270+
b = pe.Node(IdentityInterface(fields=['a','b']),name='b')
271+
flow1 = pe.Workflow(name='test')
272+
flow1.connect(a,'a',b,'a')
273+
flow1.disconnect(a,'a',b,'a')
274+
yield assert_equal, flow1._graph.edges(), []
275+
276+
def test_doubleconnect():
277+
import nipype.pipeline.engine as pe
278+
from nipype.interfaces.utility import IdentityInterface
279+
a = pe.Node(IdentityInterface(fields=['a','b']),name='a')
280+
b = pe.Node(IdentityInterface(fields=['a','b']),name='b')
281+
flow1 = pe.Workflow(name='test')
282+
flow1.connect(a,'a',b,'a')
283+
x = lambda: flow1.connect(a,'b',b,'a')
284+
yield assert_raises, Exception, x
285+
286+

0 commit comments

Comments
 (0)