From d9217de78f0f1ce864a51287a08a0ceed8b553f3 Mon Sep 17 00:00:00 2001 From: medmekk Date: Wed, 10 Dec 2025 10:44:52 +0000 Subject: [PATCH 1/5] initial --- mamba-ssm/build.toml | 3 +- mamba-ssm/flake.lock | 99 +++---------------- mamba-ssm/flake.nix | 10 -- .../torch-ext/mamba_ssm/utils/generation.py | 3 +- 4 files changed, 16 insertions(+), 99 deletions(-) diff --git a/mamba-ssm/build.toml b/mamba-ssm/build.toml index a659ccf..b00fe6b 100644 --- a/mamba-ssm/build.toml +++ b/mamba-ssm/build.toml @@ -1,6 +1,7 @@ [general] name = "mamba_ssm" -universal = false +backends = ["cuda"] +python-depends = ["einops"] [torch] src = [ diff --git a/mamba-ssm/flake.lock b/mamba-ssm/flake.lock index e05ef3d..68fcaba 100644 --- a/mamba-ssm/flake.lock +++ b/mamba-ssm/flake.lock @@ -2,26 +2,11 @@ "nodes": { "flake-compat": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -48,61 +33,18 @@ "type": "github" } }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1760814603, - "narHash": "sha256-i5uuhnJPxOrd0dC8+btp31WMfzPDL8Uwz0TPG2n6nHE=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "c0b62ec3d0abb11dd2d960e3dfee3a46fc46d111", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, "kernel-builder": { "inputs": { "flake-compat": "flake-compat", "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] + "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1761645431, - "narHash": "sha256-Ns3m/L+FMAYnmKhwt4vlIf8lq6dOJWHAocFL23HasTM=", + "lastModified": 1765353242, + "narHash": "sha256-vkcX9frBYYdxJSI8DI1LF01XoQg1xsIJ1RbQ+thJON4=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "289788986c318e6ccb92608f011c49d61b25b5b6", + "rev": "2ffd75b34db10b2c708635b578a8856b02916dab", "type": "github" }, "original": { @@ -113,17 +55,17 @@ }, "nixpkgs": { "locked": { - "lastModified": 1755963616, - "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", - "owner": "nixos", + "lastModified": 1763291491, + "narHash": "sha256-eEYvm+45PPmy+Qe+nZDpn1uhoMUjJwx3PwVVQoO9ksA=", + "owner": "NixOS", "repo": "nixpkgs", - "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", "type": "github" }, "original": { - "owner": "nixos", - "ref": "nixos-unstable-small", + "owner": "NixOS", "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", "type": "github" } }, @@ -146,21 +88,6 @@ "repo": "default", "type": "github" } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } } }, "root": "root", diff --git a/mamba-ssm/flake.nix b/mamba-ssm/flake.nix index d50f88a..e916c64 100644 --- a/mamba-ssm/flake.nix +++ b/mamba-ssm/flake.nix @@ -13,15 +13,5 @@ kernel-builder.lib.genFlakeOutputs { inherit self; path = ./.; - # Has many external dependencies, see README.md, this kernel should - # probably be more lean. - doGetKernelCheck = false; - - pythonCheckInputs = - ps: with ps; [ - causal-conv1d - einops - transformers - ]; }; } diff --git a/mamba-ssm/torch-ext/mamba_ssm/utils/generation.py b/mamba-ssm/torch-ext/mamba_ssm/utils/generation.py index 330672a..873210c 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/utils/generation.py +++ b/mamba-ssm/torch-ext/mamba_ssm/utils/generation.py @@ -11,8 +11,7 @@ from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer - +from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer @dataclass class InferenceParams: From 82ce460ac3869bac46fddf18153ac64d51a67b48 Mon Sep 17 00:00:00 2001 From: medmekk Date: Wed, 10 Dec 2025 15:25:11 +0000 Subject: [PATCH 2/5] fix __init__s --- mamba-ssm/flake.lock | 6 +++--- mamba-ssm/flake.nix | 13 +++++++++++++ mamba-ssm/torch-ext/mamba_ssm/__init__.py | 14 +++++--------- mamba-ssm/torch-ext/mamba_ssm/models/__init__.py | 5 +++++ mamba-ssm/torch-ext/mamba_ssm/modules/__init__.py | 7 +++++++ mamba-ssm/torch-ext/mamba_ssm/ops/__init__.py | 1 + .../torch-ext/mamba_ssm/ops/triton/__init__.py | 2 ++ 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/mamba-ssm/flake.lock b/mamba-ssm/flake.lock index 68fcaba..e1c6b7d 100644 --- a/mamba-ssm/flake.lock +++ b/mamba-ssm/flake.lock @@ -40,11 +40,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1765353242, - "narHash": "sha256-vkcX9frBYYdxJSI8DI1LF01XoQg1xsIJ1RbQ+thJON4=", + "lastModified": 1765373725, + "narHash": "sha256-bMZUewjVTRqUN0r0kcUJwm3vCiqhGXFpxo3xbaHiNS8=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "2ffd75b34db10b2c708635b578a8856b02916dab", + "rev": "d7aa2703fd2df7b07d24d4a61b134c3b2d1e00aa", "type": "github" }, "original": { diff --git a/mamba-ssm/flake.nix b/mamba-ssm/flake.nix index e916c64..7167e23 100644 --- a/mamba-ssm/flake.nix +++ b/mamba-ssm/flake.nix @@ -13,5 +13,18 @@ kernel-builder.lib.genFlakeOutputs { inherit self; path = ./.; + + # Has many external dependencies, see README.md, this kernel should + # probably be more lean. + doGetKernelCheck = false; + + pythonCheckInputs = + ps: with ps; [ + causal-conv1d + einops + transformers + ]; }; + + } diff --git a/mamba-ssm/torch-ext/mamba_ssm/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/__init__.py index a767386..7d8f7db 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/__init__.py @@ -1,14 +1,10 @@ __version__ = "2.2.4" -from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn -from .modules.mamba_simple import Mamba -from .modules.mamba2 import Mamba2 -from .models.mixer_seq_simple import MambaLMHeadModel +from .ops import selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +from .modules import Mamba, Mamba2 +from .models import MambaLMHeadModel __all__ = [ - "selective_scan_fn", - "mamba_inner_fn", - "Mamba", - "Mamba2", - "MambaLMHeadModel", + "selective_scan_fn", "mamba_inner_fn", "selective_state_update", "mamba_chunk_scan_combined", "mamba_split_conv1d_scan_combined", + "Mamba", "Mamba2", "MambaLMHeadModel", ] diff --git a/mamba-ssm/torch-ext/mamba_ssm/models/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/models/__init__.py index e69de29..c64f5ff 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/models/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/models/__init__.py @@ -0,0 +1,5 @@ +from .mixer_seq_simple import MambaLMHeadModel + +__all__ = [ + "MambaLMHeadModel", +] \ No newline at end of file diff --git a/mamba-ssm/torch-ext/mamba_ssm/modules/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/modules/__init__.py index e69de29..ad51e45 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/modules/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/modules/__init__.py @@ -0,0 +1,7 @@ +from .mamba_simple import Mamba +from .mamba2 import Mamba2 + +__all__ = [ + "Mamba", + "Mamba2", +] \ No newline at end of file diff --git a/mamba-ssm/torch-ext/mamba_ssm/ops/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/ops/__init__.py index e69de29..1e994c1 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/ops/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/ops/__init__.py @@ -0,0 +1 @@ +from .triton import selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined \ No newline at end of file diff --git a/mamba-ssm/torch-ext/mamba_ssm/ops/triton/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/ops/triton/__init__.py index e69de29..b210bb0 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/ops/triton/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/ops/triton/__init__.py @@ -0,0 +1,2 @@ +from .selective_state_update import selective_state_update +from .ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined \ No newline at end of file From ac9fad5e020f74c039248d2aefc320e56bdee17e Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 11 Dec 2025 05:22:08 +0000 Subject: [PATCH 3/5] formatting --- mamba-ssm/flake.nix | 1 - 1 file changed, 1 deletion(-) diff --git a/mamba-ssm/flake.nix b/mamba-ssm/flake.nix index 7167e23..5af5789 100644 --- a/mamba-ssm/flake.nix +++ b/mamba-ssm/flake.nix @@ -26,5 +26,4 @@ ]; }; - } From ea443b644919f9e5caba792e67d67792c27e4258 Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 11 Dec 2025 09:26:31 +0000 Subject: [PATCH 4/5] fix --- mamba-ssm/torch-ext/mamba_ssm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mamba-ssm/torch-ext/mamba_ssm/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/__init__.py index 7d8f7db..3761cb9 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/__init__.py @@ -1,6 +1,7 @@ __version__ = "2.2.4" from .ops import selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from .modules import Mamba, Mamba2 from .models import MambaLMHeadModel From 4d430b75daab6ea7f04d709107c62d1ae0bbaa06 Mon Sep 17 00:00:00 2001 From: medmekk Date: Thu, 11 Dec 2025 13:40:10 +0000 Subject: [PATCH 5/5] update for falcon --- mamba-ssm/torch-ext/mamba_ssm/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mamba-ssm/torch-ext/mamba_ssm/__init__.py b/mamba-ssm/torch-ext/mamba_ssm/__init__.py index 3761cb9..c209429 100644 --- a/mamba-ssm/torch-ext/mamba_ssm/__init__.py +++ b/mamba-ssm/torch-ext/mamba_ssm/__init__.py @@ -5,7 +5,8 @@ from .modules import Mamba, Mamba2 from .models import MambaLMHeadModel +falcon_mamba_inner_fn = mamba_inner_fn __all__ = [ - "selective_scan_fn", "mamba_inner_fn", "selective_state_update", "mamba_chunk_scan_combined", "mamba_split_conv1d_scan_combined", + "selective_scan_fn", "mamba_inner_fn", "falcon_mamba_inner_fn", "selective_state_update", "mamba_chunk_scan_combined", "mamba_split_conv1d_scan_combined", "Mamba", "Mamba2", "MambaLMHeadModel", ]