@@ -197,33 +197,15 @@ def test_wan_block(self):
197197 assert dummy_output .shape == dummy_hidden_states .shape
198198
199199 def test_wan_attention (self ):
200- pyconfig .initialize (
201- [
202- None ,
203- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
204- ],
205- unittest = True ,
206- )
207- config = pyconfig .config
200+ # pyconfig.initialize(
201+ # [
202+ # None,
203+ # os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
204+ # ],
205+ # unittest=True,
206+ # )
207+ # config = pyconfig.config
208208
209- batch_size = 1
210- channels = 16
211- frames = 21
212- height = 90
213- width = 160
214- hidden_states_shape = (batch_size , frames , height , width , channels )
215- dummy_hidden_states = jnp .ones (hidden_states_shape )
216- wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
217- dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
218-
219- key = jax .random .key (0 )
220- rngs = nnx .Rngs (key )
221- devices_array = create_device_mesh (config )
222-
223- mesh_axes = ['data' , 'fsdp' , 'tensor' ]
224- mesh = Mesh (devices_array , mesh_axes )
225- batch_size = 1
226- query_dim = 5120
227209 for attention_kernel in ["flash" , "tokamax_flash" ]:
228210 pyconfig .initialize (
229211 [
@@ -233,6 +215,22 @@ def test_wan_attention(self):
233215 ]
234216 )
235217 config = pyconfig .config
218+ batch_size = 1
219+ channels = 16
220+ frames = 21
221+ height = 90
222+ width = 160
223+ hidden_states_shape = (batch_size , frames , height , width , channels )
224+ dummy_hidden_states = jnp .ones (hidden_states_shape )
225+ wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
226+ dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
227+
228+ key = jax .random .key (0 )
229+ rngs = nnx .Rngs (key )
230+ devices_array = create_device_mesh (config )
231+ mesh = Mesh (devices_array , config .mesh_axes )
232+ batch_size = 1
233+ query_dim = 5120
236234 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
237235 flash_block_sizes = get_flash_block_sizes (config )
238236 attention = FlaxWanAttention (
0 commit comments