@@ -584,23 +584,15 @@ def generate_expanded_graph(graph_in):
584584 """
585585 logger .debug ("PE: expanding iterables" )
586586 graph_in = _remove_nonjoin_identity_nodes (graph_in , keep_iterables = True )
587- # convert list of tuples to dict fields
587+ # standardize the iterables as {(field, function)} dictionaries
588588 for node in graph_in .nodes_iter ():
589- if isinstance (node .iterables , tuple ):
590- node .iterables = [node .iterables ]
591- for node in graph_in .nodes_iter ():
592- if isinstance (node .iterables , list ):
593- node .iterables = dict (map (lambda (x ): (x [0 ],
594- lambda : x [1 ]),
595- node .iterables ))
589+ if node .iterables :
590+ _standardize_iterables (node )
596591 allprefixes = list ('abcdefghijklmnopqrstuvwxyz' )
597592
598593 # the iterable nodes
599594 inodes = _iterable_nodes (graph_in )
600595 logger .debug ("Detected iterable nodes %s" % inodes )
601- # record the iterable fields, since expansion removes them
602- iter_fld_dict = {inode .name : inode .iterables .keys ()
603- for inode in inodes }
604596 # while there is an iterable node, expand the iterable node's
605597 # subgraphs
606598 while inodes :
@@ -626,22 +618,25 @@ def generate_expanded_graph(graph_in):
626618 % (src , dest ))
627619
628620 if inode .itersource :
621+ # the itersource is a (node name, fields) tuple
622+ src_name , src_fields = inode .itersource
623+ # convert a single field to a list
624+ if isinstance (src_fields , str ):
625+ src_fields = [src_fields ]
629626 # find the unique iterable source node in the graph
630- iter_src = None
631- for node in graph_in .nodes_iter ():
632- if (node .name == inode .itersource
633- and nx .has_path (graph_in , node , inode )):
634- iter_src = node
635- break
636- if not iter_src or not iter_fld_dict .has_key (inode .itersource ):
627+ try :
628+ iter_src = next ((node for node in graph_in .nodes_iter ()
629+ if node .name == src_name
630+ and nx .has_path (graph_in , node , inode )))
631+ except StopIteration :
637632 raise ValueError ("The node %s itersource %s was not found"
638- " among the iterable nodes %s"
639- % (inode , inode .itersource , iter_fld_dict .keys ()))
633+ " among the iterable predecessor nodes"
634+ % (inode , src_name ))
635+ logger .debug ("The node %s has iterable source node %s"
636+ % (inode , iter_src ))
640637 # look up the iterables for this particular itersource descendant
641638 # using the iterable source ancestor values as a key
642639 iterables = {}
643- # the source node iterables fields
644- src_fields = iter_fld_dict [inode .itersource ]
645640 # the source node iterables values
646641 src_values = [getattr (iter_src .inputs , field ) for field in src_fields ]
647642 # if there is one source field, then the key is the the source value,
@@ -714,9 +709,12 @@ def generate_expanded_graph(graph_in):
714709 for src_id , edge_data in old_edge_dict .iteritems ():
715710 if node ._id .startswith (src_id ):
716711 expansions [src_id ].append (node )
712+ for in_id , in_nodes in expansions .iteritems ():
713+ logger .debug ("The join node %s input %s was expanded"
714+ " to %d nodes." % (jnode , in_id , len (in_nodes )))
717715 # preserve the node iteration order by sorting on the node id
718- for src_nodes in expansions .itervalues ():
719- src_nodes .sort (key = lambda node : node ._id )
716+ for in_nodes in expansions .itervalues ():
717+ in_nodes .sort (key = lambda node : node ._id )
720718
721719 # the number of iterations.
722720 iter_cnt = count_iterables (iterables , inode .synchronize )
@@ -731,28 +729,28 @@ def generate_expanded_graph(graph_in):
731729 # field 'in' are qualified as ('out_file', 'in1') and
732730 # ('out_file', 'in2'), resp. This preserves connection port
733731 # integrity.
734- for old_id , src_nodes in expansions .iteritems ():
732+ for old_id , in_nodes in expansions .iteritems ():
735733 # reconnect each replication of the current join in-edge
736734 # source
737- for si , src in enumerate (src_nodes ):
735+ for in_idx , in_node in enumerate (in_nodes ):
738736 olddata = old_edge_dict [old_id ]
739737 newdata = deepcopy (olddata )
740738 connects = newdata ['connect' ]
741739 join_fields = [field for _ , field in connects
742740 if field in dest .joinfield ]
743- slots = slot_dicts [si ]
744- for ci , connect in enumerate (connects ):
741+ slots = slot_dicts [in_idx ]
742+ for con_idx , connect in enumerate (connects ):
745743 src_field , dest_field = connect
746744 # qualify a join destination field name
747745 if dest_field in slots :
748746 slot_field = slots [dest_field ]
749- connects [ci ] = (src_field , slot_field )
747+ connects [con_idx ] = (src_field , slot_field )
750748 logger .debug ("Qualified the %s -> %s join field"
751749 " %s as %s." %
752- (src , jnode , dest_field , slot_field ))
753- graph_in .add_edge (src , jnode , newdata )
750+ (in_node , jnode , dest_field , slot_field ))
751+ graph_in .add_edge (in_node , jnode , newdata )
754752 logger .debug ("Connected the join node %s subgraph to the"
755- " expanded join point %s" % (jnode , src ))
753+ " expanded join point %s" % (jnode , in_node ))
756754
757755 #nx.write_dot(graph_in, '%s_post.dot' % node)
758756 # the remaining iterable nodes
@@ -792,6 +790,65 @@ def _iterable_nodes(graph_in):
792790 inodes_src = [node for node in inodes if node .itersource ]
793791 inodes_no_src .reverse ()
794792 return inodes_no_src + inodes_src
793+
794+ def _standardize_iterables (node ):
795+ """Converts the given iterables to a {field: function} dictionary,
796+ if necessary, where the function returns a list."""
797+ # trivial case
798+ if not node .iterables :
799+ return
800+ iterables = node .iterables
801+ # The candidate iterable fields
802+ fields = set (node .inputs .copyable_trait_names ())
803+
804+ # Convert a tuple to a list
805+ if isinstance (iterables , tuple ):
806+ iterables = [iterables ]
807+ # Convert a list to a dictionary
808+ if isinstance (iterables , list ):
809+ # Synchronize iterables can be in [fields, value tuples] format
810+ # rather than [(field, value list), (field, value list), ...]
811+ if node .synchronize and len (iterables ) == 2 :
812+ first , last = iterables
813+ if all ((isinstance (item , str ) and item in fields
814+ for item in first )):
815+ iterables = _transpose_iterables (first , last )
816+ # Validate the format
817+ for item in iterables :
818+ try :
819+ if len (item ) != 2 :
820+ raise ValueError ("The %s iterables do not consist of"
821+ " (field, values) pairs" % node .name )
822+ except TypeError , e :
823+ raise TypeError ("The %s iterables is not iterable: %s"
824+ % (node .name , e ))
825+ # Convert the values to functions. This is a legacy Nipype
826+ # requirement with unknown rationale.
827+ iter_items = map (lambda (field , value ): (field , lambda : value ),
828+ iterables )
829+ # Make the iterables dictionary
830+ iterables = dict (iter_items )
831+ elif not isinstance (iterables , dict ):
832+ raise ValueError ("The %s iterables type is not a list or a dictionary:"
833+ " %s" % (node .name , iterables .__class__ ))
834+
835+ # Validate the iterable fields
836+ for field in iterables .iterkeys ():
837+ if field not in fields :
838+ raise ValueError ("The %s iterables field is unrecognized: %s"
839+ % (node .name , field ))
840+
841+ # Assign to the standard form
842+ node .iterables = iterables
843+
844+ def _transpose_iterables (fields , values ):
845+ """
846+ Converts the given fields and tuple values into a list of
847+ iterable (field: value list) pairs, suitable for setting
848+ a node iterables property.
849+ """
850+ return zip (fields , [filter (lambda (v ): v != None , transpose )
851+ for transpose in zip (* values )])
795852
796853def export_graph (graph_in , base_dir = None , show = False , use_execgraph = False ,
797854 show_connectinfo = False , dotfilename = 'graph.dot' , format = 'png' ,
0 commit comments