Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,25 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

def fromjson(x):
# Parse a JSON string into a Python object
# This is useful for parsing tool call arguments from JSON strings
if isinstance(x, str):
try:
return json.loads(x)
except (json.JSONDecodeError, TypeError):
# If parsing fails, return the original string
return x
return x

def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.filters["fromjson"] = fromjson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,48 @@ def test_jinja_strftime(self):
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)

@require_jinja
def test_jinja_fromjson(self):
# Test fromjson filter for parsing JSON strings in chat templates
fromjson_template = (
"""{% set args = '{"name": "test_func", "value": 42}' | fromjson %}{{ args.name }}: {{ args.value }}"""
)

# Test with tool calls that have JSON string arguments
tool_call_template = """{% for message in messages %}{% if message.tool_calls %}{% for tc in message.tool_calls %}{% set args = tc.function.arguments | fromjson %}Function: {{ tc.function.name }}, Args: {% for k, v in args.items() %}{{ k }}={{ v }}{% if not loop.last %}, {% endif %}{% endfor %}{% endfor %}{% endif %}{% endfor %}"""

dummy_conversation = [{"role": "user", "content": "test"}]

tool_conversation = [
{
"role": "assistant",
"content": "I'll help with that.",
"tool_calls": [{"function": {"name": "search", "arguments": '{"query": "hello world", "limit": 10}'}}],
}
]

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Test basic fromjson usage
fromjson_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=fromjson_template, tokenize=False
)
self.assertEqual(fromjson_output, "test_func: 42")

# Test fromjson with tool calls
tool_output = tokenizer.apply_chat_template(
tool_conversation, chat_template=tool_call_template, tokenize=False
)
self.assertEqual(tool_output, "Function: search, Args: query=hello world, limit=10")

# Test that fromjson handles non-string inputs gracefully
graceful_template = """{{ 123 | fromjson }}"""
graceful_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=graceful_template, tokenize=False
)
self.assertEqual(graceful_output, "123")

@require_torch
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
Expand Down