Skip to content

Commit c379559

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 7238105 commit c379559

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,33 +197,15 @@ def test_wan_block(self):
197197
assert dummy_output.shape == dummy_hidden_states.shape
198198

199199
def test_wan_attention(self):
200-
pyconfig.initialize(
201-
[
202-
None,
203-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
204-
],
205-
unittest=True,
206-
)
207-
config = pyconfig.config
200+
# pyconfig.initialize(
201+
# [
202+
# None,
203+
# os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
204+
# ],
205+
# unittest=True,
206+
# )
207+
# config = pyconfig.config
208208

209-
batch_size = 1
210-
channels = 16
211-
frames = 21
212-
height = 90
213-
width = 160
214-
hidden_states_shape = (batch_size, frames, height, width, channels)
215-
dummy_hidden_states = jnp.ones(hidden_states_shape)
216-
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
217-
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
218-
219-
key = jax.random.key(0)
220-
rngs = nnx.Rngs(key)
221-
devices_array = create_device_mesh(config)
222-
223-
mesh_axes = ['data', 'fsdp', 'tensor']
224-
mesh = Mesh(devices_array, mesh_axes)
225-
batch_size = 1
226-
query_dim = 5120
227209
for attention_kernel in ["flash", "tokamax_flash"]:
228210
pyconfig.initialize(
229211
[
@@ -233,6 +215,22 @@ def test_wan_attention(self):
233215
]
234216
)
235217
config = pyconfig.config
218+
batch_size = 1
219+
channels = 16
220+
frames = 21
221+
height = 90
222+
width = 160
223+
hidden_states_shape = (batch_size, frames, height, width, channels)
224+
dummy_hidden_states = jnp.ones(hidden_states_shape)
225+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
226+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
227+
228+
key = jax.random.key(0)
229+
rngs = nnx.Rngs(key)
230+
devices_array = create_device_mesh(config)
231+
mesh = Mesh(devices_array, config.mesh_axes)
232+
batch_size = 1
233+
query_dim = 5120
236234
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
237235
flash_block_sizes = get_flash_block_sizes(config)
238236
attention = FlaxWanAttention(

0 commit comments

Comments
 (0)