|
1 | 1 | #![allow(clippy::enum_variant_names)] |
2 | 2 | use std::borrow::Cow; |
3 | | -use std::collections::HashSet; |
| 3 | +use std::collections::{HashMap, HashSet}; |
4 | 4 | use std::path::{Path, PathBuf}; |
| 5 | +use std::sync::LazyLock; |
5 | 6 |
|
6 | 7 | use convert_case::{Case, Casing}; |
7 | 8 | use derive_more::From; |
@@ -645,21 +646,31 @@ impl ToolDescription for ToolCatalog { |
645 | 646 | } |
646 | 647 | } |
647 | 648 | } |
648 | | -lazy_static::lazy_static! { |
649 | | - // Cache of all tool names |
650 | | - static ref FORGE_TOOLS: HashSet<ToolName> = ToolCatalog::iter() |
651 | | - .map(ToolName::new) |
652 | | - .collect(); |
653 | | -} |
| 649 | +// Cache of all tool names |
| 650 | +static FORGE_TOOLS: LazyLock<HashSet<ToolName>> = |
| 651 | + LazyLock::new(|| ToolCatalog::iter().map(ToolName::new).collect()); |
| 652 | + |
| 653 | +// Case-insensitive lookup map: lowercase tool name -> canonical tool name |
| 654 | +static FORGE_TOOLS_LOWER: LazyLock<HashMap<String, ToolName>> = LazyLock::new(|| { |
| 655 | + ToolCatalog::iter() |
| 656 | + .map(|tool| { |
| 657 | + let name = ToolName::new(tool.to_string()); |
| 658 | + (name.as_str().to_lowercase(), name) |
| 659 | + }) |
| 660 | + .collect() |
| 661 | +}); |
654 | 662 |
|
655 | | -/// Normalizes tool names for backward compatibility |
656 | | -/// Maps capitalized aliases to their lowercase canonical forms |
| 663 | +/// Normalizes a tool name received in a response before catalog matching. |
| 664 | +/// Trims surrounding whitespace and performs a case-insensitive lookup |
| 665 | +/// against all known catalog tool names, returning the canonical form when |
| 666 | +/// a match is found. |
657 | 667 | fn normalize_tool_name(name: &ToolName) -> ToolName { |
658 | | - match name.as_str() { |
659 | | - "Read" => ToolName::new("read"), |
660 | | - "Write" => ToolName::new("write"), |
661 | | - _ => name.clone(), |
662 | | - } |
| 668 | + let trimmed = name.as_str().trim(); |
| 669 | + let lower = trimmed.to_lowercase(); |
| 670 | + FORGE_TOOLS_LOWER |
| 671 | + .get(&lower) |
| 672 | + .cloned() |
| 673 | + .unwrap_or_else(|| ToolName::new(trimmed)) |
663 | 674 | } |
664 | 675 |
|
665 | 676 | impl ToolCatalog { |
@@ -931,15 +942,17 @@ impl TryFrom<ToolCallFull> for ToolCatalog { |
931 | 942 | type Error = crate::Error; |
932 | 943 |
|
933 | 944 | fn try_from(value: ToolCallFull) -> Result<Self, Self::Error> { |
| 945 | + // Normalize the tool name: trim whitespace and perform case-insensitive |
| 946 | + // catalog match so the serde deserialization receives the canonical name. |
| 947 | + let normalized_name = normalize_tool_name(&value.name); |
| 948 | + |
934 | 949 | let mut map = Map::new(); |
935 | | - map.insert("name".into(), value.name.as_str().into()); |
| 950 | + map.insert("name".into(), normalized_name.as_str().into()); |
936 | 951 |
|
937 | 952 | // Parse the arguments |
938 | 953 | let parsed_args = value.arguments.parse()?; |
939 | 954 |
|
940 | 955 | // Try to find the tool definition and coerce types based on schema |
941 | | - // Normalize the tool name for comparison |
942 | | - let normalized_name = normalize_tool_name(&value.name); |
943 | 956 | let coerced_args = ToolCatalog::iter() |
944 | 957 | .find(|tool| tool.definition().name == normalized_name) |
945 | 958 | .map(|tool| { |
@@ -1606,4 +1619,92 @@ mod tests { |
1606 | 1619 |
|
1607 | 1620 | assert_eq!(actual, expected); |
1608 | 1621 | } |
| 1622 | + |
| 1623 | + #[test] |
| 1624 | + fn test_normalize_tool_name_trims_whitespace() { |
| 1625 | + let actual = super::normalize_tool_name(&ToolName::new(" read ")); |
| 1626 | + let expected = ToolName::new("read"); |
| 1627 | + assert_eq!(actual, expected); |
| 1628 | + } |
| 1629 | + |
| 1630 | + #[test] |
| 1631 | + fn test_normalize_tool_name_case_insensitive_uppercase() { |
| 1632 | + let actual = super::normalize_tool_name(&ToolName::new("READ")); |
| 1633 | + let expected = ToolName::new("read"); |
| 1634 | + assert_eq!(actual, expected); |
| 1635 | + } |
| 1636 | + |
| 1637 | + #[test] |
| 1638 | + fn test_normalize_tool_name_case_insensitive_mixed() { |
| 1639 | + let actual = super::normalize_tool_name(&ToolName::new("FS_SEARCH")); |
| 1640 | + let expected = ToolName::new("fs_search"); |
| 1641 | + assert_eq!(actual, expected); |
| 1642 | + } |
| 1643 | + |
| 1644 | + #[test] |
| 1645 | + fn test_normalize_tool_name_trim_and_case_insensitive() { |
| 1646 | + let actual = super::normalize_tool_name(&ToolName::new(" SHELL ")); |
| 1647 | + let expected = ToolName::new("shell"); |
| 1648 | + assert_eq!(actual, expected); |
| 1649 | + } |
| 1650 | + |
| 1651 | + #[test] |
| 1652 | + fn test_normalize_tool_name_unknown_returns_trimmed() { |
| 1653 | + let actual = super::normalize_tool_name(&ToolName::new(" unknown_tool ")); |
| 1654 | + let expected = ToolName::new("unknown_tool"); |
| 1655 | + assert_eq!(actual, expected); |
| 1656 | + } |
| 1657 | + |
| 1658 | + #[test] |
| 1659 | + fn test_contains_case_insensitive() { |
| 1660 | + assert!(ToolCatalog::contains(&ToolName::new("READ"))); |
| 1661 | + assert!(ToolCatalog::contains(&ToolName::new("Shell"))); |
| 1662 | + assert!(ToolCatalog::contains(&ToolName::new("PATCH"))); |
| 1663 | + assert!(!ToolCatalog::contains(&ToolName::new("nonexistent"))); |
| 1664 | + } |
| 1665 | + |
| 1666 | + #[test] |
| 1667 | + fn test_contains_with_whitespace() { |
| 1668 | + assert!(ToolCatalog::contains(&ToolName::new(" read "))); |
| 1669 | + assert!(ToolCatalog::contains(&ToolName::new(" shell "))); |
| 1670 | + } |
| 1671 | + |
| 1672 | + #[test] |
| 1673 | + fn test_try_from_tool_call_uppercase_name() { |
| 1674 | + use crate::{ToolCallArguments, ToolCallFull}; |
| 1675 | + |
| 1676 | + let tool_call = ToolCallFull { |
| 1677 | + name: ToolName::new("SHELL"), |
| 1678 | + call_id: None, |
| 1679 | + arguments: ToolCallArguments::from_json(r#"{"command": "ls"}"#), |
| 1680 | + thought_signature: None, |
| 1681 | + }; |
| 1682 | + |
| 1683 | + let actual = ToolCatalog::try_from(tool_call); |
| 1684 | + |
| 1685 | + assert!(actual.is_ok(), "Should parse uppercase 'SHELL' tool name"); |
| 1686 | + assert!(matches!(actual.unwrap(), ToolCatalog::Shell(_))); |
| 1687 | + } |
| 1688 | + |
| 1689 | + #[test] |
| 1690 | + fn test_try_from_tool_call_with_whitespace_name() { |
| 1691 | + use crate::{ToolCallArguments, ToolCallFull}; |
| 1692 | + |
| 1693 | + let tool_call = ToolCallFull { |
| 1694 | + name: ToolName::new(" patch "), |
| 1695 | + call_id: None, |
| 1696 | + arguments: ToolCallArguments::from_json( |
| 1697 | + r#"{"file_path": "/test/file.rs", "new_string": "new", "old_string": "old"}"#, |
| 1698 | + ), |
| 1699 | + thought_signature: None, |
| 1700 | + }; |
| 1701 | + |
| 1702 | + let actual = ToolCatalog::try_from(tool_call); |
| 1703 | + |
| 1704 | + assert!( |
| 1705 | + actual.is_ok(), |
| 1706 | + "Should parse whitespace-padded 'patch' tool name" |
| 1707 | + ); |
| 1708 | + assert!(matches!(actual.unwrap(), ToolCatalog::Patch(_))); |
| 1709 | + } |
1609 | 1710 | } |
0 commit comments