Skip to content

Commit aeb59b5

Browse files
committed
Fix linters
1 parent 5f8063a commit aeb59b5

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

test/spmd/test_fsdp_v2.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,23 @@ def test_fsdp_v2_cpu_model(self):
142142
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
143143
def test_fsdp_v2_multi_slice(self):
144144
model = self.SimpleLinear().to(xm.xla_device())
145-
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
145+
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
146+
('data', 'fsdp', 'tensor'))
146147
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")
147148

148149
# Make sure all weights are sharded.
149150
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
150151
self.assertEqual(annotation,
151-
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
152+
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
152153
self.assertEqual(annotation,
153-
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))
154+
torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight))
154155

155156
x = torch.randn(16, 128).to(xm.xla_device())
156157
xs.mark_sharding(x, mesh, (('data', 'fsdp'), None))
157158
output = model(x)
158159
# Make sure output are sharded.
159160
annotation = '{devices=[4,1]0,2,1,3}'
160-
self.assertEqual(annotation,
161-
torch_xla._XLAC._get_xla_sharding_spec(output))
161+
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(output))
162162

163163
# Make sure the model can execute without error.
164164
xm.mark_step()
@@ -169,7 +169,8 @@ def test_fsdp_v2_multi_slice_output_correctness(self):
169169
model_expected = self.SimpleLinear().to(xm.xla_device())
170170

171171
model = copy.deepcopy(model_expected)
172-
mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))
172+
mesh = self._get_mesh((2, self.n_devices // 2, 1), None,
173+
('data', 'fsdp', 'tensor'))
173174
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")
174175

175176
x_expected = torch.randn(16, 128).to(xm.xla_device())
@@ -183,9 +184,12 @@ def test_fsdp_v2_multi_slice_output_correctness(self):
183184

184185
def test_fsdp_v2_multi_slice_error(self):
185186
model = self.SimpleLinear().to(xm.xla_device())
186-
xs.set_global_mesh(self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')))
187+
xs.set_global_mesh(
188+
self._get_mesh((2, self.n_devices // 2, 1), None,
189+
('data', 'fsdp', 'tensor')))
187190

188-
with self.assertRaisesRegex(ValueError, "The provided ddp axis is not in the mesh."):
191+
with self.assertRaisesRegex(ValueError,
192+
"The provided ddp axis is not in the mesh."):
189193
model = FSDPv2(model, extra_data_axis='ddp')
190194

191195

torch_xla/experimental/spmd_fully_sharded_data_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def __init__(
7979
if "fsdp" not in mesh.axis_names:
8080
raise ValueError("The mesh must have an axis named 'fsdp'.")
8181
if extra_data_axis and extra_data_axis not in mesh.axis_names:
82-
raise ValueError(f"The provided {extra_data_axis} axis is not in the mesh.")
82+
raise ValueError(
83+
f"The provided {extra_data_axis} axis is not in the mesh.")
8384

8485
super().__init__()
8586

@@ -136,8 +137,9 @@ def shard_output_impl(output, mesh):
136137
f"The output type is not supported: {type(output)}. Please provide your own shard_output callable."
137138
)
138139

139-
spmd.mark_sharding(real_output, mesh,
140-
_prepare_spmd_partition_spec(real_output, extra_data_axis))
140+
spmd.mark_sharding(
141+
real_output, mesh,
142+
_prepare_spmd_partition_spec(real_output, extra_data_axis))
141143

142144
shard_output = shard_output_impl
143145

0 commit comments

Comments
 (0)