Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
835baf2
Simplify the XNAT data sink.
Mar 20, 2013
4782333
Make a SHA-1 digest directory name for long parameterizations.
Mar 20, 2013
c4cadb0
The ANTS 1.9 affine map is Affine.mat.
Mar 29, 2013
0617e6a
use_histogram_matching can be a Bool which applies to all transforms.…
May 21, 2013
e8eda0d
Add postfix file name input option.
May 21, 2013
0c2f849
Update the shape and slice dimension after a meta extension update.
May 21, 2013
bbdecd5
Remove unused line.
Mar 29, 2013
8113e55
Forego iterable param merge if there are no iterables.
May 21, 2013
45f66cf
Disable the ants Registration interface sigma_units vox value until i…
May 24, 2013
a6573c4
Format the smoothing-sigmas option without a unit if none is supplied.
May 24, 2013
a86e616
Set the input values.
Jun 14, 2013
63a0067
Remove unused line.
Mar 29, 2013
28df659
Remove unused kwargs Workflow parameter.
Jun 15, 2013
0001a9f
Add join feature.
Jul 24, 2013
d5fcb2d
Refactor join into JoinNode rather than a Join interface.
Jul 25, 2013
b897d8c
Clean up white space.
Jul 25, 2013
0b1d3da
Remove obsolete joinsource Node init parameter.
Jul 25, 2013
eb00354
Iterate with nodes_iter rather than nodes.
Jul 25, 2013
7976538
Support a Set join field.
Jul 25, 2013
cd30706
Support an identity join node.
Jul 26, 2013
3e3649a
The joinfield default is all input fields.
Jul 29, 2013
1382671
Support multiple join fields.
Jul 29, 2013
d620ba2
Support the unique flag.
Jul 29, 2013
c2bf451
Remove extraneous debug message.
Jul 29, 2013
cedcfac
Add new line.
Jul 30, 2013
3365581
Add itersource.
Jul 30, 2013
1bab9e1
Add itersource test.
Jul 30, 2013
385baec
The itersource iterables key is a tuple only if there is more than on…
Jul 30, 2013
306c78f
Remove spurious line.
Jul 30, 2013
f2b9305
Flush out the itersource test.
Jul 31, 2013
83d9d2e
Build the itersource iterables from the ancestor value lookup.
Jul 31, 2013
15012ff
Fix _iterable_nodes merge.
Jul 31, 2013
99b9f01
Support join on a node with an itersource.
Aug 1, 2013
6e92e9e
Add iterables synchronize.
Aug 2, 2013
a8e6e23
Clarify comment.
Aug 3, 2013
0830956
Refactor iterable standardization to account for the itersource alter…
Aug 3, 2013
088ec83
Test the alternate itersource iterables format.
Aug 3, 2013
37e01ce
Add the itersource field.
Aug 3, 2013
a9ba9c8
Support synchronize iterables tuple values.
Aug 3, 2013
766b5fc
Make the transposed synchronize iterable values a list.
Aug 4, 2013
114d026
Handle the alternate synchronize iterables format.
Aug 6, 2013
dee068d
Copy all join node fields from the override self._inputs to the base …
Aug 9, 2013
9246db6
Add a test for multiple join nodes.
Aug 9, 2013
96c9bac
The join node loop current node variable is jnode, not dest.
Aug 9, 2013
d8408f9
Use in rather than has_key for lookup.
Aug 9, 2013
80bf24e
Make joinsource a property and convert a setter node value to the nod…
Aug 14, 2013
cd340ea
Add whitespace.
Aug 16, 2013
8cd8347
Add join feature.
Jul 24, 2013
79bd1fd
Refactor join into JoinNode rather than a Join interface.
Jul 25, 2013
318a096
Clean up white space.
Jul 25, 2013
195abb1
Remove obsolete joinsource Node init parameter.
Jul 25, 2013
2a7bdda
Iterate with nodes_iter rather than nodes.
Jul 25, 2013
0a8caad
Support a Set join field.
Jul 25, 2013
ff7e157
Support an identity join node.
Jul 26, 2013
1928390
The joinfield default is all input fields.
Jul 29, 2013
49a37e9
Support multiple join fields.
Jul 29, 2013
3cdb9c1
Support the unique flag.
Jul 29, 2013
d289545
Remove extraneous debug message.
Jul 29, 2013
50b0ca6
Add new line.
Jul 30, 2013
2461482
Add itersource.
Jul 30, 2013
3d71609
Add itersource test.
Jul 30, 2013
8e8f0a9
The itersource iterables key is a tuple only if there is more than on…
Jul 30, 2013
e8ab121
Flush out the itersource test.
Jul 31, 2013
dbb091a
Build the itersource iterables from the ancestor value lookup.
Jul 31, 2013
8a90be5
Fix _iterable_nodes merge.
Jul 31, 2013
293b80f
Support join on a node with an itersource.
Aug 1, 2013
ae1a25a
Add iterables synchronize.
Aug 2, 2013
24b33d9
Clarify comment.
Aug 3, 2013
5f4a764
Refactor iterable standardization to account for the itersource alter…
Aug 3, 2013
d630887
Test the alternate itersource iterables format.
Aug 3, 2013
8ee3ea3
Add the itersource field.
Aug 3, 2013
aab06f5
Support synchronize iterables tuple values.
Aug 3, 2013
fa08d85
Make the transposed synchronize iterable values a list.
Aug 4, 2013
2a6bb8f
Handle the alternate synchronize iterables format.
Aug 6, 2013
9808609
Copy all join node fields from the override self._inputs to the base …
Aug 9, 2013
80383d0
Add a test for multiple join nodes.
Aug 9, 2013
9ece179
The join node loop current node variable is jnode, not dest.
Aug 9, 2013
c1eb0c1
Use in rather than has_key for lookup.
Aug 9, 2013
d7098af
Make joinsource a property and convert a setter node value to the nod…
Aug 14, 2013
192a7a9
Add whitespace.
Aug 16, 2013
a1c387c
Delete Nipype master merge artifact.
Sep 5, 2013
459cc3c
Add the JoinNode and itersource chapter to the index.
Sep 6, 2013
04e05d0
Add midstream itersource join tests.
Sep 6, 2013
bdfc4e4
Add JoinNode and itersource chapter to User Guide.
Sep 6, 2013
13a6bea
Improve JoinNode documentation.
Sep 6, 2013
d2106c0
Add JoinNode to pipeline init
Sep 9, 2013
c27f62d
Add JoinNode to nipype init
Sep 9, 2013
091b4a5
Flatten imports in JoinNode example.
Sep 9, 2013
8e26e7f
Fix the _add_join_item_fields doctest.
Sep 9, 2013
6d6a617
fix: tests
satra Sep 14, 2013
826cab1
sty: fixed white spaces
satra Sep 14, 2013
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support join on a node with an itersource.
  • Loading branch information
