@@ -9,6 +9,7 @@ namespace torch_ipex {
99namespace cpu {
1010
1111IPEX_DEFINE_DISPATCH (causal_conv1d_update_kernel_stub);
12+ IPEX_DEFINE_DISPATCH (causal_conv1d_fn_kernel_stub);
1213std::vector<int64_t > calc_conv_output_size (
1314 at::IntArrayRef input_size,
1415 at::IntArrayRef kernel_size,
@@ -514,21 +515,55 @@ at::Tensor convolution_forward(
514515 * @param conv_weights (dim, width)
515516 * @param conv_bias (dim,)
516517 * @param silu_activation If true, apply the SiLU activation function.
518+ * @param cache_seqlens (batch,) or None
517519 * @return (hidden_states, conv_states)
518520 */
519521std::tuple<at::Tensor, at::Tensor> causal_conv1d_update (
520522 const at::Tensor& hidden_states,
521523 const at::Tensor& conv_states,
522524 const at::Tensor& conv_weights,
523525 const c10::optional<at::Tensor>& conv_bias,
524- bool silu_activation) {
526+ bool silu_activation,
527+ const c10::optional<at::Tensor>& cache_seqlens) {
525528 RECORD_FUNCTION (" causal_conv1d_update" , c10::ArrayRef<c10::IValue>({}));
526529 return causal_conv1d_update_kernel_stub (
527530 kCPU ,
528531 hidden_states,
529532 conv_states,
530533 conv_weights,
531534 conv_bias,
535+ silu_activation,
536+ cache_seqlens);
537+ }
538+
539+ /* *
540+ * Official Python implementation: causal_conv1d_ref:
541+ * https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py#L133
542+ * @param x (batch, dim, seqlen)
543+ * @param conv_weights (dim, width)
544+ * @param conv_bias (dim,)
545+ * @param initial_states (batch, dim, width - 1)
546+ * @param final_states_out (batch, dim, width - 1)
547+ * @param silu_activation If true, apply the SiLU activation function.
548+ * @return (out, final_states_out)
549+ * out: (batch, dim, seqlen)
550+ * final_states_out: (batch, dim, width - 1)
551+ */
552+ std::tuple<at::Tensor, at::Tensor> causal_conv1d_fn (
553+ const at::Tensor& x,
554+ const at::Tensor& conv_weights,
555+ const c10::optional<at::Tensor>& conv_bias,
556+ const c10::optional<at::Tensor>& initial_states,
557+ const c10::optional<at::Tensor>& final_states_out,
558+ bool silu_activation) {
559+ RECORD_FUNCTION (" causal_conv1d_fn" , c10::ArrayRef<c10::IValue>({}));
560+ return causal_conv1d_fn_kernel_stub (
561+ kCPU ,
562+ x,
563+ conv_weights,
564+ conv_bias,
565+ initial_states,
566+ final_states_out,
532567 silu_activation);
533568}
534569
@@ -589,11 +624,17 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
589624 c10::DispatchKey::CPU,
590625 torch_ipex::cpu::convolution_forward_impl);
591626 m.def (
592- " causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation) -> (Tensor, Tensor)" );
627+ " causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation, Tensor? cache_seqlens=None ) -> (Tensor, Tensor)" );
593628 m.impl (
594629 " causal_conv1d_update" ,
595630 c10::DispatchKey::CPU,
596631 torch_ipex::cpu::causal_conv1d_update);
632+ m.def (
633+ " causal_conv1d_fn(Tensor x, Tensor conv_weights, Tensor? conv_bias, Tensor? initial_states, Tensor? final_states_out, bool silu_activation) -> (Tensor, Tensor)" );
634+ m.impl (
635+ " causal_conv1d_fn" ,
636+ c10::DispatchKey::CPU,
637+ torch_ipex::cpu::causal_conv1d_fn);
597638 // bw
598639 m.def (
599640 " convolution_backward(Tensor input, Tensor weight, Tensor? bias, Tensor grad_output, bool[3] out_mask, "
0 commit comments