jax-flash-attn2

Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and pla…

Installation

In a virtualenv (see these instructions if you need to create one):

pip3 install jax-flash-attn2

Releases

Version Released Bullseye
Python 3.9
Bookworm
Python 3.11
Files
0.0.3 2025-03-04    
0.0.1 2024-10-23    

Issues with this package?

Page last updated 2025-07-17 20:17:42 UTC