Skip to content

LongLM isn't compatible with gemma-2-27b-it or gemma-2b-it #46

@uebian

Description

@uebian

I found that the current version of LongLM can not load Gemma 1 or Gemma 2 model successfully. I wrote a minimum test to help reproduce the issue:

# transfromers version 4.38.2
# this example is tested with 4 RTX3090s, 24GB memory each
import warnings
warnings.filterwarnings("ignore")

import torch 
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import SelfExtend 

window_size = 1024
group_size = 32

model_name = '/tmp/gemma-2b-it/'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
prompt = "How are you?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=4096)
answer = tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
print( answer )

While trying to load the model, it fails with the error message below:

$ python3 test.py 
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.07it/s]
Traceback (most recent call last):
  File "/var/lib/condor/execute/slot1/dir_2652801/test.py", line 22, in <module>
    SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
  File "/var/lib/condor/execute/slot1/dir_2652801/SelfExtend.py", line 160, in apply
    raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of GemmaForCausalLM

I found that it fails in the duplicate check in the L24 of SelfExtend.py. When it fails, instance = False.

Below is a conda env export dump including package details in my Python environment:

channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.7.2=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.14=h5eee18b_0
  - pip=24.0=py310h06a4308_0
  - python=3.10.14=h955ad1f_1
  - readline=8.2=h5eee18b_0
  - setuptools=69.5.1=py310h06a4308_0
  - sqlite=3.45.3=h5eee18b_0
  - tk=8.6.14=h39e8969_0
  - wheel=0.43.0=py310h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==0.33.0
      - aiohttp==3.9.5
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.4.0
      - async-timeout==4.0.3
      - attrs==23.2.0
      - certifi==2024.7.4
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - cmake==3.30.1
      - datasets==2.20.0
      - dill==0.3.8
      - diskcache==5.6.3
      - distro==1.9.0
      - dnspython==2.6.1
      - einops==0.8.0
      - email-validator==2.2.0
      - exceptiongroup==1.2.2
      - fastapi==0.111.1
      - fastapi-cli==0.0.4
      - filelock==3.15.4
      - flash-attn==2.6.3
      - frozenlist==1.4.1
      - fsspec==2024.5.0
      - h11==0.14.0
      - httpcore==1.0.5
      - httptools==0.6.1
      - httpx==0.27.0
      - huggingface-hub==0.24.2
      - idna==3.7
      - interegular==0.3.3
      - jinja2==3.1.4
      - jsonschema==4.23.0
      - jsonschema-specifications==2023.12.1
      - lark==1.1.9
      - llvmlite==0.43.0
      - lm-format-enforcer==0.10.3
      - markdown-it-py==3.0.0
      - markupsafe==2.1.5
      - mdurl==0.1.2
      - mpmath==1.3.0
      - msgpack==1.0.8
      - multidict==6.0.5
      - multiprocess==0.70.16
      - nest-asyncio==1.6.0
      - networkx==3.3
      - ninja==1.11.1.1
      - numba==0.60.0
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==8.9.2.26
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-ml-py==12.555.43
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.5.82
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.37.1
      - outlines==0.0.46
      - packaging==24.1
      - pandas==2.2.2
      - pillow==10.4.0
      - prometheus-client==0.20.0
      - prometheus-fastapi-instrumentator==7.0.0
      - protobuf==5.27.2
      - psutil==6.0.0
      - py-cpuinfo==9.0.0
      - pyairports==2.1.1
      - pyarrow==17.0.0
      - pyarrow-hotfix==0.6
      - pycountry==24.6.1
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pygments==2.18.0
      - python-dateutil==2.9.0.post0
      - python-dotenv==1.0.1
      - python-multipart==0.0.9
      - pytz==2024.1
      - pyyaml==6.0.1
      - pyzmq==26.0.3
      - ray==2.33.0
      - referencing==0.35.1
      - regex==2024.7.24
      - requests==2.32.3
      - rich==13.7.1
      - rpds-py==0.19.1
      - safetensors==0.4.3
      - sentencepiece==0.2.0
      - shellingham==1.5.4
      - six==1.16.0
      - sniffio==1.3.1
      - starlette==0.37.2
      - sympy==1.13.1
      - tiktoken==0.7.0
      - tokenizers==0.19.1
      - torch==2.3.1
      - torchvision==0.18.1
      - tqdm==4.66.4
      - transformers==4.43.3
      - triton==2.3.1
      - typer==0.12.3
      - typing-extensions==4.12.2
      - tzdata==2024.1
      - urllib3==2.2.2
      - uvicorn==0.30.3
      - uvloop==0.19.0
      - vllm==0.5.3.post1
      - vllm-flash-attn==2.5.9.post1
      - watchfiles==0.22.0
      - websockets==12.0
      - xformers==0.0.27
      - xxhash==3.4.1
      - yarl==1.9.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions