Show HN: Mamba2-Jax; Mamba2 implemented in pure Jax/Flax
Mood
informative
Sentiment
positive
Category
startup_launch
Key topics
Jax
Flax
Mamba2
State-Space-Models
Machine-Learning
Deep-Learning
– Core Mamba2 block with LM (Mamba2ForCausalLM) and time-series (Mamba2Forecaster) heads – Pure JAX/Flax (no Triton/custom CUDA), runs on CPU / CUDA / TPU via standard JAX backends – Small CPU-only parity test vs mamba2-torch: similar loss curves, final MSE diff ≈ 0.012, prediction correlation ≈ 0.99; after JIT warmup JAX was ≈ 2× faster per step
I’d really appreciate feedback on: – API design, especially for streaming/stateful inference – Performance gotchas you hit if you try it – Any hooks you’d want exposed for research use
PyPI: https://pypi.org/project/mamba2-jax/
Thanks, Cosmo
Discussion Activity
No activity data yet
We're still syncing comments from Hacker News.
Generating AI Summary...
Analyzing up to 500 comments to identify key contributors and discussion patterns
Discussion hasn't started yet.
Want the full context?
Jump to the original sources
Read the primary article or dive into the live Hacker News thread when you're ready.