@@ -387,7 +387,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
387387 relative_buckets += (relative_position > 0 ) * num_buckets
388388 relative_position = jnp .abs (relative_position )
389389 else :
390- relative_position = - jnp .clip (relative_position , max = 0 )
390+ relative_position = - jnp .clip (relative_position , a_max = 0 )
391391 # now relative_position is in the range [0, inf)
392392
393393 # half of the buckets are for exact increments in positions
@@ -398,7 +398,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
398398 relative_position_if_large = max_exact + (
399399 jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
400400 )
401- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
401+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
402402
403403 relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
404404
@@ -672,7 +672,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
672672 relative_buckets += (relative_position > 0 ) * num_buckets
673673 relative_position = jnp .abs (relative_position )
674674 else :
675- relative_position = - jnp .clip (relative_position , max = 0 )
675+ relative_position = - jnp .clip (relative_position , a_max = 0 )
676676 # now relative_position is in the range [0, inf)
677677
678678 # half of the buckets are for exact increments in positions
@@ -683,7 +683,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
683683 relative_position_if_large = max_exact + (
684684 jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
685685 )
686- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
686+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
687687
688688 relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
689689
@@ -895,7 +895,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
895895 relative_buckets += (relative_position > 0 ) * num_buckets
896896 relative_position = jnp .abs (relative_position )
897897 else :
898- relative_position = - jnp .clip (relative_position , max = 0 )
898+ relative_position = - jnp .clip (relative_position , a_max = 0 )
899899 # now relative_position is in the range [0, inf)
900900
901901 # half of the buckets are for exact increments in positions
@@ -906,7 +906,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
906906 relative_position_if_large = max_exact + (
907907 jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
908908 )
909- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
909+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
910910
911911 relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
912912
0 commit comments