diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index 6b0ccaf05d..5cf252639c 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -79,7 +79,16 @@ def test_transform_messages_preserves_tool_call_fields(): assert transformed[1] == { "role": "assistant", "content": None, - "tool_calls": messages[1]["tool_calls"], + "tool_calls": [ + { + "id": "call_bed4c5f1", + "function": { + "arguments": {"file_path": "README*"}, + "name": "view_file_in_detail", + }, + "type": "function", + } + ], } assert transformed[2] == { "role": "tool", diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 700d75cffc..6cd4980da0 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -811,6 +811,26 @@ def _transform_messages( ) new_message = dict(msg) new_message["content"] = new_content if new_content else None + # Parse JSON-encoded arguments in tool_calls to dicts, + # so Jinja2 templates can iterate them with |items. + if new_message.get("tool_calls"): + tool_calls = [] + for tc in new_message["tool_calls"]: + tc = dict(tc) + func = tc.get("function") + if isinstance(func, dict) and isinstance( + func.get("arguments"), str + ): + func = dict(func) + try: + parsed_args = json.loads(func["arguments"]) + if isinstance(parsed_args, dict): + func["arguments"] = parsed_args + except (json.JSONDecodeError, TypeError): + pass + tc["function"] = func + tool_calls.append(tc) + new_message["tool_calls"] = tool_calls transformed_messages.append(new_message) return transformed_messages