inferencerlabs commited on
Commit
8d0007e
·
1 Parent(s): 3de39f7

Upload complete model

Browse files
Files changed (1) hide show
  1. solar_open_tool_parser.py +267 -0
solar_open_tool_parser.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Upstage AI.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import re
18
+ import string
19
+ import ast
20
+ import json
21
+ from collections.abc import Sequence
22
+ from typing import Union, Tuple, List, Optional
23
+
24
+ from vllm.entrypoints.openai.protocol import (
25
+ ChatCompletionRequest,
26
+ DeltaMessage,
27
+ DeltaFunctionCall,
28
+ DeltaToolCall,
29
+ ExtractedToolCallInformation,
30
+ ToolCall,
31
+ FunctionCall,
32
+ )
33
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
34
+ ToolParser
35
+ )
36
+ from vllm.logger import init_logger
37
+
38
+ import pyjson5
39
+
40
+ class ToolCallID:
41
+ _LENGTH = 10
42
+
43
+ def __init__(self, id_val: str, validation: bool = False):
44
+ self._id = id_val
45
+ if validation:
46
+ self._validate()
47
+
48
+ @classmethod
49
+ def random(cls, validation=False) -> 'ToolCallID':
50
+ chars = string.ascii_lowercase + string.digits
51
+ return cls(''.join(random.choice(chars) for _ in range(ToolCallID._LENGTH)), validation=validation)
52
+
53
+ def _validate(self):
54
+ assert len(self._id) == ToolCallID._LENGTH
55
+ pattern = r'^[a-z0-9]{10}$'
56
+ assert re.match(pattern, self._id) is not None
57
+
58
+ def to_string(self) -> str:
59
+ return self._id
60
+
61
+ def __str__(self) -> str:
62
+ return self.to_string()
63
+
64
+
65
+ logger = init_logger(__name__)
66
+
67
+
68
+ class SolarOpenToolParser(ToolParser):
69
+
70
+ def extract_tool_calls(
71
+ self,
72
+ model_output: str,
73
+ request: ChatCompletionRequest,
74
+ ) -> ExtractedToolCallInformation:
75
+ content, tool_calls = self._parse_text(model_output)
76
+ return ExtractedToolCallInformation(
77
+ tools_called=len(tool_calls) > 0,
78
+ tool_calls=tool_calls,
79
+ content=content if content else None,
80
+ )
81
+
82
+ def extract_tool_calls_streaming(
83
+ self,
84
+ previous_text: str,
85
+ current_text: str,
86
+ delta_text: str,
87
+ previous_token_ids: Sequence[int],
88
+ current_token_ids: Sequence[int],
89
+ delta_token_ids: Sequence[int],
90
+ request: ChatCompletionRequest,
91
+ ) -> Union[DeltaMessage, None]:
92
+ # 1) Emit plain content tokens immediately until content terminator
93
+ # tags or tool_calls section begins. Be careful when tokenizer groups
94
+ # multiple special tags into a single delta (e.g., "<|tool_calls|><|tool_call:begin|>").
95
+ # Only emit as content if BOTH:
96
+ # - previous_text has not seen any special markers, and
97
+ # - delta_text does NOT contain any of those markers as a substring.
98
+ if delta_text:
99
+ # Do NOT emit content if we have already started any special section
100
+ # including tool call tags. Content should only be emitted at the
101
+ # very beginning before any markers show up.
102
+ special_markers = (
103
+ "<|flush|>",
104
+ "<|end|>",
105
+ "<|begin|>",
106
+ "<|tool_calls|>",
107
+ "<|tool_call:begin|>",
108
+ "<|tool_call:name|>",
109
+ "<|tool_call:args|>",
110
+ "<|tool_call:end|>",
111
+ "<|calls|>",
112
+ )
113
+ if not any(tag in previous_text for tag in special_markers):
114
+ if not any(tag in delta_text for tag in special_markers):
115
+ return DeltaMessage(content=delta_text, tool_calls=[])
116
+
117
+ tool_call_deltas: list[DeltaToolCall] = []
118
+
119
+ # Helper lambdas to analyze current_text state
120
+ def _completed_calls_count(txt: str) -> int:
121
+ return len(self._parse_tool_calls(txt))
122
+
123
+ # Detect if a new tool_call started streaming its args just now.
124
+ if delta_text and "<|tool_call:args|>" in delta_text:
125
+ # Extract id and name for the latest tool call block present so far.
126
+ begin_tag = "<|tool_call:begin|>"
127
+ name_tag = "<|tool_call:name|>"
128
+ args_tag = "<|tool_call:args|>"
129
+
130
+ latest_args = current_text.rfind(args_tag)
131
+ latest_name = current_text.rfind(name_tag, 0, latest_args if latest_args != -1 else None)
132
+ latest_begin = current_text.rfind(begin_tag, 0, latest_name if latest_name != -1 else None)
133
+ if latest_begin != -1 and latest_name != -1 and latest_args != -1 and latest_begin < latest_name < latest_args:
134
+ tool_id = current_text[latest_begin + len(begin_tag):latest_name]
135
+ func_name = current_text[latest_name + len(name_tag):latest_args]
136
+ # Index equals number of args tags seen before this delta
137
+ index = previous_text.count(args_tag)
138
+ tool_call_deltas.append(
139
+ DeltaToolCall(
140
+ id=tool_id,
141
+ type="function",
142
+ index=index,
143
+ function=DeltaFunctionCall(name=func_name, arguments=""),
144
+ )
145
+ )
146
+
147
+ # If we are inside args (after last args tag without end), stream arg chunk
148
+ begin_tag = "<|tool_call:begin|>"
149
+ args_tag = "<|tool_call:args|>"
150
+ end_tag = "<|tool_call:end|>"
151
+ last_args_pos = current_text.rfind(args_tag)
152
+ last_end_pos = current_text.rfind(end_tag)
153
+ if last_args_pos != -1 and (last_end_pos == -1 or last_args_pos > last_end_pos):
154
+ # Currently within args for the latest tool call
155
+ # Determine previous args text and current args text to compute delta
156
+ prev_last_args = previous_text.rfind(args_tag)
157
+ prev_last_end = previous_text.rfind(end_tag)
158
+ if prev_last_args != -1 and (prev_last_end == -1 or prev_last_args > prev_last_end):
159
+ # Already inside args previously: emit only the delta_text
160
+ if delta_text and delta_text not in (begin_tag, args_tag, end_tag):
161
+ # Stream into the most recently started (but not yet ended) call
162
+ index = max(previous_text.count(args_tag) - 1, 0)
163
+ tool_call_deltas.append(
164
+ DeltaToolCall(
165
+ id=None,
166
+ type=None,
167
+ index=index,
168
+ function=DeltaFunctionCall(name=None, arguments=delta_text),
169
+ )
170
+ )
171
+
172
+ if not tool_call_deltas:
173
+ return None
174
+
175
+ return DeltaMessage(content=None, tool_calls=tool_call_deltas)
176
+
177
+ # --------------------
178
+ # Internal helpers
179
+ # --------------------
180
+ def _parse_text(self, text: str) -> Tuple[Optional[str], List[ToolCall]]:
181
+ """Parse the completed segments from the given text.
182
+
183
+ Returns (content, tool_calls) where content is extracted as the leading
184
+ text up to the first '<|flush|>' or '<|end|>' marker, and tool_calls is
185
+ a list of fully parsed tool calls inside '<|tool_calls|> ... <|calls|>'.
186
+ """
187
+ content = self._parse_content(text)
188
+ tool_calls = self._parse_tool_calls(text)
189
+ return content, tool_calls
190
+
191
+ def _parse_content(self, text: str) -> Optional[str]:
192
+ """Extract assistant content from the text.
193
+
194
+ Rule: take the leading content before the first '<|flush|>' or
195
+ '<|end|>' marker. If neither marker exists, return None.
196
+ """
197
+ end_tags = ["<|flush|>", "<|end|>"]
198
+
199
+ # Take leading content before the first end tag
200
+ end_positions = [pos for tag in end_tags if (pos := text.find(tag)) != -1]
201
+ if not end_positions:
202
+ return None
203
+ end = min(end_positions)
204
+ # Trim only the extracted portion; tests expect exact substring
205
+ return text[:end]
206
+
207
+ def _parse_tool_call_args(self, text: str) -> str:
208
+ try:
209
+ # Try to parse as JSON
210
+ args = json.loads(text)
211
+ except json.JSONDecodeError:
212
+ try:
213
+ # Try to parse as JSON5
214
+ args = pyjson5.decode(text)
215
+ except pyjson5.Json5DecoderException:
216
+ try:
217
+ # Try to parse as Python literal
218
+ args = ast.literal_eval(text)
219
+ except Exception:
220
+ # Fallback: return the original string
221
+ args = text
222
+ if not isinstance(args, str):
223
+ # Always convert back to JSON string
224
+ args = json.dumps(args)
225
+ return args
226
+
227
+ def _parse_tool_calls(self, text: str) -> List[ToolCall]:
228
+ tool_calls: list[ToolCall] = []
229
+ # Parse globally; wrapper '<|tool_calls|>' may or may not be present.
230
+ section_start = 0
231
+ # section ends at <|calls|> if present, else use end of text
232
+ section_end = text.find("<|calls|>")
233
+ if section_end == -1:
234
+ section_end = len(text)
235
+ i = section_start
236
+ while True:
237
+ begin_tag = "<|tool_call:begin|>"
238
+ name_tag = "<|tool_call:name|>"
239
+ args_tag = "<|tool_call:args|>"
240
+ end_tag = "<|tool_call:end|>"
241
+
242
+ b = text.find(begin_tag, i, section_end)
243
+ if b == -1:
244
+ break
245
+ b += len(begin_tag)
246
+ n = text.find(name_tag, b, section_end)
247
+ if n == -1:
248
+ break
249
+ tool_id = text[b:n]
250
+ n += len(name_tag)
251
+ a = text.find(args_tag, n, section_end)
252
+ if a == -1:
253
+ break
254
+ name = text[n:a]
255
+ a += len(args_tag)
256
+ e = text.find(end_tag, a, section_end)
257
+ if e == -1:
258
+ break
259
+ args = text[a:e]
260
+ tool_calls.append(
261
+ ToolCall(
262
+ id=tool_id,
263
+ function=FunctionCall(name=name, arguments=self._parse_tool_call_args(args)),
264
+ ))
265
+ i = e + len(end_tag)
266
+
267
+ return tool_calls