FredLoney authored and satra committed Sep 14, 2013
commit 293b80fdb8ccb658bd94e2d8bea1d0894854611b
82 changes: 73 additions & 9 deletions nipype/pipeline/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def _list_outputs(self):
outputs['output1'] = self.inputs.input1 + self.inputs.inc
return outputs

_sum = 0
_sums = []

_sum_operands = None
_sum_operands = []

class SumInputSpec(nib.TraitedSpec):
input1 = nib.traits.List(nib.traits.Int, mandatory=True, desc='input')
Expand All @@ -58,8 +58,10 @@ def _list_outputs(self):
global _sum
global _sum_operands
outputs = self._outputs().get()
_sum_operands = outputs['operands'] = self.inputs.input1
_sum = outputs['output1'] = sum(self.inputs.input1)
outputs['operands'] = self.inputs.input1
_sum_operands.append(outputs['operands'])
outputs['output1'] = sum(self.inputs.input1)
_sums.append(outputs['output1'])
return outputs


Expand Down Expand Up @@ -150,9 +152,11 @@ def test_join_expansion():
# Nipype factors away the IdentityInterface.
assert_equal(len(result.nodes()), 8, "The number of expanded nodes is incorrect.")
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
assert_equal(_sum, 7, "The join Sum output value is incorrect: %s." % _sum)
assert_equal(len(_sums), 1,
"The number of join outputs is incorrect")
assert_equal(_sums[0], 7, "The join Sum output value is incorrect: %s." % _sums[0])
# the join input preserves the iterables input order
assert_equal(_sum_operands, [3, 4], "The join Sum input is incorrect: %s." % _sum_operands)
assert_equal(_sum_operands[0], [3, 4], "The join Sum input is incorrect: %s." % _sum_operands[0])
# there are two iterations of the post-join node in the iterable path
assert_equal(len(_products), 2,
"The number of iterated post-join outputs is incorrect")
Expand Down Expand Up @@ -209,12 +213,14 @@ def test_unique_join_node():
wf.run()

