Skip to content

Commit

Permalink
core[patch]: allow passing JSON schema as args_schema to tools (#29812)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Feb 18, 2025
1 parent 5034a8d commit d04fa1a
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 40 deletions.
43 changes: 36 additions & 7 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ class ToolException(Exception): # noqa: N818
"""


ArgsSchema = Union[TypeBaseModel, dict[str, Any]]


class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
"""Interface LangChain tools must implement."""

Expand Down Expand Up @@ -354,7 +357,7 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""

args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
default=None, description="The tool schema."
)
"""Pydantic model class to validate and parse the tool's input arguments.
Expand All @@ -364,6 +367,8 @@ class ChildTool(BaseTool):
- A subclass of pydantic.BaseModel.
or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
or
- a JSON schema dict
"""
return_direct: bool = False
"""Whether to return the tool's output directly.
Expand Down Expand Up @@ -423,10 +428,11 @@ def __init__(self, **kwargs: Any) -> None:
"args_schema" in kwargs
and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"])
and not isinstance(kwargs["args_schema"], dict)
):
msg = (
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
"args_schema must be a subclass of pydantic BaseModel or "
f"a JSON schema dict. Got: {kwargs['args_schema']}."
)
raise TypeError(msg)
super().__init__(**kwargs)
Expand All @@ -443,10 +449,18 @@ def is_single_input(self) -> bool:

@property
def args(self) -> dict:
return self.get_input_schema().model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
input_schema = self.get_input_schema()
json_schema = input_schema.model_json_schema()
return json_schema["properties"]

@property
def tool_call_schema(self) -> type[BaseModel]:
def tool_call_schema(self) -> ArgsSchema:
if isinstance(self.args_schema, dict):
return self.args_schema

full_schema = self.get_input_schema()
fields = []
for name, type_ in get_all_basemodel_annotations(full_schema).items():
Expand All @@ -470,6 +484,8 @@ def get_input_schema(
The input schema for the tool.
"""
if self.args_schema is not None:
if isinstance(self.args_schema, dict):
return super().get_input_schema(config)
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
Expand Down Expand Up @@ -505,6 +521,12 @@ def _parse_input(
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
if isinstance(input_args, dict):
msg = (
"String tool inputs are not allowed when "
"using tools with JSON schema args_schema."
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if hasattr(input_args, "model_validate"):
input_args.model_validate({key_: tool_input})
Expand All @@ -513,7 +535,9 @@ def _parse_input(
return tool_input
else:
if input_args is not None:
if issubclass(input_args, BaseModel):
if isinstance(input_args, dict):
return tool_input
elif issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
Expand Down Expand Up @@ -605,7 +629,12 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any:
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
if self.args_schema is not None and not get_fields(self.args_schema):
if (
self.args_schema is not None
and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema)
and not get_fields(self.args_schema)
):
# StructuredTool with no args
return (), {}
tool_input = self._parse_input(tool_input, tool_call_id)
Expand Down
11 changes: 7 additions & 4 deletions libs/core/langchain_core/tools/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
Union,
)

from pydantic import BaseModel

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolCall
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
ArgsSchema,
BaseTool,
ToolException,
_get_runnable_config_param,
Expand Down Expand Up @@ -57,7 +56,11 @@ def args(self) -> dict:
The input arguments for the tool.
"""
if self.args_schema is not None:
return self.args_schema.model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
json_schema = self.args_schema.model_json_schema()
return json_schema["properties"]
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
Expand Down Expand Up @@ -132,7 +135,7 @@ def from_function(
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
args_schema: Optional[ArgsSchema] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
Expand Down
15 changes: 10 additions & 5 deletions libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Union,
)

from pydantic import BaseModel, Field, SkipValidation
from pydantic import Field, SkipValidation

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
Expand All @@ -22,18 +22,18 @@
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
FILTERED_ARGS,
ArgsSchema,
BaseTool,
_get_runnable_config_param,
create_schema_from_function,
)
from langchain_core.utils.pydantic import TypeBaseModel


class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""

description: str = ""
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
..., description="The tool schema."
)
"""The input arguments' schema."""
Expand Down Expand Up @@ -62,7 +62,12 @@ async def ainvoke(
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
input_schema = self.get_input_schema()
json_schema = input_schema.model_json_schema()
return json_schema["properties"]

def _run(
self,
Expand Down Expand Up @@ -110,7 +115,7 @@ def from_function(
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_artifact"] = "content",
Expand Down
63 changes: 50 additions & 13 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict

from langchain_core._api import beta, deprecated
Expand Down Expand Up @@ -75,8 +76,8 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
return new_kv


def _convert_pydantic_to_openai_function(
model: type,
def _convert_json_schema_to_openai_function(
schema: dict,
*,
name: Optional[str] = None,
description: Optional[str] = None,
Expand All @@ -85,7 +86,7 @@ def _convert_pydantic_to_openai_function(
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
schema: The JSON schema to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
Expand All @@ -95,13 +96,6 @@ def _convert_pydantic_to_openai_function(
Returns:
The function description.
"""
if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2
elif hasattr(model, "schema"):
schema = model.schema() # Pydantic 1
else:
msg = "Model must be a Pydantic model."
raise TypeError(msg)
schema = dereference_refs(schema)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
Expand All @@ -116,6 +110,38 @@ def _convert_pydantic_to_openai_function(
}


def _convert_pydantic_to_openai_function(
model: type,
*,
name: Optional[str] = None,
description: Optional[str] = None,
rm_titles: bool = True,
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
rm_titles: Whether to remove titles from the schema. Defaults to True.
Returns:
The function description.
"""
if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2
elif hasattr(model, "schema"):
schema = model.schema() # Pydantic 1
else:
msg = "Model must be a Pydantic model."
raise TypeError(msg)
return _convert_json_schema_to_openai_function(
schema, name=name, description=description, rm_titles=rm_titles
)


convert_pydantic_to_openai_function = deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
Expand Down Expand Up @@ -289,9 +315,20 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:

is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
if tool.tool_call_schema and not is_simple_oai_tool:
return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
if isinstance(tool.tool_call_schema, dict):
return _convert_json_schema_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
else:
error_msg = (
f"Unsupported tool call schema: {tool.tool_call_schema}. "
"Tool call schema must be a JSON schema dict or a Pydantic model."
)
raise ValueError(error_msg)
else:
return {
"name": tool.name,
Expand Down
Loading

0 comments on commit d04fa1a

Please sign in to comment.