@@ -142,23 +142,23 @@ def test_fsdp_v2_cpu_model(self):
142142 @unittest .skipIf (xr .device_type () !=  'TPU' , "This test only works on TPU." ) 
143143 def  test_fsdp_v2_multi_slice (self ):
144144 model  =  self .SimpleLinear ().to (xm .xla_device ())
145-  mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None , ('data' , 'fsdp' , 'tensor' ))
145+  mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None ,
146+  ('data' , 'fsdp' , 'tensor' ))
146147 model  =  FSDPv2 (model , mesh = mesh , extra_data_axis = "data" )
147148
148149 # Make sure all weights are sharded. 
149150 annotation  =  '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' 
150151 self .assertEqual (annotation ,
151-    torch_xla ._XLAC ._get_xla_sharding_spec (model .fc1 .weight ))
152+  torch_xla ._XLAC ._get_xla_sharding_spec (model .fc1 .weight ))
152153 self .assertEqual (annotation ,
153-    torch_xla ._XLAC ._get_xla_sharding_spec (model .fc2 .weight ))
154+  torch_xla ._XLAC ._get_xla_sharding_spec (model .fc2 .weight ))
154155
155156 x  =  torch .randn (16 , 128 ).to (xm .xla_device ())
156157 xs .mark_sharding (x , mesh , (('data' , 'fsdp' ), None ))
157158 output  =  model (x )
158159 # Make sure output are sharded. 
159160 annotation  =  '{devices=[4,1]0,2,1,3}' 
160-  self .assertEqual (annotation ,
161-  torch_xla ._XLAC ._get_xla_sharding_spec (output ))
161+  self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (output ))
162162
163163 # Make sure the model can execute without error. 
164164 xm .mark_step ()
@@ -169,7 +169,8 @@ def test_fsdp_v2_multi_slice_output_correctness(self):
169169 model_expected  =  self .SimpleLinear ().to (xm .xla_device ())
170170
171171 model  =  copy .deepcopy (model_expected )
172-  mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None , ('data' , 'fsdp' , 'tensor' ))
172+  mesh  =  self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None ,
173+  ('data' , 'fsdp' , 'tensor' ))
173174 model  =  FSDPv2 (model , mesh = mesh , extra_data_axis = "data" )
174175
175176 x_expected  =  torch .randn (16 , 128 ).to (xm .xla_device ())
@@ -183,9 +184,12 @@ def test_fsdp_v2_multi_slice_output_correctness(self):
183184
184185 def  test_fsdp_v2_multi_slice_error (self ):
185186 model  =  self .SimpleLinear ().to (xm .xla_device ())
186-  xs .set_global_mesh (self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None , ('data' , 'fsdp' , 'tensor' )))
187+  xs .set_global_mesh (
188+  self ._get_mesh ((2 , self .n_devices  //  2 , 1 ), None ,
189+  ('data' , 'fsdp' , 'tensor' )))
187190
188-  with  self .assertRaisesRegex (ValueError , "The provided ddp axis is not in the mesh." ):
191+  with  self .assertRaisesRegex (ValueError ,
192+  "The provided ddp axis is not in the mesh." ):
189193 model  =  FSDPv2 (model , extra_data_axis = 'ddp' )
190194
191195
0 commit comments