# the join length is the number of unique inputs
assert_equal(_sum_operands, [4, 2, 3], "The unique join output value is incorrect: %s." % _sum_operands)
assert_equal(_sum_operands[0], [4, 2, 3], "The unique join output value is incorrect: %s." % _sum_operands[0])

os.chdir(cwd)
rmtree(wd)

def test_identity_join_node():
global _sum_operands
_sum_operands = []
cwd = os.getcwd()
wd = mkdtemp()
os.chdir(wd)
Expand All @@ -241,8 +247,9 @@ def test_identity_join_node():
# node and 1 post-join node. Nipype factors away the iterable input
# IdentityInterface but keeps the join IdentityInterface.
assert_equal(len(result.nodes()), 5, "The number of expanded nodes is incorrect.")
assert_equal(_sum_operands, [2, 3, 4],
"The join Sum input is incorrect: %s." %_sum_operands)
assert_equal(_sum_operands[0], [2, 3, 4],
"The join Sum input is incorrect: %s." %_sum_operands[0])

os.chdir(cwd)
rmtree(wd)

Expand Down Expand Up @@ -282,9 +289,66 @@ def test_multifield_join_node():
# the product inputs are [2, 4], [2, 5], [3, 4], [3, 5]
assert_equal(_products, [8, 10, 12, 15],
"The post-join products is incorrect: %s." % _products)

os.chdir(cwd)
rmtree(wd)

def test_itersource_join_source_node():
cwd = os.getcwd()
wd = mkdtemp()
os.chdir(wd)

# Make the workflow.
wf = pe.Workflow(name='test')
# the iterated input node
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
inputspec.iterables = [('n', [1, 2])]
# an intermediate node in the first iteration path
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
wf.connect(inputspec, 'n', pre_join1, 'input1')
# an iterable pre-join node with an itersource
pre_join2 = pe.Node(ProductInterface(), name='pre_join2')
pre_join2.itersource = 'inputspec'
pre_join2.iterables = ('input1', {1: [3, 4], 2: [5, 6]})
wf.connect(pre_join1, 'output1', pre_join2, 'input2')
# an intermediate node in the second iteration path
pre_join3 = pe.Node(IncrementInterface(), name='pre_join3')
wf.connect(pre_join2, 'output1', pre_join3, 'input1')
# the join node
join = pe.JoinNode(IdentityInterface(fields=['vector']), joinsource='pre_join2',
joinfield='vector', name='join')
wf.connect(pre_join3, 'output1', join, 'vector')
# a join successor node
post_join1 = pe.Node(SumInterface(), name='post_join1')
wf.connect(join, 'vector', post_join1, 'input1')

result = wf.run()

# the expanded graph contains
# 1 pre_join1 replicate for each inputspec iteration,
# 2 pre_join2 replicates for each inputspec iteration,
# 1 pre_join3 for each pre_join2 iteration,
# 1 join replicate for each inputspec iteration and
# 1 post_join1 replicate for each join replicate =
# 2 + (2 * 2) + 4 + 2 + 2 = 14 expansion graph nodes.
# Nipype factors away the iterable input
# IdentityInterface but keeps the join IdentityInterface.
assert_equal(len(result.nodes()), 14, "The number of expanded nodes is incorrect.")
# The first join inputs are:
# 1 + (3 * 2) and 1 + (4 * 2)
# The second join inputs are:
# 1 + (5 * 3) and 1 + (6 * 3)
# the post-join nodes execution order is indeterminate;
# therefore, compare the lists item-wise.
assert_true([16, 19] in _sum_operands,
"The join Sum input is incorrect: %s." % _sum_operands)
assert_true([7, 9] in _sum_operands,
"The join Sum input is incorrect: %s." % _sum_operands)

os.chdir(cwd)
rmtree(wd)


if __name__ == "__main__":
import nose

Expand Down
76 changes: 43 additions & 33 deletions nipype/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,34 @@ def generate_expanded_graph(graph_in):

# the iterable nodes
inodes = _iterable_nodes(graph_in)
logger.debug("Detected iterable nodes %s" % inodes)
# record the iterable fields, since expansion removes them
iter_fld_dict = {inode.name: inode.iterables.keys()
for inode in inodes}
# while there is an iterable node, expand the iterable node's
# subgraphs
while inodes:
inode = inodes[0]
logger.debug("Expanding the iterable node %s..." % inode)

