|
1 | 1 | from __future__ import annotations |
| 2 | +from collections import defaultdict |
2 | 3 | import inspect |
3 | 4 | from typing import Callable, Union |
4 | 5 |
|
|
7 | 8 |
|
8 | 9 |
|
9 | 10 | def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: |
10 | | - parsed_docstring = {'description': ''} |
| 11 | + parsed_docstring = defaultdict(str) |
11 | 12 | if not doc_string: |
12 | 13 | return parsed_docstring |
13 | 14 |
|
14 | 15 | lowered_doc_string = doc_string.lower() |
15 | 16 |
|
16 | | - if 'args:' not in lowered_doc_string: |
17 | | - parsed_docstring['description'] = lowered_doc_string.strip() |
18 | | - return parsed_docstring |
19 | | - |
20 | | - else: |
21 | | - parsed_docstring['description'] = lowered_doc_string.split('args:')[0].strip() |
22 | | - args_section = lowered_doc_string.split('args:')[1] |
23 | | - |
24 | | - if 'returns:' in lowered_doc_string: |
25 | | - # Return section can be captured and used |
26 | | - args_section = args_section.split('returns:')[0] |
| 17 | + key = hash(doc_string) |
| 18 | + parsed_docstring[key] = '' |
| 19 | + for line in lowered_doc_string.splitlines(): |
| 20 | + if line.startswith('args:'): |
| 21 | + key = 'args' |
| 22 | + elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'): |
| 23 | + key = '_' |
27 | 24 |
|
28 | | - if 'yields:' in lowered_doc_string: |
29 | | - args_section = args_section.split('yields:')[0] |
| 25 | + else: |
| 26 | + # maybe change to a list and join later |
| 27 | + parsed_docstring[key] += f'{line.strip()}\n' |
30 | 28 |
|
31 | | - cur_var = None |
32 | | - for line in args_section.split('\n'): |
| 29 | + last_key = None |
| 30 | + for line in parsed_docstring['args'].splitlines(): |
33 | 31 | line = line.strip() |
34 | | - if not line: |
35 | | - continue |
36 | | - if ':' not in line: |
37 | | - # Continuation of the previous parameter's description |
38 | | - if cur_var: |
39 | | - parsed_docstring[cur_var] += f' {line}' |
40 | | - continue |
41 | | - |
42 | | - # For the case with: `param_name (type)`: ... |
43 | | - if '(' in line: |
44 | | - param_name = line.split('(')[0] |
45 | | - param_desc = line.split('):')[1] |
46 | | - |
47 | | - # For the case with: `param_name: ...` |
48 | | - else: |
49 | | - param_name, param_desc = line.split(':', 1) |
| 32 | + if ':' in line and not line.startswith('args'): |
| 33 | + # Split on first occurrence of '(' or ':' to separate arg name from description |
| 34 | + split_char = '(' if '(' in line else ':' |
| 35 | + arg_name, rest = line.split(split_char, 1) |
50 | 36 |
|
51 | | - parsed_docstring[param_name.strip()] = param_desc.strip() |
52 | | - cur_var = param_name.strip() |
| 37 | + last_key = arg_name.strip() |
| 38 | + # Get description after the colon |
| 39 | + arg_description = rest.split(':', 1)[1].strip() if split_char == '(' else rest.strip() |
| 40 | + parsed_docstring[last_key] = arg_description |
| 41 | + |
| 42 | + elif last_key and line: |
| 43 | + parsed_docstring[last_key] += ' ' + line |
53 | 44 |
|
54 | 45 | return parsed_docstring |
55 | 46 |
|
56 | 47 |
|
57 | 48 | def convert_function_to_tool(func: Callable) -> Tool: |
| 49 | + doc_string_hash = hash(inspect.getdoc(func)) |
| 50 | + parsed_docstring = _parse_docstring(inspect.getdoc(func)) |
58 | 51 | schema = type( |
59 | 52 | func.__name__, |
60 | 53 | (pydantic.BaseModel,), |
61 | 54 | { |
62 | | - '__annotations__': {k: v.annotation for k, v in inspect.signature(func).parameters.items()}, |
| 55 | + '__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, |
63 | 56 | '__signature__': inspect.signature(func), |
64 | | - '__doc__': inspect.getdoc(func), |
| 57 | + '__doc__': parsed_docstring[doc_string_hash], |
65 | 58 | }, |
66 | 59 | ).model_json_schema() |
67 | 60 |
|
68 | | - properties = {} |
69 | | - required = [] |
70 | | - parsed_docstring = _parse_docstring(schema.get('description')) |
71 | 61 | for k, v in schema.get('properties', {}).items(): |
72 | | - prop = { |
73 | | - 'description': parsed_docstring.get(k, ''), |
74 | | - 'type': v.get('type'), |
| 62 | + # If type is missing, the default is string |
| 63 | + types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')} |
| 64 | + if 'null' in types: |
| 65 | + schema['required'].remove(k) |
| 66 | + types.discard('null') |
| 67 | + |
| 68 | + schema['properties'][k] = { |
| 69 | + 'description': parsed_docstring[k], |
| 70 | + 'type': ', '.join(types), |
75 | 71 | } |
76 | 72 |
|
77 | | - if 'anyOf' in v: |
78 | | - is_optional = any(t.get('type') == 'null' for t in v['anyOf']) |
79 | | - types = [t.get('type', 'string') for t in v['anyOf'] if t.get('type') != 'null'] |
80 | | - prop['type'] = types[0] if len(types) == 1 else str(types) |
81 | | - if not is_optional: |
82 | | - required.append(k) |
83 | | - else: |
84 | | - if prop['type'] != 'null': |
85 | | - required.append(k) |
86 | | - |
87 | | - properties[k] = prop |
88 | | - |
89 | | - schema['properties'] = properties |
90 | | - |
91 | 73 | tool = Tool( |
92 | 74 | function=Tool.Function( |
93 | 75 | name=func.__name__, |
94 | | - description=parsed_docstring.get('description'), |
95 | | - parameters=Tool.Function.Parameters( |
96 | | - type='object', |
97 | | - properties=schema.get('properties', {}), |
98 | | - required=required, |
99 | | - ), |
| 76 | + description=schema.get('description', ''), |
| 77 | + parameters=Tool.Function.Parameters(**schema), |
100 | 78 | ) |
101 | 79 | ) |
102 | 80 |
|
|
0 commit comments