@@ -397,15 +397,19 @@ def attribute(
397397 if show_progress :
398398 attr_progress .update ()
399399 if agg_output_mode :
400- eval_diff = modified_eval - prev_results
400+ eval_diff = (modified_eval - prev_results ).to (
401+ inputs_tuple [0 ].device
402+ )
401403 prev_results = modified_eval
402404 else :
403405 # when perturb_per_eval > 1, every num_examples stands for
404406 # one perturb. Since the perturbs are from a consecutive
405407 # perumuation, each diff of a perturb is its eval minus
406408 # the eval of the previous perturb
407409 all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
408- eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
410+ eval_diff = (
411+ all_eval [num_examples :] - all_eval [:- num_examples ]
412+ ).to (inputs_tuple [0 ].device )
409413 prev_results = all_eval [- num_examples :]
410414
411415 for j in range (len (total_attrib )):
@@ -689,7 +693,7 @@ def _evalFutToPrevResultsTuple(
689693 agg_output_mode ,
690694 ) = prev_results_tuple
691695 if agg_output_mode :
692- eval_diff = modified_eval - prev_results
696+ eval_diff = ( modified_eval - prev_results ). to ( inputs_tuple [ 0 ]. device )
693697 prev_results = modified_eval
694698 else :
695699 # when perturb_per_eval > 1, every num_examples stands for
@@ -698,7 +702,9 @@ def _evalFutToPrevResultsTuple(
698702 # the eval of the previous perturb
699703
700704 all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
701- eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
705+ eval_diff = (all_eval [num_examples :] - all_eval [:- num_examples ]).to (
706+ inputs_tuple [0 ].device
707+ )
702708 prev_results = all_eval [- num_examples :]
703709
704710 for j in range (len (total_attrib )):
@@ -799,7 +805,10 @@ def _perturbation_generator(
799805 )
800806 current_tensors_list .append (current_tensors )
801807 current_mask_list .append (
802- tuple (mask == feature_permutation [i ] for mask in input_masks )
808+ tuple (
809+ (mask == feature_permutation [i ]).to (inputs [0 ].device )
810+ for mask in input_masks
811+ )
803812 )
804813 if len (current_tensors_list ) == perturbations_per_eval :
805814 combined_inputs = tuple (
0 commit comments