| | #include <torch/library.h> |
| |
|
| | #include "registration.h" |
| | #include "torch_binding.h" |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | ops.def("fwd(" |
| | "Tensor! q, " |
| | "Tensor k, " |
| | "Tensor v, " |
| | "Tensor(out_!)? out_, " |
| | "Tensor? alibi_slopes_, " |
| | "float p_dropout, " |
| | "float softmax_scale, " |
| | "bool is_causal," |
| | "int window_size_left, " |
| | "int window_size_right, " |
| | "float softcap, " |
| | "bool return_softmax, " |
| | "Generator? gen_) -> Tensor[]"); |
| | ops.impl("fwd", torch::kCUDA, &mha_fwd); |
| |
|
| | ops.def("varlen_fwd(" |
| | "Tensor! q, " |
| | "Tensor k, " |
| | "Tensor v, " |
| | "Tensor? out_, " |
| | "Tensor cu_seqlens_q, " |
| | "Tensor cu_seqlens_k, " |
| | "Tensor? seqused_k_, " |
| | "Tensor? leftpad_k_, " |
| | "Tensor? block_table_, " |
| | "Tensor? alibi_slopes_, " |
| | "int max_seqlen_q, " |
| | "int max_seqlen_k, " |
| | "float p_dropout, " |
| | "float softmax_scale, " |
| | "bool zero_tensors, " |
| | "bool is_causal, " |
| | "int window_size_left, " |
| | "int window_size_right, " |
| | "float softcap, " |
| | "bool return_softmax, " |
| | "Generator? gen_) -> Tensor[]"); |
| | ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd); |
| |
|
| | ops.def("bwd(" |
| | "Tensor! dout, " |
| | "Tensor! q, " |
| | "Tensor! k, " |
| | "Tensor! v, " |
| | "Tensor! out, " |
| | "Tensor! " |
| | "softmax_lse, " |
| | "Tensor? dq_, " |
| | "Tensor? dk_, " |
| | "Tensor? dv_, " |
| | "Tensor? alibi_slopes_, " |
| | "float p_dropout, " |
| | "float softmax_scale, " |
| | "bool is_causal, " |
| | "int window_size_left, " |
| | "int window_size_right, " |
| | "float softcap, " |
| | "bool deterministic, " |
| | "Generator? gen_, " |
| | "Tensor? rng_state) -> Tensor[]"); |
| | ops.impl("bwd", torch::kCUDA, &mha_bwd); |
| |
|
| | ops.def("varlen_bwd(" |
| | "Tensor! dout, " |
| | "Tensor! q, " |
| | "Tensor! k, " |
| | "Tensor! v, " |
| | "Tensor! out, " |
| | "Tensor! softmax_lse, " |
| | "Tensor? dq_, " |
| | "Tensor? dk_, " |
| | "Tensor? dv_, " |
| | "Tensor cu_seqlens_q, " |
| | "Tensor cu_seqlens_k, " |
| | "Tensor? alibi_slopes_, " |
| | "int max_seqlen_q, " |
| | "int max_seqlen_k, " |
| | "float p_dropout, float softmax_scale, " |
| | "bool zero_tensors, " |
| | "bool is_causal, " |
| | "int window_size_left, " |
| | "int window_size_right, " |
| | "float softcap, " |
| | "bool deterministic, " |
| | "Generator? gen_, " |
| | "Tensor? rng_state) -> Tensor[]"); |
| | ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd); |
| |
|
| | ops.def("fwd_kvcache(" |
| | "Tensor! q, " |
| | "Tensor! kcache, " |
| | "Tensor! vcache, " |
| | "Tensor? k_, " |
| | "Tensor? v_, " |
| | "Tensor? seqlens_k_, " |
| | "Tensor? rotary_cos_, " |
| | "Tensor? rotary_sin_, " |
| | "Tensor? cache_batch_idx_, " |
| | "Tensor? leftpad_k_, " |
| | "Tensor? block_table_, " |
| | "Tensor? alibi_slopes_, " |
| | "Tensor? out_, " |
| | "float softmax_scale, " |
| | "bool is_causal, " |
| | "int window_size_left, " |
| | "int window_size_right, " |
| | "float softcap, " |
| | "bool is_rotary_interleaved, " |
| | "int num_splits) -> Tensor[]"); |
| | ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| |
|