diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 3c9b8cbedcf0..e230493469ee 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -71,6 +71,7 @@ steps: - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal + - tests/renderers - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ - tests/tool_parsers @@ -82,6 +83,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s renderers - pytest -v -s tokenizers_ - pytest -v -s tool_parsers - pytest -v -s transformers_utils diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2dcca5711b3d..6e7723c67934 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -64,6 +64,7 @@ steps: - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal + - tests/renderers - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ - tests/tool_parsers @@ -75,6 +76,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s renderers - pytest -v -s tokenizers_ - pytest -v -s tool_parsers - pytest -v -s transformers_utils diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml index 252af1e56a10..b3b4566abffc 100644 --- a/.buildkite/test_areas/misc.yaml +++ b/.buildkite/test_areas/misc.yaml @@ -121,6 +121,7 @@ steps: - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal + - tests/renderers - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ - tests/tool_parsers @@ -132,6 +133,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s renderers - pytest -v -s tokenizers_ - pytest -v -s tool_parsers - pytest -v -s transformers_utils diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py deleted file mode 100644 index 77087ac21ea8..000000000000 --- a/tests/entrypoints/openai/test_chat_template.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template -from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.tokenizers import get_tokenizer - -from ...models.registry import HF_EXAMPLE_MODELS -from ...utils import VLLM_PATH - -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - -# Define models, templates, and their corresponding expected outputs -MODEL_TEMPLATE_GENERATION_OUTPUT = [ - ( - "facebook/opt-125m", - chatml_jinja_path, - True, - False, - """<|im_start|>user -Hello<|im_end|> -<|im_start|>assistant -Hi there!<|im_end|> -<|im_start|>user -What is the capital of<|im_end|> -<|im_start|>assistant -""", - ), - ( - "facebook/opt-125m", - chatml_jinja_path, - False, - False, - """<|im_start|>user -Hello<|im_end|> -<|im_start|>assistant -Hi there!<|im_end|> -<|im_start|>user -What is the capital of""", - ), - ( - "facebook/opt-125m", - chatml_jinja_path, - False, - True, - """<|im_start|>user -Hello<|im_end|> -<|im_start|>assistant -Hi there!<|im_end|> -<|im_start|>user -What is the capital of<|im_end|> -<|im_start|>assistant -The capital of""", - ), -] - -TEST_MESSAGES = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "What is the capital of"}, -] -ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"} - - -def test_load_chat_template(): - # Testing chatml template - template_content = load_chat_template(chat_template=chatml_jinja_path) - - # Test assertions - assert template_content is not None - # Hard coded value for template_chatml.jinja - assert ( - template_content - == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} -{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 - ) - - -def test_no_load_chat_template_filelike(): - # Testing chatml template - template = "../../examples/does_not_exist" - - with pytest.raises(ValueError, match="looks like a file path"): - load_chat_template(chat_template=template) - - -def test_no_load_chat_template_literallike(): - # Testing chatml template - template = "{{ messages }}" - - template_content = load_chat_template(chat_template=template) - - assert template_content == template - - -@pytest.mark.parametrize( - "model,template,add_generation_prompt,continue_final_message,expected_output", - MODEL_TEMPLATE_GENERATION_OUTPUT, -) -def test_get_gen_prompt( - model, template, add_generation_prompt, continue_final_message, expected_output -): - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - - model_config = ModelConfig( - model, - tokenizer=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - trust_remote_code=model_info.trust_remote_code, - revision=model_info.revision, - hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.require_embed_inputs, - enable_prompt_embeds=model_info.require_embed_inputs, - enable_mm_embeds=model_info.require_embed_inputs, - enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype, - ) - - # Initialize the tokenizer - tokenizer = get_tokenizer( - tokenizer_name=model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) - template_content = load_chat_template(chat_template=template) - - # Create a mock request object using keyword arguments - mock_request = ChatCompletionRequest( - model=model, - messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] - if continue_final_message - else TEST_MESSAGES, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - ) - - # Call the function and get the result - result = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=mock_request.messages, - chat_template=mock_request.chat_template or template_content, - model_config=model_config, - tools=None, - add_generation_prompt=mock_request.add_generation_prompt, - continue_final_message=mock_request.continue_final_message, - ) - - # Test assertion - assert result == expected_output, ( - f"The generated prompt does not match the expected output for " - f"model {model} and template {template}" - ) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 444275e061c6..7e0cdf7b503a 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -20,7 +20,9 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput +from vllm.renderers.hf import HfRenderer from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tool_parsers import ToolParserManager from vllm.v1.engine.async_llm import AsyncLLM @@ -379,6 +381,15 @@ def get_diff_sampling_param(self): return self.diff_sampling_param or {} +def _build_renderer(model_config: MockModelConfig): + _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) + + return HfRenderer( + model_config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) + + def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: models = OpenAIServingModels( engine_client=engine, @@ -413,6 +424,7 @@ class MockEngine: model_config: MockModelConfig = field(default_factory=MockModelConfig) input_processor: MagicMock = field(default_factory=MagicMock) io_processor: MagicMock = field(default_factory=MagicMock) + renderer: MagicMock = field(default_factory=MagicMock) async def _async_serving_chat_init(): @@ -438,11 +450,11 @@ def test_async_serving_chat_init(): @pytest.mark.asyncio async def test_serving_chat_returns_correct_model_name(): mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) messages = [{"role": "user", "content": "what is 1+1?"}] @@ -468,11 +480,11 @@ async def return_model_name(*args): @pytest.mark.asyncio async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -501,11 +513,11 @@ async def test_serving_chat_should_set_correct_max_tokens(): # Reinitialize the engine with new settings mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat serving_chat = _build_serving_chat(mock_engine) @@ -546,11 +558,11 @@ async def test_serving_chat_should_set_correct_max_tokens(): # Reinitialize the engine with new settings mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat serving_chat = _build_serving_chat(mock_engine) @@ -592,11 +604,11 @@ async def test_serving_chat_could_load_correct_generation_config(): } mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat serving_chat = _build_serving_chat(mock_engine) @@ -638,11 +650,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -671,11 +683,11 @@ async def test_serving_chat_data_parallel_rank_extraction(): """Test that data_parallel_rank is properly extracted from header and passed to engine.""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) # Mock the generate method to return an async generator async def mock_generate(*args, **kwargs): diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py deleted file mode 100644 index 192c7cafb749..000000000000 --- a/tests/entrypoints/openai/test_serving_engine.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import time -from unittest.mock import Mock - -import pytest - -from vllm.config import ModelConfig -from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.tokenizers.mistral import MistralTokenizer - - -@pytest.fixture() -def serving() -> OpenAIServing: - """Create a minimal OpenAIServing instance for testing.""" - - # Create minimal mocks - engine_client = Mock() - model_config = Mock(spec=ModelConfig) - model_config.max_model_len = 32768 - models = Mock(spec=OpenAIServingModels) - models.model_config = model_config - models.input_processor = Mock() - models.io_processor = Mock() - - serving = OpenAIServing( - engine_client=engine_client, - models=models, - request_logger=None, - ) - return serving - - -@pytest.mark.asyncio -async def test_async_mistral_tokenizer_does_not_block_event_loop( - serving: OpenAIServing, -): - expected_tokens = [1, 2, 3] - - # Mock the blocking version to sleep - def mocked_apply_chat_template(*_args, **_kwargs): - time.sleep(2) - return expected_tokens - - mock_tokenizer = Mock(spec=MistralTokenizer) - mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template - - task = serving._apply_mistral_chat_template_async( - tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[] - ) - - # Ensure the event loop is not blocked - blocked_count = 0 - for _i in range(20): # Check over ~2 seconds - start = time.perf_counter() - await asyncio.sleep(0) - elapsed = time.perf_counter() - start - - # an overly generous elapsed time for slow machines - if elapsed >= 0.5: - blocked_count += 1 - - await asyncio.sleep(0.1) - - # Ensure task completes - tokens = await task - assert tokens == expected_tokens, "Mocked blocking tokenizer was not called" - assert blocked_count == 0, "Event loop blocked during tokenization" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index b585835a0667..2c135cd1dea3 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -32,6 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: mock_engine_client.model_config = mock_model_config mock_engine_client.input_processor = MagicMock() mock_engine_client.io_processor = MagicMock() + mock_engine_client.renderer = MagicMock() serving_models = OpenAIServingModels( engine_client=mock_engine_client, diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 7d03dccec30d..2264227fdb61 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -130,6 +130,7 @@ async def serving_responses_instance(self): engine_client.input_processor = MagicMock() engine_client.io_processor = MagicMock() + engine_client.renderer = MagicMock() models = MagicMock() @@ -216,6 +217,7 @@ async def serving_responses_instance(self): engine_client.input_processor = MagicMock() engine_client.io_processor = MagicMock() + engine_client.renderer = MagicMock() models = MagicMock() diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index a87a4c35d3dc..43acc1fc14f9 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -7,21 +7,14 @@ import pytest import torch -from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( - _try_extract_ast, - apply_mistral_chat_template, - load_chat_template, parse_chat_messages, parse_chat_messages_futures, - resolve_chat_template_content_format, - resolve_chat_template_kwargs, - resolve_hf_chat_template, ) from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import ( @@ -29,24 +22,11 @@ encode_image_base64, encode_video_base64, ) -from vllm.tokenizers import get_tokenizer -from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.serial_utils import tensor2base64 -from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import VLLM_PATH - -EXAMPLES_DIR = VLLM_PATH / "examples" - PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" -QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" -QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" -QWEN3_MODEL_ID = "Qwen/Qwen3-8B" -LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" -HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -2033,377 +2013,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( ) -@pytest.mark.parametrize( - "model", - [ - QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str - HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ], -) -@pytest.mark.parametrize("use_tools", [True, False]) -def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): - """checks that chat_template is a dict type for HF models.""" - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - - model_config = ModelConfig( - model, - tokenizer=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.require_embed_inputs, - enable_prompt_embeds=model_info.require_embed_inputs, - enable_mm_embeds=model_info.require_embed_inputs, - enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype, - ) - - # Build the tokenizer - tokenizer = get_tokenizer( - model, - trust_remote_code=model_config.trust_remote_code, - ) - - tools = ( - [ - { - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - } - ] - if use_tools - else None - ) - - # Test detecting the tokenizer's chat_template - chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=None, - tools=tools, - model_config=model_config, - ) - assert isinstance(chat_template, str) - - -@pytest.mark.parametrize( - "model, expected_kwargs", - [ - ( - QWEN2VL_MODEL_ID, - { - "add_vision_id", - "add_generation_prompt", - "continue_final_message", - "tools", - }, - ), - ( - QWEN3_MODEL_ID, - { - "enable_thinking", - "add_generation_prompt", - "continue_final_message", - "tools", - }, - ), - ], -) -def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwargs): - """checks that chat_template is a dict type for HF models.""" - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - - tools = [ - { - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - } - ] - - chat_template_kwargs = { - # both unused - "unsed_kwargs_1": 123, - "unsed_kwargs_2": "abc", - # should not appear - "chat_template": "{% Hello world! %}", - "tokenize": True, - # used by tokenizer - "continue_final_message": True, - "tools": tools, - # both used by Qwen2-VL and Qwen3 - "add_generation_prompt": True, - # only used by Qwen2-VL - "add_vision_id": True, - # only used by Qwen3 - "enable_thinking": True, - } - - model_config = ModelConfig( - model, - tokenizer=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.require_embed_inputs, - enable_prompt_embeds=model_info.require_embed_inputs, - enable_mm_embeds=model_info.require_embed_inputs, - enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype, - ) - - # Build the tokenizer - tokenizer = get_tokenizer( - model, - trust_remote_code=model_config.trust_remote_code, - ) - - # Test detecting the tokenizer's chat_template - chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=None, - tools=tools, - model_config=model_config, - ) - with pytest.raises( - ValueError, match="Found unexpected chat template kwargs from request" - ): - # should raise error if `chat_template_kwargs` contains - # `chat_template` or `tokenize` - resolve_chat_template_kwargs( - tokenizer, - chat_template=chat_template, - chat_template_kwargs=chat_template_kwargs, - ) - resolved_chat_template_kwargs = resolve_chat_template_kwargs( - tokenizer, - chat_template=chat_template, - chat_template_kwargs=chat_template_kwargs, - raise_on_unexpected=False, - ) - assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs - - # Additional test: Verify HF base parameters work with **kwargs tokenizers - # This validates the fix for tokenizers like Kimi K2 that use **kwargs - # to receive standard HuggingFace parameters instead of declaring them explicitly - from vllm.entrypoints.chat_utils import _get_hf_base_chat_template_params - - hf_base_params = _get_hf_base_chat_template_params() - # Verify common HF parameters are in the base class - assert {"add_generation_prompt", "tools", "continue_final_message"}.issubset( - hf_base_params - ), f"Expected HF base params not found in {hf_base_params}" - - # Test with a mock tokenizer that uses **kwargs (like Kimi K2) - class MockTokenizerWithKwargs: - def apply_chat_template(self, conversation, **kwargs): - return "mocked_output" - - mock_tokenizer = MockTokenizerWithKwargs() - mock_kwargs = { - "add_generation_prompt": True, - "tools": tools, - "continue_final_message": False, - "unknown_param": "should_be_filtered", - } - resolved_mock = resolve_chat_template_kwargs( - mock_tokenizer, chat_template, mock_kwargs, raise_on_unexpected=False - ) - # HF base params should pass through even with **kwargs tokenizer - assert "add_generation_prompt" in resolved_mock - assert "tools" in resolved_mock - assert "continue_final_message" in resolved_mock - # Unknown params should be filtered out - assert "unknown_param" not in resolved_mock - - -# NOTE: Qwen2-Audio default chat template is specially defined inside -# processor class instead of using `tokenizer_config.json` -@pytest.mark.parametrize( - ("model", "expected_format"), - [ - (PHI3V_MODEL_ID, "string"), - (QWEN2VL_MODEL_ID, "openai"), - (QWEN25VL_MODEL_ID, "openai"), - (ULTRAVOX_MODEL_ID, "string"), - (QWEN2AUDIO_MODEL_ID, "openai"), - (LLAMA_GUARD_MODEL_ID, "openai"), - ], -) -def test_resolve_content_format_hf_defined(model, expected_format): - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - - model_config = ModelConfig( - model, - tokenizer=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.require_embed_inputs, - enable_prompt_embeds=model_info.require_embed_inputs, - enable_mm_embeds=model_info.require_embed_inputs, - enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype, - ) - - tokenizer = get_tokenizer( - model, - trust_remote_code=model_config.trust_remote_code, - ) - - # Test detecting the tokenizer's chat_template - chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=None, - tools=None, - model_config=model_config, - ) - assert isinstance(chat_template, str) - - print("[TEXT]") - print(chat_template) - print("[AST]") - print(_try_extract_ast(chat_template)) - - resolved_format = resolve_chat_template_content_format( - None, # Test detecting the tokenizer's chat_template - None, - "auto", - tokenizer, - model_config=model_config, - ) - - assert resolved_format == expected_format - - -@pytest.mark.parametrize( - ("model", "expected_format"), - [ - ("Salesforce/blip2-opt-2.7b", "string"), - ("facebook/chameleon-7b", "string"), - ("deepseek-ai/deepseek-vl2-tiny", "string"), - ("adept/fuyu-8b", "string"), - ("google/paligemma-3b-mix-224", "string"), - ("Qwen/Qwen-VL", "string"), - ("Qwen/Qwen-VL-Chat", "string"), - ], -) -def test_resolve_content_format_fallbacks(model, expected_format): - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - - model_config = ModelConfig( - model, - tokenizer=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.require_embed_inputs, - enable_prompt_embeds=model_info.require_embed_inputs, - enable_mm_embeds=model_info.require_embed_inputs, - enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype, - ) - - tokenizer = get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) - - # Test detecting the tokenizer's chat_template - chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=None, - tools=None, - model_config=model_config, - ) - assert isinstance(chat_template, str) - - print("[TEXT]") - print(chat_template) - print("[AST]") - print(_try_extract_ast(chat_template)) - - resolved_format = resolve_chat_template_content_format( - None, # Test detecting the tokenizer's chat_template - None, - "auto", - tokenizer, - model_config=model_config, - ) - - assert resolved_format == expected_format - - -@pytest.mark.parametrize( - ("template_path", "expected_format"), - [ - ("template_alpaca.jinja", "string"), - ("template_baichuan.jinja", "string"), - ("template_chatglm.jinja", "string"), - ("template_chatglm2.jinja", "string"), - ("template_chatml.jinja", "string"), - ("template_dse_qwen2_vl.jinja", "openai"), - ("template_falcon_180b.jinja", "string"), - ("template_falcon.jinja", "string"), - ("template_inkbot.jinja", "string"), - ("template_teleflm.jinja", "string"), - ("template_vlm2vec_phi3v.jinja", "openai"), - ("template_vlm2vec_qwen2vl.jinja", "openai"), - ("tool_chat_template_granite_20b_fc.jinja", "string"), - ("tool_chat_template_hermes.jinja", "string"), - ("tool_chat_template_internlm2_tool.jinja", "string"), - ("tool_chat_template_llama3.1_json.jinja", "openai"), - ("tool_chat_template_llama3.2_json.jinja", "openai"), - ("tool_chat_template_mistral_parallel.jinja", "string"), - ("tool_chat_template_mistral.jinja", "string"), - ], -) -def test_resolve_content_format_examples(template_path, expected_format): - model_config = ModelConfig( - PHI3V_MODEL_ID, # Dummy - tokenizer=PHI3V_MODEL_ID, # Dummy - trust_remote_code=True, - ) - - dummy_tokenizer = get_tokenizer( - PHI3V_MODEL_ID, # Dummy - trust_remote_code=model_config.trust_remote_code, - ) - dummy_tokenizer.chat_template = None - - chat_template = load_chat_template(EXAMPLES_DIR / template_path) - assert isinstance(chat_template, str) - - print("[TEXT]") - print(chat_template) - print("[AST]") - print(_try_extract_ast(chat_template)) - - resolved_format = resolve_chat_template_content_format( - chat_template, - None, - "auto", - dummy_tokenizer, - model_config=model_config, - ) - - assert resolved_format == expected_format - - def test_parse_chat_messages_include_thinking_chunk(mistral_model_config): messages = [ { @@ -2465,56 +2074,6 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config): assert conversation_with_thinking == expected_conversation -def test_apply_mistral_chat_template_thinking_chunk(): - messages = [ - { - "role": "system", - "content": [ - {"type": "text", "text": "You are a helpful assistant."}, - { - "type": "thinking", - "closed": True, - "thinking": "Only return the answer when you are confident.", - }, - ], - }, - {"role": "user", "content": "What is 2+2?"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Let me think about it."}, - {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, - { - "type": "text", - "text": "The answer is 4.", - }, - ], - }, - {"role": "user", "content": "Thanks, what is 3+3?"}, - ] - mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Magistral-Small-2509" - ) - - tokens_ids = apply_mistral_chat_template( - mistral_tokenizer, messages, chat_template=None, tools=None - ) - - string_tokens = mistral_tokenizer.mistral.decode( - tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP - ) - - expected_tokens = ( - r"[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" - r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" - r"[INST]What is 2+2?[/INST]" - r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4." - r"[INST]Thanks, what is 3+3?[/INST]" - ) - - assert string_tokens == expected_tokens - - def test_parse_chat_messages_single_empty_audio_with_uuid( qwen2_audio_model_config, ): diff --git a/tests/renderers/__init__.py b/tests/renderers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/renderers/test_hf.py b/tests/renderers/test_hf.py new file mode 100644 index 000000000000..168dfaa4b403 --- /dev/null +++ b/tests/renderers/test_hf.py @@ -0,0 +1,537 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.renderers.hf import ( + _get_hf_base_chat_template_params, + _try_extract_ast, + resolve_chat_template, + resolve_chat_template_content_format, + resolve_chat_template_kwargs, + safe_apply_chat_template, +) +from vllm.tokenizers import get_tokenizer + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import VLLM_PATH + +EXAMPLES_DIR = VLLM_PATH / "examples" + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() + +# Define models, templates, and their corresponding expected outputs +MODEL_TEMPLATE_GENERATION_OUTPUT = [ + ( + "facebook/opt-125m", + chatml_jinja_path, + True, + False, + """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of<|im_end|> +<|im_start|>assistant +""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + False, + """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + True, + """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of<|im_end|> +<|im_start|>assistant +The capital of""", + ), +] + +TEST_MESSAGES = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What is the capital of"}, +] +ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"} + + +def test_load_chat_template(): + # Testing chatml template + template_content = load_chat_template(chat_template=chatml_jinja_path) + + # Test assertions + assert template_content is not None + # Hard coded value for template_chatml.jinja + assert ( + template_content + == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 + ) + + +def test_no_load_chat_template_filelike(): + # Testing chatml template + template = "../../examples/does_not_exist" + + with pytest.raises(ValueError, match="looks like a file path"): + load_chat_template(chat_template=template) + + +def test_no_load_chat_template_literallike(): + # Testing chatml template + template = "{{ messages }}" + + template_content = load_chat_template(chat_template=template) + + assert template_content == template + + +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2-VL-2B-Instruct", # chat_template is of type str + "NousResearch/Hermes-3-Llama-3.1-8B", # chat_template is of type dict + ], +) +@pytest.mark.parametrize("use_tools", [True, False]) +def test_resolve_chat_template(sample_json_schema, model, use_tools): + """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + # Build the tokenizer + tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + + tools = ( + [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + if use_tools + else None + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_chat_template( + tokenizer, + chat_template=None, + tools=tools, + model_config=model_config, + ) + assert isinstance(chat_template, str) + + +@pytest.mark.parametrize( + "model, expected_kwargs", + [ + ( + "Qwen/Qwen2-VL-2B-Instruct", + { + "add_vision_id", + "add_generation_prompt", + "continue_final_message", + "tools", + }, + ), + ( + "Qwen/Qwen3-8B", + { + "enable_thinking", + "add_generation_prompt", + "continue_final_message", + "tools", + }, + ), + ], +) +def test_resolve_chat_template_kwargs(sample_json_schema, model, expected_kwargs): + """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + + chat_template_kwargs = { + # both unused + "unsed_kwargs_1": 123, + "unsed_kwargs_2": "abc", + # should not appear + "chat_template": "{% Hello world! %}", + "tokenize": True, + # used by tokenizer + "continue_final_message": True, + "tools": tools, + # both used by Qwen2-VL and Qwen3 + "add_generation_prompt": True, + # only used by Qwen2-VL + "add_vision_id": True, + # only used by Qwen3 + "enable_thinking": True, + } + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + # Build the tokenizer + tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_chat_template( + tokenizer, + chat_template=None, + tools=tools, + model_config=model_config, + ) + with pytest.raises( + ValueError, match="Found unexpected chat template kwargs from request" + ): + # should raise error if `chat_template_kwargs` contains + # `chat_template` or `tokenize` + resolve_chat_template_kwargs( + tokenizer, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, + ) + resolved_chat_template_kwargs = resolve_chat_template_kwargs( + tokenizer, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, + raise_on_unexpected=False, + ) + assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs + + # Additional test: Verify HF base parameters work with **kwargs tokenizers + # This validates the fix for tokenizers like Kimi K2 that use **kwargs + # to receive standard HuggingFace parameters instead of declaring them explicitly + hf_base_params = _get_hf_base_chat_template_params() + # Verify common HF parameters are in the base class + assert {"add_generation_prompt", "tools", "continue_final_message"}.issubset( + hf_base_params + ), f"Expected HF base params not found in {hf_base_params}" + + # Test with a mock tokenizer that uses **kwargs (like Kimi K2) + class MockTokenizerWithKwargs: + def apply_chat_template(self, conversation, **kwargs): + return "mocked_output" + + mock_tokenizer = MockTokenizerWithKwargs() + mock_kwargs = { + "add_generation_prompt": True, + "tools": tools, + "continue_final_message": False, + "unknown_param": "should_be_filtered", + } + resolved_mock = resolve_chat_template_kwargs( + mock_tokenizer, chat_template, mock_kwargs, raise_on_unexpected=False + ) + # HF base params should pass through even with **kwargs tokenizer + assert "add_generation_prompt" in resolved_mock + assert "tools" in resolved_mock + assert "continue_final_message" in resolved_mock + # Unknown params should be filtered out + assert "unknown_param" not in resolved_mock + + +# NOTE: Qwen2-Audio default chat template is specially defined inside +# processor class instead of using `tokenizer_config.json` +@pytest.mark.parametrize( + ("model", "expected_format"), + [ + ("microsoft/Phi-3.5-vision-instruct", "string"), + ("Qwen/Qwen2-VL-2B-Instruct", "openai"), + ("Qwen/Qwen2.5-VL-3B-Instruct", "openai"), + ("fixie-ai/ultravox-v0_5-llama-3_2-1b", "string"), + ("Qwen/Qwen2-Audio-7B-Instruct", "openai"), + ("meta-llama/Llama-Guard-3-1B", "openai"), + ], +) +def test_resolve_content_format_hf_defined(model, expected_format): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_chat_template( + tokenizer, + chat_template=None, + tools=None, + model_config=model_config, + ) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + None, # Test detecting the tokenizer's chat_template + None, + "auto", + tokenizer, + model_config=model_config, + ) + + assert resolved_format == expected_format + + +@pytest.mark.parametrize( + ("model", "expected_format"), + [ + ("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string"), + ], +) +def test_resolve_content_format_fallbacks(model, expected_format): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + tokenizer = get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_chat_template( + tokenizer, + chat_template=None, + tools=None, + model_config=model_config, + ) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + None, # Test detecting the tokenizer's chat_template + None, + "auto", + tokenizer, + model_config=model_config, + ) + + assert resolved_format == expected_format + + +@pytest.mark.parametrize( + ("template_path", "expected_format"), + [ + ("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_dse_qwen2_vl.jinja", "openai"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_teleflm.jinja", "string"), + ("template_vlm2vec_phi3v.jinja", "openai"), + ("template_vlm2vec_qwen2vl.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "openai"), + ("tool_chat_template_llama3.2_json.jinja", "openai"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string"), + ], +) +def test_resolve_content_format_examples(template_path, expected_format): + model = "Qwen/Qwen2-VL-2B-Instruct" # Dummy + model_config = ModelConfig( + model, + tokenizer=model, + trust_remote_code=True, + ) + + dummy_tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + dummy_tokenizer.chat_template = None + + chat_template = load_chat_template(EXAMPLES_DIR / template_path) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + chat_template, + None, + "auto", + dummy_tokenizer, + model_config=model_config, + ) + + assert resolved_format == expected_format + + +@pytest.mark.parametrize( + "model,template,add_generation_prompt,continue_final_message,expected_output", + MODEL_TEMPLATE_GENERATION_OUTPUT, +) +def test_get_gen_prompt( + model, template, add_generation_prompt, continue_final_message, expected_output +): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + revision=model_info.revision, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + # Initialize the tokenizer + tokenizer = get_tokenizer( + tokenizer_name=model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + template_content = load_chat_template(chat_template=template) + + # Create a mock request object using keyword arguments + mock_request = ChatCompletionRequest( + model=model, + messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] + if continue_final_message + else TEST_MESSAGES, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + + # Call the function and get the result + result = safe_apply_chat_template( + model_config, + tokenizer, + mock_request.messages, + tools=None, + chat_template=mock_request.chat_template or template_content, + add_generation_prompt=mock_request.add_generation_prompt, + continue_final_message=mock_request.continue_final_message, + tokenize=False, + ) + + # Test assertion + assert result == expected_output, ( + f"The generated prompt does not match the expected output for " + f"model {model} and template {template}" + ) diff --git a/tests/renderers/test_mistral.py b/tests/renderers/test_mistral.py new file mode 100644 index 000000000000..0dc214ae939b --- /dev/null +++ b/tests/renderers/test_mistral.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from unittest.mock import Mock + +import pytest +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + +from vllm.config import ModelConfig +from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template +from vllm.tokenizers.mistral import MistralTokenizer + + +@pytest.mark.asyncio +async def test_async_mistral_tokenizer_does_not_block_event_loop(): + expected_tokens = [1, 2, 3] + + # Mock the blocking version to sleep + def mocked_apply_chat_template(*_args, **_kwargs): + time.sleep(2) + return expected_tokens + + mock_tokenizer = Mock(spec=MistralTokenizer) + mock_tokenizer.apply_chat_template = mocked_apply_chat_template + mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) + mock_renderer._tokenizer = mock_tokenizer + + task = mock_renderer.render_messages_async([]) + + # Ensure the event loop is not blocked + blocked_count = 0 + for _i in range(20): # Check over ~2 seconds + start = time.perf_counter() + await asyncio.sleep(0) + elapsed = time.perf_counter() - start + + # an overly generous elapsed time for slow machines + if elapsed >= 0.5: + blocked_count += 1 + + await asyncio.sleep(0.1) + + # Ensure task completes + _, prompt = await task + assert prompt["prompt_token_ids"] == expected_tokens, ( + "Mocked blocking tokenizer was not called" + ) + assert blocked_count == 0, "Event loop blocked during tokenization" + + +def test_apply_mistral_chat_template_thinking_chunk(): + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "thinking", + "closed": True, + "thinking": "Only return the answer when you are confident.", + }, + ], + }, + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, + { + "type": "text", + "text": "The answer is 4.", + }, + ], + }, + {"role": "user", "content": "Thanks, what is 3+3?"}, + ] + mistral_tokenizer = MistralTokenizer.from_pretrained( + "mistralai/Magistral-Small-2509" + ) + + tokens_ids = safe_apply_chat_template( + mistral_tokenizer, messages, chat_template=None, tools=None + ) + + string_tokens = mistral_tokenizer.mistral.decode( + tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP + ) + + expected_tokens = ( + r"[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" + r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" + r"[INST]What is 2+2?[/INST]" + r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4." + r"[INST]Thanks, what is 3+3?[/INST]" + ) + + assert string_tokens == expected_tokens diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 073be24a4a07..6ea4f465cdff 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,6 @@ from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.tokenizers import cached_tokenizer_from_config pytestmark = pytest.mark.cpu_test @@ -115,10 +114,10 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = cached_tokenizer_from_config(model_config) - input_preprocessor = InputPreprocessor(model_config, tokenizer) + input_preprocessor = InputPreprocessor(model_config) # HF processor adds sep token + tokenizer = input_preprocessor.get_tokenizer() sep_token_id = tokenizer.vocab[tokenizer.sep_token] processed_inputs = input_preprocessor.preprocess(prompt) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index c1d5f8af7917..7e5196efc873 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -224,7 +224,7 @@ def test_skip_tokenizer_initialization(model: str): ) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - with pytest.raises(ValueError, match="cannot pass text prompts when"): + with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): llm.generate("abc", sampling_params) outputs = llm.generate( diff --git a/tests/v1/engine/test_process_multi_modal_uuids.py b/tests/v1/engine/test_process_multi_modal_uuids.py index 1b11b8af49d1..c5158998f3b2 100644 --- a/tests/v1/engine/test_process_multi_modal_uuids.py +++ b/tests/v1/engine/test_process_multi_modal_uuids.py @@ -5,7 +5,13 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + MultiModalConfig, + VllmConfig, +) from vllm.sampling_params import SamplingParams from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine.input_processor import InputProcessor @@ -44,27 +50,22 @@ def _mock_input_processor( monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True) model_config = ModelConfig( + tokenizer="dummy", skip_tokenizer_init=True, max_model_len=128, mm_processor_cache_gb=mm_cache_gb, generation_config="vllm", - tokenizer="dummy", ) + model_config.runner_type = "generate" + model_config.multimodal_config = MultiModalConfig(mm_processor_cache_gb=mm_cache_gb) - # Minimal multimodal_config to satisfy references in - # Processor.process_inputs. - class _MockMMConfig: - def __init__(self, gb: float): - self.mm_processor_cache_gb = gb - - model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined] vllm_config = VllmConfig( model_config=model_config, cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), device_config=DeviceConfig(device="cpu"), ) - return InputProcessor(vllm_config, tokenizer=None) + return InputProcessor(vllm_config) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 3f7e0a069f86..3c63b2178671 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -35,6 +35,7 @@ "vllm/multimodal", "vllm/platforms", "vllm/plugins", + "vllm/renderers", "vllm/tokenizers", "vllm/transformers_utils", "vllm/triton_utils", diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d94951a0cffc..461c1f193c87 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,9 +11,9 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.input_processor import InputProcessor @@ -26,6 +26,10 @@ class EngineClient(ABC): input_processor: InputProcessor io_processor: IOProcessor | None + @property + @abstractmethod + def renderer(self) -> RendererLike: ... + @property @abstractmethod def is_running(self) -> bool: ... @@ -84,11 +88,6 @@ async def abort(self, request_id: str | Iterable[str]) -> None: """ ... - @abstractmethod - async def get_tokenizer(self) -> TokenizerLike: - """Get the tokenizer""" - ... - @abstractmethod async def is_tracing_enabled(self) -> bool: ... diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index ab055dfb1fb0..7907097c17f7 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,22 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import inspect import json +import warnings from abc import ABC, abstractmethod -from collections import Counter, defaultdict, deque +from collections import Counter, defaultdict from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast -import jinja2 -import jinja2.ext -import jinja2.meta -import jinja2.nodes -import jinja2.parser -import jinja2.sandbox -import transformers.utils.chat_template_utils as hf_chat_utils from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam, @@ -39,7 +32,6 @@ from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypedDict @@ -50,23 +42,35 @@ from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector -from vllm.tokenizers import TokenizerLike -from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path -from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid from vllm.utils.collection_utils import is_list_of -from vllm.utils.func_utils import supports_kw from vllm.utils.import_utils import LazyLoader if TYPE_CHECKING: import torch - - from vllm.tokenizers.mistral import MistralTokenizer else: torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) + +def __getattr__(name: str): + if name == "resolve_hf_chat_template": + from vllm.renderers.hf import resolve_chat_template + + warnings.warn( + "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to " + "`vllm.renderers.hf.resolve_chat_template`. " + "The old name will be removed in v0.16.", + DeprecationWarning, + stacklevel=2, + ) + + return resolve_chat_template + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + MODALITY_PLACEHOLDERS_MAP = { "image": "<##IMAGE##>", "audio": "<##AUDIO##>", @@ -311,325 +315,8 @@ class ConversationMessage(TypedDict, total=False): # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] -# Used internally -_ChatTemplateContentFormat = Literal["string", "openai"] - - -def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: - if isinstance(node, jinja2.nodes.Name): - return node.ctx == "load" and node.name == varname - - return False - - -def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: - if isinstance(node, jinja2.nodes.Getitem): - return ( - _is_var_access(node.node, varname) - and isinstance(node.arg, jinja2.nodes.Const) - and node.arg.value == key - ) - - if isinstance(node, jinja2.nodes.Getattr): - return _is_var_access(node.node, varname) and node.attr == key - - return False - - -def _is_var_or_elems_access( - node: jinja2.nodes.Node, - varname: str, - key: str | None = None, -) -> bool: - if isinstance(node, jinja2.nodes.Filter): - return node.node is not None and _is_var_or_elems_access( - node.node, varname, key - ) - if isinstance(node, jinja2.nodes.Test): - return _is_var_or_elems_access(node.node, varname, key) - - if isinstance(node, jinja2.nodes.Getitem) and isinstance( - node.arg, jinja2.nodes.Slice - ): - return _is_var_or_elems_access(node.node, varname, key) - - return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) - - -def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): - # Global variable that is implicitly defined at the root - yield root, varname - - # Iterative BFS - related_varnames = deque([varname]) - while related_varnames: - related_varname = related_varnames.popleft() - - for assign_ast in root.find_all(jinja2.nodes.Assign): - lhs = assign_ast.target - rhs = assign_ast.node - - if _is_var_or_elems_access(rhs, related_varname): - assert isinstance(lhs, jinja2.nodes.Name) - yield assign_ast, lhs.name - - # Avoid infinite looping for self-assignment - if lhs.name != related_varname: - related_varnames.append(lhs.name) - - -# NOTE: The proper way to handle this is to build a CFG so that we can handle -# the scope in which each variable is defined, but that is too complicated -def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): - messages_varnames = [ - varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") - ] - - # Search for {%- for message in messages -%} loops - for loop_ast in root.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - loop_target = loop_ast.target - - for varname in messages_varnames: - if _is_var_or_elems_access(loop_iter, varname): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name - break - - -def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): - message_varnames = [ - varname for _, varname in _iter_nodes_assign_messages_item(root) - ] - - # Search for {%- for content in message['content'] -%} loops - for loop_ast in root.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - loop_target = loop_ast.target - - for varname in message_varnames: - if _is_var_or_elems_access(loop_iter, varname, "content"): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name - break - - -def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: - try: - jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) - return jinja_compiled.environment.parse(chat_template) - except Exception: - logger.exception("Error when compiling Jinja template") - return None - - -@lru_cache(maxsize=32) -def _detect_content_format( - chat_template: str, - *, - default: _ChatTemplateContentFormat, -) -> _ChatTemplateContentFormat: - jinja_ast = _try_extract_ast(chat_template) - if jinja_ast is None: - return default - - try: - next(_iter_nodes_assign_content_item(jinja_ast)) - except StopIteration: - return "string" - except Exception: - logger.exception("Error when parsing AST of Jinja template") - return default - else: - return "openai" - - -def resolve_mistral_chat_template( - chat_template: str | None, - **kwargs: Any, -) -> str | None: - if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: - raise ValueError( - "'chat_template' or 'chat_template_kwargs' cannot be overridden " - "for mistral tokenizer." - ) - - return None - - -_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() -""" -Used in `_try_get_processor_chat_template` to avoid calling -`cached_get_processor` again if the processor fails to be loaded. - -This is needed because `lru_cache` does not cache when an exception happens. -""" - - -def _try_get_processor_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - model_config: ModelConfig, -) -> str | None: - cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) - if cache_key in _PROCESSOR_CHAT_TEMPLATES: - return _PROCESSOR_CHAT_TEMPLATES[cache_key] - - try: - processor = cached_get_processor( - tokenizer.name_or_path, - processor_cls=( - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, - ), - trust_remote_code=model_config.trust_remote_code, - ) - if ( - isinstance(processor, ProcessorMixin) - and hasattr(processor, "chat_template") - and (chat_template := processor.chat_template) is not None - ): - _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template - return chat_template - except Exception: - logger.debug( - "Failed to load AutoProcessor chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) - - _PROCESSOR_CHAT_TEMPLATES[cache_key] = None - return None - - -def resolve_hf_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - chat_template: str | None, - tools: list[dict[str, Any]] | None, - *, - model_config: ModelConfig, -) -> str | None: - # 1st priority: The given chat template - if chat_template is not None: - return chat_template - - # 2nd priority: AutoProcessor chat template, unless tool calling is enabled - if tools is None: - chat_template = _try_get_processor_chat_template(tokenizer, model_config) - if chat_template is not None: - return chat_template - - # 3rd priority: AutoTokenizer chat template - try: - return tokenizer.get_chat_template(chat_template, tools=tools) - except Exception: - logger.debug( - "Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) - - # 4th priority: Predefined fallbacks - path = get_chat_template_fallback_path( - model_type=model_config.hf_config.model_type, - tokenizer_name_or_path=model_config.tokenizer, - ) - if path is not None: - logger.info_once( - "Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", - tokenizer.name_or_path, - ) - chat_template = load_chat_template(path) - else: - logger.debug_once( - "There is no chat template fallback for %s", tokenizer.name_or_path - ) - - return chat_template - - -def _resolve_chat_template_content_format( - chat_template: str | None, - tools: list[dict[str, Any]] | None, - tokenizer: TokenizerLike | None, - *, - model_config: ModelConfig, -) -> _ChatTemplateContentFormat: - if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - hf_chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=chat_template, - tools=tools, - model_config=model_config, - ) - else: - hf_chat_template = None - - jinja_text = ( - hf_chat_template - if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True) - ) - - detected_format = ( - "string" - if jinja_text is None - else _detect_content_format(jinja_text, default="string") - ) - - return detected_format - - -@lru_cache -def _log_chat_template_content_format( - chat_template: str | None, - given_format: ChatTemplateContentFormatOption, - detected_format: ChatTemplateContentFormatOption, -): - logger.info( - "Detected the chat template content format to be '%s'. " - "You can set `--chat-template-content-format` to override this.", - detected_format, - ) - - if given_format != "auto" and given_format != detected_format: - logger.warning( - "You specified `--chat-template-content-format %s` " - "which is different from the detected format '%s'. " - "If our automatic detection is incorrect, please consider " - "opening a GitHub issue so that we can improve it: " - "https://github.com/vllm-project/vllm/issues/new/choose", - given_format, - detected_format, - ) - - -def resolve_chat_template_content_format( - chat_template: str | None, - tools: list[dict[str, Any]] | None, - given_format: ChatTemplateContentFormatOption, - tokenizer: TokenizerLike | None, - *, - model_config: ModelConfig, -) -> _ChatTemplateContentFormat: - if given_format != "auto": - return given_format - - detected_format = _resolve_chat_template_content_format( - chat_template, - tools, - tokenizer, - model_config=model_config, - ) - - _log_chat_template_content_format( - chat_template, - given_format=given_format, - detected_format=detected_format, - ) - - return detected_format +# After resolving "auto" +ChatTemplateContentFormat = Literal["string", "openai"] ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] @@ -1584,7 +1271,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, interleave_strings: bool, ) -> list[ConversationMessage]: role = message["role"] @@ -1660,7 +1347,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: def parse_chat_messages( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], MultiModalDataDict | None, @@ -1691,7 +1378,7 @@ def parse_chat_messages( def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], Awaitable[MultiModalDataDict | None], @@ -1719,173 +1406,6 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() -# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 -# only preserve the parse function used to resolve chat template kwargs -class AssistantTracker(jinja2.ext.Extension): - tags = {"generation"} - - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: - lineno = next(parser.stream).lineno - body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - call = self.call_method("_generation_support") - call_block = jinja2.nodes.CallBlock(call, [], [], body) - return call_block.set_lineno(lineno) - - -def _resolve_chat_template_kwargs( - chat_template: str, -): - env = jinja2.sandbox.ImmutableSandboxedEnvironment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[AssistantTracker, jinja2.ext.loopcontrols], - ) - parsed_content = env.parse(chat_template) - template_vars = jinja2.meta.find_undeclared_variables(parsed_content) - return template_vars - - -_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) - - -@lru_cache -def _get_hf_base_chat_template_params() -> frozenset[str]: - # Get standard parameters from HuggingFace's base tokenizer class. - # This dynamically extracts parameters from PreTrainedTokenizer's - # apply_chat_template method, ensuring compatibility with tokenizers - # that use **kwargs to receive standard parameters. - - # Read signature from HF's base class - the single source of truth - base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template) - # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders - return frozenset( - p.name - for p in base_sig.parameters.values() - if p.kind - not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) - ) - - -def resolve_chat_template_kwargs( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - chat_template: str, - chat_template_kwargs: dict[str, Any], - raise_on_unexpected: bool = True, -) -> dict[str, Any]: - # We exclude chat_template from kwargs here, because - # chat template has been already resolved at this stage - unexpected_vars = {"chat_template", "tokenize"} - if raise_on_unexpected and ( - unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() - ): - raise ValueError( - "Found unexpected chat template kwargs from request: " - f"{unexpected_in_kwargs}" - ) - - fn_kw = { - k - for k in chat_template_kwargs - if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) - } - template_vars = _cached_resolve_chat_template_kwargs(chat_template) - - # Allow standard HF parameters even if tokenizer uses **kwargs to receive them - hf_base_params = _get_hf_base_chat_template_params() - - accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars - return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} - - -def apply_hf_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - conversation: list[ConversationMessage], - chat_template: str | None, - tools: list[dict[str, Any]] | None, - *, - model_config: ModelConfig, - **kwargs: Any, -) -> str: - hf_chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=chat_template, - tools=tools, - model_config=model_config, - ) - - if hf_chat_template is None: - raise ValueError( - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ) - - resolved_kwargs = resolve_chat_template_kwargs( - tokenizer=tokenizer, - chat_template=hf_chat_template, - chat_template_kwargs=kwargs, - ) - - try: - return tokenizer.apply_chat_template( - conversation=conversation, # type: ignore[arg-type] - tools=tools, # type: ignore[arg-type] - chat_template=hf_chat_template, - tokenize=False, - **resolved_kwargs, - ) - - # External library exceptions can sometimes occur despite the framework's - # internal exception management capabilities. - except Exception as e: - # Log and report any library-related exceptions for further - # investigation. - logger.exception( - "An error occurred in `transformers` while applying chat template" - ) - raise ValueError(str(e)) from e - - -def apply_mistral_chat_template( - tokenizer: "MistralTokenizer", - messages: list[ChatCompletionMessageParam], - chat_template: str | None, - tools: list[dict[str, Any]] | None, - **kwargs: Any, -) -> list[int]: - from mistral_common.exceptions import MistralCommonException - - # The return value of resolve_mistral_chat_template is always None, - # and we won't use it. - resolve_mistral_chat_template( - chat_template=chat_template, - **kwargs, - ) - - try: - return tokenizer.apply_chat_template( - messages=messages, - tools=tools, - **kwargs, - ) - # mistral-common uses assert statements to stop processing of input - # if input does not comply with the expected format. - # We convert those assertion errors to ValueErrors so they can be - # properly caught in the preprocessing_input step - except (AssertionError, MistralCommonException) as e: - raise ValueError(str(e)) from e - - # External library exceptions can sometimes occur despite the framework's - # internal exception management capabilities. - except Exception as e: - # Log and report any library-related exceptions for further - # investigation. - logger.exception( - "An error occurred in `mistral_common` while applying chat template" - ) - raise ValueError(str(e)) from e - - def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): idx = 0 for msg in conversation: diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index a22ab02229cd..f6ceac768c86 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -39,9 +39,9 @@ from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser -from vllm.tokenizers.protocol import TokenizerLike +from vllm.renderers import RendererLike +from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ToolParser -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid if TYPE_CHECKING: @@ -229,8 +229,8 @@ def __init__( self, *, response_messages: list[ResponseInputOutputItem], - tokenizer: AnyTokenizer, - reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None, + renderer: RendererLike, + reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, request: ResponsesRequest, available_tools: list[str] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, @@ -248,6 +248,7 @@ def __init__( if reasoning_parser_cls is None: raise ValueError("reasoning_parser_cls must be provided.") + tokenizer = renderer.get_tokenizer() self.parser = get_responses_parser_for_simple_context( tokenizer=tokenizer, reasoning_parser_cls=reasoning_parser_cls, @@ -257,6 +258,7 @@ def __init__( ) self.tool_parser_cls = tool_parser_cls self.request = request + self.renderer = renderer self.tokenizer = tokenizer self.available_tools = available_tools or [] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31319cf64aeb..282f23617f0b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -36,10 +36,6 @@ from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages, - resolve_chat_template_content_format, ) from vllm.entrypoints.score_utils import ( ScoreContentPartParam, @@ -73,7 +69,6 @@ from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask from vllm.tokenizers import TokenizerLike -from vllm.tokenizers.mistral import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter @@ -793,7 +788,7 @@ def preprocess_chat( tools: list[dict[str, Any]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[TokensPrompt]: + ) -> list[TextPrompt | TokensPrompt]: """ Generate prompt for a chat conversation. The pre-processed prompt can then be used as input for the other LLM methods. @@ -814,63 +809,27 @@ def preprocess_chat( # messages is list[...] list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] - tokenizer = self.get_tokenizer() - model_config = self.model_config - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) + renderer = self.llm_engine.renderer - _chat_template_kwargs: dict[str, Any] = dict( - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tools, - ) - _chat_template_kwargs.update(chat_template_kwargs or {}) + chat_template_kwargs = { + "chat_template": chat_template, + "add_generation_prompt": add_generation_prompt, + "continue_final_message": continue_final_message, + "tools": tools, + **(chat_template_kwargs or {}), + } - prompts: list[TokensPrompt] = [] + prompts = list[TextPrompt | TokensPrompt]() for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't + # NOTE: renderer.render_messages() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. - conversation, mm_data, mm_uuids = parse_chat_messages( + _, prompt = renderer.render_messages( msgs, - model_config, - content_format=resolved_content_format, + chat_template_content_format=chat_template_content_format, + **chat_template_kwargs, ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode( - prompt_str, add_special_tokens=False - ) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - prompt["multi_modal_uuids"] = mm_uuids - if mm_processor_kwargs is not None: prompt["mm_processor_kwargs"] = mm_processor_kwargs @@ -1299,9 +1258,6 @@ def _cross_encoding_score( ) -> list[ScoringRequestOutput]: model_config = self.model_config - if isinstance(tokenizer, MistralTokenizer): - raise ValueError("Score API is not supported for Mistral tokenizer") - if len(data_1) == 1: data_1 = data_1 * len(data_2) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5d0eacae34dd..028aafce3753 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,6 +41,7 @@ AnthropicMessagesResponse, ) from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -87,7 +88,6 @@ cli_env_setup, load_aware_call, log_non_default_args, - process_chat_template, process_lora_modules, with_cancellation, ) @@ -1098,9 +1098,7 @@ async def init_app_state( supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) - resolved_chat_template = await process_chat_template( - args.chat_template, engine_client, vllm_config.model_config - ) + resolved_chat_template = load_chat_template(args.chat_template) if args.tool_server == "demo": tool_server: ToolServer | None = DemoToolServer() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 98fc7810faf9..0ae741388f36 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -192,7 +192,8 @@ async def create_chat_completion( model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer() + renderer = self.engine_client.renderer + tokenizer = renderer.tokenizer tool_parser = self.tool_parser @@ -234,9 +235,10 @@ async def create_chat_completion( ) if error_check_ret is not None: return error_check_ret + conversation, engine_prompts = await self._preprocess_chat( request, - tokenizer, + renderer, request.messages, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self.chat_template_content_format, @@ -1700,7 +1702,7 @@ def _create_chat_logprobs( else: if tokenizer is None: raise ValueError( - "Tokenizer not available when `skip_tokenizer_init=True`" + "Unable to get tokenizer because `skip_tokenizer_init=True`" ) token = tokenizer.decode(token_id) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1be0afc8c74e..11885548e641 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -125,13 +125,7 @@ async def create_completion( try: lora_request = self._maybe_get_adapters(request) - - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - + renderer = self._get_completion_renderer() engine_prompts = await renderer.render_prompt_and_embeds( prompt_or_prompts=request.prompt, prompt_embeds=request.prompt_embeds, @@ -258,6 +252,8 @@ async def create_completion( stream = request.stream and not request.use_beam_search # Streaming response + tokenizer = self.renderer.tokenizer + if stream: return self.completion_stream_generator( request, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5f7cfaa53ec1..f3856519f0d4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -6,10 +6,9 @@ import time import traceback from collections.abc import AsyncGenerator, Callable, Iterable, Mapping -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from http import HTTPStatus -from typing import Any, ClassVar, Generic, TypeAlias, TypeVar +from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request @@ -26,10 +25,6 @@ ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format, ) from vllm.entrypoints.context import ( ConversationContext, @@ -99,10 +94,9 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.renderers import RendererLike from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike -from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer -from vllm.tokenizers.mistral import MistralTokenizer from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tracing import ( contains_trace_headers, @@ -113,10 +107,8 @@ from vllm.utils.async_utils import ( AsyncMicrobatchTokenizer, collect_from_async_generator, - make_async, merge_async_iterators, ) -from vllm.utils.collection_utils import is_list_of from vllm.v1.engine import EngineCoreRequest @@ -201,7 +193,6 @@ class ResponseGenerationMixin: @dataclass(kw_only=True) class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]): - # Shared across all requests request: RequestT raw_request: Request | None = None model_name: str @@ -209,9 +200,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[Requ created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - # Shared across most requests - tokenizer: TokenizerLike | None = None - @dataclass(kw_only=True) class ClassificationServeContext(ServeContext[ClassificationRequest]): @@ -247,16 +235,13 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self._apply_mistral_chat_template_async = make_async( - apply_mistral_chat_template, executor=self._tokenizer_executor - ) self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack self.input_processor = self.models.input_processor self.io_processor = self.models.io_processor + self.renderer = self.models.renderer self.model_config = self.models.model_config self.max_model_len = self.model_config.max_model_len @@ -541,14 +526,14 @@ async def beam_search( prompt_logprobs=None, ) - def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer: + def _get_completion_renderer(self) -> BaseRenderer: """ Get a Renderer instance with the provided tokenizer. Uses shared async tokenizer pool for efficiency. """ return CompletionRenderer( model_config=self.model_config, - tokenizer=tokenizer, + tokenizer=self.renderer.tokenizer, async_tokenizer_pool=self._async_tokenizer_pool, ) @@ -1102,7 +1087,7 @@ def _validate_chat_template( async def _preprocess_chat( self, request: ChatLikeRequest | ResponsesRequest, - tokenizer: TokenizerLike | None, + renderer: RendererLike, messages: list[ChatCompletionMessageParam], chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, @@ -1114,56 +1099,46 @@ async def _preprocess_chat( tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], list[TokensPrompt]]: - model_config = self.model_config - - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tool_dicts, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + chat_template_kwargs = { + "chat_template": chat_template, + "add_generation_prompt": add_generation_prompt, + "continue_final_message": continue_final_message, + "tools": tool_dicts, + "documents": documents, + **(chat_template_kwargs or {}), + } + + # Use the async tokenizer in `OpenAIServing` if possible. + # Later we can move it into the renderer so that we can return both + # text and token IDs in the same prompt from `render_messages_async` + # which is used for logging and `enable_response_messages`. + from vllm.tokenizers.mistral import MistralTokenizer + + conversation, engine_prompt = await renderer.render_messages_async( messages, - model_config, - content_format=resolved_content_format, + chat_template_content_format=chat_template_content_format, + tokenize=isinstance(renderer.tokenizer, MistralTokenizer), + **chat_template_kwargs, ) - _chat_template_kwargs: dict[str, Any] = dict( - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tool_dicts, - documents=documents, - ) - _chat_template_kwargs.update(chat_template_kwargs or {}) + if "prompt_token_ids" not in engine_prompt: + extra_data = engine_prompt + engine_prompt = await self._tokenize_prompt_input_async( + request, + renderer.get_tokenizer(), + engine_prompt["prompt"], + add_special_tokens=add_special_tokens, + ) - request_prompt: str | list[int] + # Fill in other keys like MM data + engine_prompt.update(extra_data) # type: ignore - if tokenizer is None: - request_prompt = "placeholder" - elif isinstance(tokenizer, MistralTokenizer): - request_prompt = await self._apply_mistral_chat_template_async( - tokenizer, - messages=messages, - **_chat_template_kwargs, - ) - elif isinstance(tokenizer, DeepseekV32Tokenizer): - request_prompt = tokenizer.apply_chat_template( - conversation=conversation, - messages=messages, - model_config=model_config, - **_chat_template_kwargs, - ) - else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) + engine_prompt = cast(TokensPrompt, engine_prompt) - mm_data = await mm_data_future + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + if (cache_salt := getattr(request, "cache_salt", None)) is not None: + engine_prompt["cache_salt"] = cache_salt # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser @@ -1179,46 +1154,9 @@ async def _preprocess_chat( "or Responses API requests." ) raise NotImplementedError(msg) - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore - - if tokenizer is None: - assert isinstance(request_prompt, str), ( - "Prompt has to be a string", - "when the tokenizer is not initialised", - ) - prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) - elif isinstance(request_prompt, str): - prompt_inputs = await self._tokenize_prompt_input_async( - request, - tokenizer, - request_prompt, - add_special_tokens=add_special_tokens, - ) - else: - # For MistralTokenizer - assert is_list_of(request_prompt, int), ( - "Prompt has to be either a string or a list of token ids" - ) - prompt_inputs = TokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt, - ) - engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) - if "prompt" in prompt_inputs: - engine_prompt["prompt"] = prompt_inputs["prompt"] - - if mm_data is not None: - engine_prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - engine_prompt["multi_modal_uuids"] = mm_uuids - - if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs - - if hasattr(request, "cache_salt") and request.cache_salt is not None: - engine_prompt["cache_salt"] = request.cache_salt + tokenizer = renderer.get_tokenizer() + request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore return conversation, [engine_prompt] @@ -1252,7 +1190,7 @@ async def _process_inputs( async def _render_next_turn( self, request: ResponsesRequest, - tokenizer: TokenizerLike | None, + renderer: RendererLike, messages: list[ResponseInputOutputItem], tool_dicts: list[dict[str, Any]] | None, tool_parser, @@ -1265,7 +1203,7 @@ async def _render_next_turn( _, engine_prompts = await self._preprocess_chat( request, - tokenizer, + renderer, new_messages, tool_dicts=tool_dicts, tool_parser=tool_parser, @@ -1342,7 +1280,7 @@ async def _generate_with_builtin_tools( elif isinstance(context, ParsableContext): engine_prompts = await self._render_next_turn( context.request, - context.tokenizer, + context.renderer, context.parser.response_messages, context.tool_dicts, context.tool_parser_cls, diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 953398a9a72a..d4335016f635 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -71,6 +71,7 @@ def __init__( self.input_processor = self.engine_client.input_processor self.io_processor = self.engine_client.io_processor + self.renderer = self.engine_client.renderer self.model_config = self.engine_client.model_config self.max_model_len = self.model_config.max_model_len diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index fb2a6440daf0..9d22939267ed 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -112,6 +112,7 @@ from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid @@ -350,7 +351,8 @@ async def create_responses( try: lora_request = self._maybe_get_adapters(request) model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer() + renderer = self.engine_client.renderer + tokenizer = renderer.get_tokenizer() if self.use_harmony: messages, engine_prompts = self._make_request_with_harmony( @@ -358,7 +360,7 @@ async def create_responses( ) else: messages, engine_prompts = await self._make_request( - request, prev_response, tokenizer + request, prev_response, renderer ) except ( @@ -424,7 +426,7 @@ async def create_responses( # tokens during generation instead of at the end context = ParsableContext( response_messages=messages, - tokenizer=tokenizer, + renderer=renderer, reasoning_parser_cls=self.reasoning_parser, request=request, tool_parser_cls=self.tool_parser, @@ -553,7 +555,7 @@ async def _make_request( self, request: ResponsesRequest, prev_response: ResponsesResponse | None, - tokenizer: TokenizerLike, + renderer: RendererLike, ): tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) # Construct the input messages. @@ -565,7 +567,7 @@ async def _make_request( ) _, engine_prompts = await self._preprocess_chat( request, - tokenizer, + renderer, messages, tool_dicts=tool_dicts, tool_parser=self.tool_parser, @@ -583,6 +585,7 @@ def _make_request_with_harmony( raise NotImplementedError( "Only 'auto' tool_choice is supported in response API with Harmony" ) + messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index e166405a6f05..9ddf9b7bb2b7 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -52,8 +52,6 @@ async def _preprocess( """ ctx = cast(ClassificationServeContext, ctx) try: - ctx.tokenizer = await self.engine_client.get_tokenizer() - request_obj = ctx.request if isinstance(request_obj, ClassificationChatRequest): @@ -74,7 +72,7 @@ async def _preprocess( _, engine_prompts = await self._preprocess_chat( cast(ChatCompletionRequest, chat_request), - ctx.tokenizer, + self.renderer, messages, chat_template=( chat_request.chat_template @@ -102,7 +100,7 @@ async def _preprocess( ctx.engine_prompts = [] return None - renderer = self._get_renderer(ctx.tokenizer) + renderer = self._get_completion_renderer() prompt_input = cast(str | list[str], input_data) ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=prompt_input, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index f5a21208ed80..662db531db5c 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -78,13 +78,10 @@ async def _preprocess( try: ctx.lora_request = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if isinstance(ctx.request, EmbeddingChatRequest): _, ctx.engine_prompts = await self._preprocess_chat( ctx.request, - tokenizer, + self.renderer, ctx.request.messages, chat_template=ctx.request.chat_template or ctx.chat_template, chat_template_content_format=ctx.chat_template_content_format, @@ -93,6 +90,7 @@ async def _preprocess( add_special_tokens=ctx.request.add_special_tokens, ) else: + renderer = self._get_completion_renderer() ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=ctx.request.input, config=self._build_render_config(ctx.request), diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 4e1b326806ea..504083ffb271 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -94,12 +94,6 @@ async def create_pooling( try: lora_request = self._maybe_get_adapters(request) - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if getattr(request, "dimensions", None) is not None: return self.create_error_response( "dimensions is currently not supported" @@ -140,7 +134,7 @@ async def create_pooling( _, engine_prompts = await self._preprocess_chat( request, - tokenizer, + self.renderer, request.messages, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self.chat_template_content_format, @@ -151,6 +145,7 @@ async def create_pooling( add_special_tokens=request.add_special_tokens, ) elif isinstance(request, PoolingCompletionRequest): + renderer = self._get_completion_renderer() engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.input, config=self._build_render_config(request), diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index edbfcd03ac92..0407d466d714 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -3,6 +3,7 @@ import asyncio import time from collections.abc import AsyncGenerator, Mapping +from concurrent.futures import ThreadPoolExecutor from typing import Any from fastapi import Request @@ -61,6 +62,8 @@ def __init__( log_error_stack=log_error_stack, ) + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + async def _embedding_score( self, tokenizer: TokenizerLike, @@ -280,8 +283,7 @@ async def _run_scoring( raw_request: Request | None = None, ) -> list[PoolingRequestOutput] | ErrorResponse: lora_request = self._maybe_get_adapters(request) - - tokenizer = await self.engine_client.get_tokenizer() + tokenizer = self.renderer.get_tokenizer() truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 0b07f0b18dfd..055c8a41a2c7 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -65,9 +65,6 @@ async def create_tokenize( try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if isinstance(request, TokenizeChatRequest): tool_dicts = ( None @@ -84,7 +81,7 @@ async def create_tokenize( _, engine_prompts = await self._preprocess_chat( request, - tokenizer, + self.renderer, request.messages, tool_dicts=tool_dicts, chat_template=request.chat_template or self.chat_template, @@ -95,6 +92,7 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: + renderer = self._get_completion_renderer() engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.prompt, config=self._build_render_config(request), @@ -114,6 +112,7 @@ async def create_tokenize( token_strs = None if request.return_token_strs: + tokenizer = self.renderer.get_tokenizer() token_strs = tokenizer.convert_ids_to_tokens(input_ids) return TokenizeResponse( @@ -135,8 +134,7 @@ async def create_detokenize( request_id = f"tokn-{self._base_request_id(raw_request)}" lora_request = self._maybe_get_adapters(request) - - tokenizer = await self.engine_client.get_tokenizer() + tokenizer = self.renderer.get_tokenizer() self._log_inputs( request_id, @@ -159,7 +157,7 @@ async def get_tokenizer_info( ) -> TokenizerInfoResponse | ErrorResponse: """Get comprehensive tokenizer information.""" try: - tokenizer = await self.engine_client.get_tokenizer() + tokenizer = self.renderer.get_tokenizer() info = TokenizerInfo(tokenizer, self.chat_template).to_dict() return TokenizerInfoResponse(**info) except Exception as e: diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index f4a633c69cb0..949684d52814 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -6,21 +6,13 @@ import functools import os from argparse import Namespace -from pathlib import Path from typing import Any from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.background import BackgroundTask, BackgroundTasks -from vllm.config import ModelConfig from vllm.engine.arg_utils import EngineArgs -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ( - load_chat_template, - resolve_hf_chat_template, - resolve_mistral_chat_template, -) from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -30,7 +22,6 @@ from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) @@ -283,37 +274,3 @@ def process_lora_modules( else: lora_modules += default_mm_lora_paths return lora_modules - - -async def process_chat_template( - args_chat_template: Path | str | None, - engine_client: EngineClient, - model_config: ModelConfig, -) -> str | None: - resolved_chat_template = load_chat_template(args_chat_template) - if resolved_chat_template is not None: - # Get the tokenizer to check official template - tokenizer = await engine_client.get_tokenizer() - - if isinstance(tokenizer, MistralTokenizer): - # The warning is logged in resolve_mistral_chat_template. - resolved_chat_template = resolve_mistral_chat_template( - chat_template=resolved_chat_template - ) - else: - hf_chat_template = resolve_hf_chat_template( - tokenizer=tokenizer, - chat_template=None, - tools=None, - model_config=model_config, - ) - - if hf_chat_template != resolved_chat_template: - logger.warning( - "Using supplied chat template: %s\n" - "It is different from official chat template '%s'. " - "This discrepancy may lead to performance degradation.", - resolved_chat_template, - model_config.model, - ) - return resolved_chat_template diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 0372b06d0017..8e8878c21fdf 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,6 +17,7 @@ MultiModalUUIDDict, ) from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.renderers import renderer_from_config from vllm.tokenizers import TokenizerLike from vllm.utils.jsontree import json_iter_leaves from vllm.v1.metrics.stats import MultiModalCacheStats @@ -46,26 +47,24 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: TokenizerLike | None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_processor_cache: BaseMultiModalProcessorCache | None = None, ) -> None: super().__init__() self.model_config = model_config - self.tokenizer = tokenizer + self.renderer = renderer_from_config(model_config) self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None - def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "You cannot pass text prompts when `skip_tokenizer_init=True`" - ) + @property + def tokenizer(self) -> TokenizerLike | None: + return self.renderer.tokenizer - return self.tokenizer + def get_tokenizer(self) -> TokenizerLike: + return self.renderer.get_tokenizer() def get_bos_token_id(self) -> int | None: if self.tokenizer is None: diff --git a/vllm/renderers/__init__.py b/vllm/renderers/__init__.py new file mode 100644 index 000000000000..cd6a11dcc833 --- /dev/null +++ b/vllm/renderers/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .protocol import RendererLike +from .registry import RendererRegistry, renderer_from_config + +__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"] diff --git a/vllm/renderers/deepseekv32.py b/vllm/renderers/deepseekv32.py new file mode 100644 index 000000000000..6f7b7898c024 --- /dev/null +++ b/vllm/renderers/deepseekv32.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import cached_get_tokenizer +from vllm.tokenizers.deepseekv32 import DeepseekV32Tokenizer + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +class DeepseekV32Renderer(RendererLike): + @classmethod + def from_config( + cls, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = cached_get_tokenizer( + tokenizer_cls=DeepseekV32Tokenizer, + **tokenizer_kwargs, + ) + + self._tokenizer = tokenizer + + @property + def tokenizer(self) -> DeepseekV32Tokenizer | None: + return self._tokenizer + + def get_tokenizer(self) -> DeepseekV32Tokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + self.config, + content_format="string", + ) + + prompt_raw = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + **kwargs, + ) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + self.config, + content_format="string", + ) + + prompt_raw = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + **kwargs, + ) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py new file mode 100644 index 000000000000..dfbfdc7d8bcf --- /dev/null +++ b/vllm/renderers/hf.py @@ -0,0 +1,599 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections import deque +from collections.abc import Set +from functools import lru_cache +from typing import Any, cast + +import jinja2 +import jinja2.ext +import jinja2.meta +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormat, + ChatTemplateContentFormatOption, + ConversationMessage, + load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import cached_get_tokenizer +from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer +from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.func_utils import supports_kw + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: HfTokenizer, + *, + trust_remote_code: bool, +) -> str | None: + cache_key = (tokenizer.name_or_path, trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ) + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None + return None + + +def resolve_chat_template( + tokenizer: HfTokenizer, + chat_template: str | None, + tools: list[dict[str, Any]] | None, + *, + model_config: "ModelConfig", +) -> str | None: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + chat_template = _try_get_processor_chat_template( + tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + if chat_template is not None: + return chat_template + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=tokenizer.name_or_path, + ) + if path is not None: + logger.info_once( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) + chat_template = load_chat_template(path) + else: + logger.debug_once( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) + + return chat_template + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str | None = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: + import transformers.utils.chat_template_utils as hf_chat_utils + + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +@lru_cache(maxsize=32) +def _detect_content_format( + chat_template: str, + *, + default: ChatTemplateContentFormat, +) -> ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + resolved_chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + + jinja_text = ( + resolved_chat_template + if isinstance(resolved_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) + + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) + + return detected_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: str | None, # For caching purposes + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +def resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + given_format: ChatTemplateContentFormatOption, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + if given_format != "auto": + return given_format + + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + tokenizer, + model_config=model_config, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format + + +# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 +# only preserve the parse function used to resolve chat template kwargs +class AssistantTracker(jinja2.ext.Extension): + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node: + lineno = next(parser.stream).lineno + body = parser.parse_statements(("name:endgeneration",), drop_needle=True) + call = self.call_method("_generation_support") + call_block = jinja2.nodes.CallBlock(call, [], [], body) + return call_block.set_lineno(lineno) + + +def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]: + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + +@lru_cache +def _get_hf_base_chat_template_params() -> frozenset[str]: + from transformers import PreTrainedTokenizer + + # Get standard parameters from HuggingFace's base tokenizer class. + # This dynamically extracts parameters from PreTrainedTokenizer's + # apply_chat_template method, ensuring compatibility with tokenizers + # that use **kwargs to receive standard parameters. + + # Read signature from HF's base class - the single source of truth + base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template) + + # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders + return frozenset( + p.name + for p in base_sig.parameters.values() + if p.kind + not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + ) + + +def resolve_chat_template_kwargs( + tokenizer: HfTokenizer, + chat_template: str, + chat_template_kwargs: dict[str, Any], + raise_on_unexpected: bool = True, +) -> dict[str, Any]: + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template", "tokenize"} + if raise_on_unexpected and ( + unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() + ): + raise ValueError( + "Found unexpected chat template kwargs from request: " + f"{unexpected_in_kwargs}" + ) + + fn_kw = { + k + for k in chat_template_kwargs + if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) + } + template_vars = _cached_resolve_chat_template_kwargs(chat_template) + + # Allow standard HF parameters even if tokenizer uses **kwargs to receive them + hf_base_params = _get_hf_base_chat_template_params() + + accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars + return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} + + +def safe_apply_chat_template( + model_config: "ModelConfig", + tokenizer: HfTokenizer, + conversation: list[ConversationMessage], + *, + tools: list[dict[str, Any]] | None = None, + chat_template: str | None = None, + tokenize: bool = True, + **kwargs, +) -> str | list[int]: + chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + if chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ) + + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=chat_template, + chat_template_kwargs=kwargs, + ) + + try: + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **resolved_kwargs, + ) + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template" + ) + raise ValueError(str(e)) from e + + +class HfRenderer(RendererLike): + @classmethod + def from_config( + cls, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = cast( + HfTokenizer, + cached_get_tokenizer( + tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract] + **tokenizer_kwargs, + ), + ) + + self._tokenizer = tokenizer + + @property + def tokenizer(self) -> HfTokenizer | None: + return self._tokenizer + + def get_tokenizer(self) -> HfTokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + model_config = self.config + tokenizer = self.get_tokenizer() + + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + model_config, + content_format=resolve_chat_template_content_format( + chat_template=kwargs.get("chat_template"), + tools=kwargs.get("tools"), + given_format=chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + ) + + prompt_raw = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + **kwargs, + ) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + model_config = self.config + tokenizer = self.get_tokenizer() + + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + model_config, + content_format=resolve_chat_template_content_format( + chat_template=kwargs.get("chat_template"), + tools=kwargs.get("tools"), + given_format=chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + ) + + prompt_raw = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + **kwargs, + ) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] diff --git a/vllm/renderers/mistral.py b/vllm/renderers/mistral.py new file mode 100644 index 000000000000..9e57fffb0800 --- /dev/null +++ b/vllm/renderers/mistral.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import cached_get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.utils.async_utils import make_async + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +def safe_apply_chat_template( + tokenizer: MistralTokenizer, + messages: list[ChatCompletionMessageParam], + **kwargs, +) -> str | list[int]: + from mistral_common.exceptions import MistralCommonException + + try: + return tokenizer.apply_chat_template(messages, **kwargs) + # mistral-common uses assert statements to stop processing of input + # if input does not comply with the expected format. + # We convert those assertion errors to ValueErrors so they can be + # properly caught in the preprocessing_input step + except (AssertionError, MistralCommonException) as e: + raise ValueError(str(e)) from e + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `mistral_common` while applying chat template" + ) + raise ValueError(str(e)) from e + + +class MistralRenderer(RendererLike): + @classmethod + def from_config( + cls, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: ModelConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = cached_get_tokenizer( + tokenizer_cls=MistralTokenizer, + **tokenizer_kwargs, + ) + + self._tokenizer = tokenizer + + self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1) + self._apply_chat_template_async = make_async( + safe_apply_chat_template, executor=self._apply_chat_template_executor + ) + + @property + def tokenizer(self) -> MistralTokenizer | None: + return self._tokenizer + + def get_tokenizer(self) -> MistralTokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + self.config, + content_format="string", + ) + + prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + self.config, + content_format="string", + ) + + prompt_raw = await self._apply_chat_template_async( + tokenizer, messages, **kwargs + ) + + prompt = ( + TextPrompt(prompt=prompt_raw) + if isinstance(prompt_raw, str) + else TokensPrompt(prompt_token_ids=prompt_raw) + ) + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt # type: ignore[return-value] diff --git a/vllm/renderers/protocol.py b/vllm/renderers/protocol.py new file mode 100644 index 000000000000..e788f431b0f8 --- /dev/null +++ b/vllm/renderers/protocol.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any, Protocol + +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.tokenizers import TokenizerLike + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + ) + + +class RendererLike(Protocol): + @classmethod + def from_config( + cls, + config: "ModelConfig", + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + raise NotImplementedError + + @property + def tokenizer(self) -> TokenizerLike | None: + raise NotImplementedError + + def get_tokenizer(self) -> TokenizerLike: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list["ChatCompletionMessageParam"], + **kwargs, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: + raise NotImplementedError + + async def render_messages_async( + self, + messages: list["ChatCompletionMessageParam"], + **kwargs, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: + return self.render_messages(messages, **kwargs) diff --git a/vllm/renderers/registry.py b/vllm/renderers/registry.py new file mode 100644 index 000000000000..a6c402edf32a --- /dev/null +++ b/vllm/renderers/registry.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger +from vllm.tokenizers.registry import tokenizer_args_from_config +from vllm.utils.import_utils import resolve_obj_by_qualname + +from .protocol import RendererLike + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +logger = init_logger(__name__) + + +_VLLM_RENDERERS = { + "deepseekv32": ("deepseekv32", "DeepseekV32Renderer"), + "hf": ("hf", "HfRenderer"), + "mistral": ("mistral", "MistralRenderer"), + "terratorch": ("terratorch", "TerratorchRenderer"), +} + + +@dataclass +class RendererRegistry: + # Renderer mode -> (renderer module, renderer class) + renderers: dict[str, tuple[str, str]] = field(default_factory=dict) + + def register(self, renderer_mode: str, module: str, class_name: str) -> None: + if renderer_mode in self.renderers: + logger.warning( + "%s.%s is already registered for renderer_mode=%r. " + "It is overwritten by the new one.", + module, + class_name, + renderer_mode, + ) + + self.renderers[renderer_mode] = (module, class_name) + + return None + + def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]: + if renderer_mode not in self.renderers: + raise ValueError(f"No renderer registered for {renderer_mode=!r}.") + + module, class_name = self.renderers[renderer_mode] + logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}") + + return resolve_obj_by_qualname(f"{module}.{class_name}") + + def load_renderer( + self, + renderer_mode: str, + config: "ModelConfig", + tokenizer_kwargs: dict[str, Any], + ) -> RendererLike: + renderer_cls = self.load_renderer_cls(renderer_mode) + return renderer_cls.from_config(config, tokenizer_kwargs) + + +RENDERER_REGISTRY = RendererRegistry( + { + mode: (f"vllm.renderers.{mod_relname}", cls_name) + for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items() + } +) +"""The global `RendererRegistry` instance.""" + + +def renderer_from_config(config: "ModelConfig", **kwargs): + tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config( + config, **kwargs + ) + + if config.tokenizer_mode == "auto" and config.model_impl == "terratorch": + renderer_mode = "terratorch" + else: + renderer_mode = tokenizer_mode + + return RENDERER_REGISTRY.load_renderer( + renderer_mode, + config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) diff --git a/vllm/renderers/terratorch.py b/vllm/renderers/terratorch.py new file mode 100644 index 000000000000..6a3f28c7aa07 --- /dev/null +++ b/vllm/renderers/terratorch.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +class TerratorchRenderer(RendererLike): + @classmethod + def from_config( + cls, + config: "ModelConfig", + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config) + + def __init__(self, config: ModelConfig) -> None: + super().__init__() + + self.config = config + + if not config.skip_tokenizer_init: + raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`") + + @property + def tokenizer(self) -> TokenizerLike | None: + return None + + def get_tokenizer(self) -> TokenizerLike: + raise ValueError("Tokenizer not available for Terratorch renderer") + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + model_config = self.config + + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + model_config, + content_format="string", + ) + + prompt = TokensPrompt(prompt_token_ids=[1]) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + model_config = self.config + + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + model_config, + content_format="string", + ) + + prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/tokenizers/deepseek_v32.py b/vllm/tokenizers/deepseek_v32.py index bf279a5cf67c..5089e2569c64 100644 --- a/vllm/tokenizers/deepseek_v32.py +++ b/vllm/tokenizers/deepseek_v32.py @@ -63,6 +63,7 @@ def apply_chat_template( drop_thinking = messages[-1]["role"] == "user" encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) + prompt_str = encode_messages(messages, **encode_config) # type: ignore if kwargs.get("tokenize", True): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a6ee241c4115..c24c8a5c8c07 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,9 +24,10 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config +from vllm.tokenizers import TokenizerLike from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -108,12 +109,7 @@ def __init__( "enabling logging without default stat loggers." ) - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = cached_tokenizer_from_config(self.model_config) - - self.input_processor = InputProcessor(self.vllm_config, tokenizer) + self.input_processor = InputProcessor(self.vllm_config) self.io_processor = get_io_processor( self.vllm_config, self.model_config.io_processor_plugin, @@ -701,13 +697,12 @@ async def encode( def tokenizer(self) -> TokenizerLike | None: return self.input_processor.tokenizer - async def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) + def get_tokenizer(self) -> TokenizerLike: + return self.input_processor.get_tokenizer() - return self.tokenizer + @property + def renderer(self) -> RendererLike: + return self.input_processor.renderer async def is_tracing_enabled(self) -> bool: return self.observability_config.otlp_traces_endpoint is not None # type: ignore diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 65e0c845b0af..50d82688bd99 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -18,6 +18,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import MistralTokenizer @@ -40,7 +41,6 @@ class InputProcessor: def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerLike | None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: self.vllm_config = vllm_config @@ -56,7 +56,6 @@ def __init__( self.input_preprocessor = InputPreprocessor( self.model_config, - tokenizer, mm_registry, mm_processor_cache=self.mm_processor_cache, ) @@ -65,6 +64,13 @@ def __init__( def tokenizer(self) -> TokenizerLike | None: return self.input_preprocessor.tokenizer + def get_tokenizer(self) -> TokenizerLike: + return self.input_preprocessor.get_tokenizer() + + @property + def renderer(self) -> RendererLike: + return self.input_preprocessor.renderer + def _validate_logprobs( self, params: SamplingParams, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 1011317b706d..72098e7c9cef 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,9 +21,10 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config +from vllm.tokenizers import TokenizerLike from vllm.tracing import init_tracer from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest @@ -83,12 +84,7 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = cached_tokenizer_from_config(self.model_config) - - self.input_processor = InputProcessor(self.vllm_config, tokenizer) + self.input_processor = InputProcessor(self.vllm_config) self.io_processor = get_io_processor( self.vllm_config, self.model_config.io_processor_plugin, @@ -359,12 +355,11 @@ def tokenizer(self) -> TokenizerLike | None: return self.input_processor.tokenizer def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) + return self.input_processor.get_tokenizer() - return self.tokenizer + @property + def renderer(self) -> RendererLike: + return self.input_processor.renderer def do_log_stats(self) -> None: """Log stats if logging is enabled."""