# the join successor nodes of the current iterable node
jnodes = [node for node in graph_in.nodes_iter()
if hasattr(node, 'joinsource')
and inode.name == node.joinsource
and nx.has_path(graph_in, inode, node)]

# excise the join in-edges. save the excised edges in a
# {jnode: {source name: (destination name, edge data)}}
# dictionary
jedge_dict = {}
for jnode in jnodes:
in_edges = jedge_dict[jnode] = {}
for src, dest, data in graph_in.in_edges_iter(jnode, True):
in_edges[src._id] = data
graph_in.remove_edge(src, dest)
logger.debug("Excised the %s -> %s join node in-edge."
% (src, dest))

if inode.itersource:
# find the unique iterable source node in the graph
iter_src = None
Expand All @@ -607,8 +628,9 @@ def generate_expanded_graph(graph_in):
iter_src = node
break
if not iter_src or not iter_fld_dict.has_key(inode.itersource):
raise ValueError("Iterable node %s source node not found: %s"
% (inode, inode.itersource))
raise ValueError("The node %s itersource %s was not found"
" among the iterable nodes %s"
% (inode, inode.itersource, iter_fld_dict.keys()))
# look up the iterables for this particular itersource descendant
# using the iterable source ancestor values as a key
iterables = {}
Expand Down Expand Up @@ -674,16 +696,22 @@ def generate_expanded_graph(graph_in):
graph_in = _merge_graphs(graph_in, subnodes,
subgraph, inode._hierarchy + inode._id,
iterables, iterable_prefix)

# reconnect the join nodes
for jnode in jnodes:
# the {node name: replicated nodes} dictionary
node_name_dict = defaultdict(list)
# the {node id: edge data} dictionary for edges connecting
# to the join node in the unexpanded graph
old_edge_dict = jedge_dict[jnode]
# the edge source node replicates
expansions = defaultdict(list)
for node in graph_in.nodes_iter():
node_name_dict[node.name].append(node)
for src_id, edge_data in old_edge_dict.iteritems():
if node._id.startswith(src_id):
expansions[src_id].append(node)
# preserve the node iteration order by sorting on the node id
for nodes in node_name_dict.values():
nodes.sort(key=str)
for src_nodes in expansions.itervalues():
src_nodes.sort(key=lambda node: node._id)

# the number of iterations. this magic formula is borrowed
# from _merge_graphs.
iter_cnt = len(list(walk(iterables.items())))
Expand All @@ -698,30 +726,12 @@ def generate_expanded_graph(graph_in):
# field 'in' are qualified as ('out_file', 'in1') and
# ('out_file', 'in2'), resp. This preserves connection port
# integrity.
# the {source name: (destination name, edge data)} dictionary
in_edges = jedge_dict[jnode]
# for each join in-edge source name, the replicated source
# nodes are the expansion graph nodes which match on the
# source name. reconnect each replicated source node to the
# join node.
for src_name, tgt in in_edges.iteritems():
dest_name, edge_data = tgt
# there is a single join destination, since the join node
# is not expanded
dests = node_name_dict[dest_name]
if not dests:
raise Exception("The execution graph does not contain"
" the join node: %s" % dest_name)
elif len(dests) > 1:
raise Exception("The execution graph contains more than"
" one join node named %s: %s"
% (dest_name, dests))
else:
dest = dests[0]
for old_id, src_nodes in expansions.iteritems():
# reconnect each replication of the current join in-edge
# source
for si, src in enumerate(node_name_dict[src_name]):
newdata = deepcopy(edge_data)
for si, src in enumerate(src_nodes):
olddata = old_edge_dict[old_id]
newdata = deepcopy(olddata)
connects = newdata['connect']
join_fields = [field for _, field in connects
if field in dest.joinfield]
Expand All @@ -734,10 +744,10 @@ def generate_expanded_graph(graph_in):
connects[ci] = (src_field, slot_field)
logger.debug("Qualified the %s -> %s join field"
" %s as %s." %
(src, dest, dest_field, slot_field))
graph_in.add_edge(src, dest, newdata)
(src, jnode, dest_field, slot_field))
graph_in.add_edge(src, jnode, newdata)
logger.debug("Connected the join node %s subgraph to the"
" expanded join point %s" % (dest, src))
" expanded join point %s" % (jnode, src))

#nx.write_dot(graph_in, '%s_post.dot' % node)
# the remaining iterable nodes
Expand Down