-
Notifications
You must be signed in to change notification settings - Fork 95
Expand file tree
/
Copy pathhelpers.py
More file actions
560 lines (518 loc) · 18.1 KB
/
helpers.py
File metadata and controls
560 lines (518 loc) · 18.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
"""
插件通用能力:页面复用、Cookie 登录、在浏览器内发起 fetch 并流式回传。
接入方只需实现站点特有的 URL/请求体/SSE 解析,其余复用此处逻辑。
"""
import asyncio
import base64
import json
import logging
from collections.abc import Callable
from typing import Any, AsyncIterator
from playwright.async_api import BrowserContext, Page
from core.plugin.errors import AccountFrozenError
ParseSseEvent = Callable[[str], tuple[list[str], str | None, str | None]]
logger = logging.getLogger(__name__)
# 在页面内 POST 请求并流式回传:成功时逐块发送响应体,失败时发送 __error__: 前缀 + 信息,最后发送 __done__
# bindingName 按请求唯一,同一 page 多并发时互不串数据
PAGE_FETCH_STREAM_JS = """
async ({ url, body, bindingName }) => {
const send = globalThis[bindingName];
const done = "__done__";
const errPrefix = "__error__:";
try {
const ctrl = new AbortController();
const t = setTimeout(() => ctrl.abort(), 90000);
const resp = await fetch(url, {
method: "POST",
body: body,
headers: { "Content-Type": "application/json", "Accept": "text/event-stream" },
credentials: "include",
signal: ctrl.signal
});
clearTimeout(t);
if (!resp.ok) {
const errText = await resp.text();
const errSnippet = (errText && errText.length > 800) ? errText.slice(0, 800) + "..." : (errText || "");
await send(errPrefix + "HTTP " + resp.status + " " + errSnippet);
await send(done);
return;
}
if (!resp.body) {
await send(errPrefix + "No response body");
await send(done);
return;
}
const headersObj = {};
resp.headers.forEach((v, k) => { headersObj[k] = v; });
await send("__headers__:" + JSON.stringify(headersObj));
const reader = resp.body.getReader();
const dec = new TextDecoder();
while (true) {
const { done: streamDone, value } = await reader.read();
if (streamDone) break;
await send(dec.decode(value));
}
} catch (e) {
const msg = e.name === "AbortError" ? "请求超时(90s)" : (e.message || String(e));
await send(errPrefix + msg);
}
await send(done);
}
"""
PAGE_FETCH_JSON_JS = """
async ({ url, method, body, headers, timeoutMs }) => {
const ctrl = new AbortController();
const t = setTimeout(() => ctrl.abort(), timeoutMs || 15000);
try {
const resp = await fetch(url, {
method: method || "GET",
body: body ?? undefined,
headers: headers || {},
credentials: "include",
signal: ctrl.signal
});
clearTimeout(t);
const text = await resp.text();
const headersObj = {};
resp.headers.forEach((v, k) => { headersObj[k] = v; });
return {
ok: resp.ok,
status: resp.status,
statusText: resp.statusText,
url: resp.url,
redirected: resp.redirected,
headers: headersObj,
text,
};
} catch (e) {
clearTimeout(t);
const msg = e.name === "AbortError" ? `请求超时(${Math.floor((timeoutMs || 15000) / 1000)}s)` : (e.message || String(e));
return { error: msg };
}
}
"""
PAGE_FETCH_MULTIPART_JS = """
async ({ url, filename, mimeType, dataBase64, fieldName, extraFields, timeoutMs }) => {
const ctrl = new AbortController();
const t = setTimeout(() => ctrl.abort(), timeoutMs || 30000);
try {
const binary = atob(dataBase64);
const bytes = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i += 1) {
bytes[i] = binary.charCodeAt(i);
}
const form = new FormData();
if (extraFields) {
Object.entries(extraFields).forEach(([k, v]) => {
if (v !== undefined && v !== null) form.append(k, String(v));
});
}
const file = new File([bytes], filename, { type: mimeType || "application/octet-stream" });
form.append(fieldName || "file", file);
const resp = await fetch(url, {
method: "POST",
body: form,
credentials: "include",
signal: ctrl.signal
});
clearTimeout(t);
const text = await resp.text();
const headersObj = {};
resp.headers.forEach((v, k) => { headersObj[k] = v; });
return {
ok: resp.ok,
status: resp.status,
statusText: resp.statusText,
url: resp.url,
redirected: resp.redirected,
headers: headersObj,
text,
};
} catch (e) {
clearTimeout(t);
const msg = e.name === "AbortError" ? `请求超时(${Math.floor((timeoutMs || 30000) / 1000)}s)` : (e.message || String(e));
return { error: msg };
}
}
"""
async def ensure_page_for_site(
context: BrowserContext,
url_contains: str,
start_url: str,
*,
timeout: int = 20000,
) -> Page:
"""
若已有页面 URL 包含 url_contains 则复用,否则 new_page 并 goto start_url。
接入方只需提供「站点特征」和「入口 URL」。
"""
if context.pages:
for p in context.pages:
if url_contains in (p.url or ""):
return p
page = await context.new_page()
await page.goto(start_url, wait_until="domcontentloaded", timeout=timeout)
return page
async def create_page_for_site(
context: BrowserContext,
start_url: str,
*,
reuse_page: Page | None = None,
timeout: int = 20000,
) -> Page:
"""
若传入 reuse_page 则在其上 goto start_url,否则 new_page 再 goto。
用于复用浏览器默认空白页或 page 池的初始化与补回。
"""
if reuse_page is not None:
await reuse_page.goto(start_url, wait_until="domcontentloaded", timeout=timeout)
return reuse_page
page = await context.new_page()
await page.goto(start_url, wait_until="domcontentloaded", timeout=timeout)
return page
def _cookie_domain_matches(cookie_domain: str, site_domain: str) -> bool:
"""判断 cookie 的 domain 是否属于站点 domain(如 .claude.ai 与 claude.ai 视为同一域)。"""
a = cookie_domain if cookie_domain.startswith(".") else f".{cookie_domain}"
b = site_domain if site_domain.startswith(".") else f".{site_domain}"
return a == b
def _cookie_to_set_param(c: Any) -> dict[str, str]:
"""将 context.cookies() 返回的项转为 add_cookies 接受的 SetCookieParam 格式。"""
return {
"name": c["name"],
"value": c["value"],
"domain": c.get("domain") or "",
"path": c.get("path") or "/",
}
async def clear_cookies_for_domain(
context: BrowserContext,
site_domain: str,
) -> None:
"""清除 context 内属于指定站点域的所有 cookie,保留其他域。"""
cookies = await context.cookies()
keep = [
c
for c in cookies
if not _cookie_domain_matches(c.get("domain", ""), site_domain)
]
await context.clear_cookies()
if keep:
await context.add_cookies([_cookie_to_set_param(c) for c in keep]) # type: ignore[arg-type]
logger.info(
"[auth] cleared cookies for domain=%s (kept %s cookies)", site_domain, len(keep)
)
async def clear_page_storage_for_switch(page: Page) -> None:
"""切号前清空当前页面的 localStorage(当前 origin)。"""
try:
await page.evaluate("() => { window.localStorage.clear(); }")
logger.info("[auth] cleared localStorage for switch")
except Exception as e:
logger.warning("[auth] clear localStorage failed (page may be detached): %s", e)
async def safe_page_reload(page: Page, url: str | None = None) -> None:
"""安全地 reload 或 goto(url),忽略因 ERR_ABORTED / frame detached 导致的异常。"""
try:
if url:
await page.goto(url, wait_until="domcontentloaded")
else:
await page.reload(wait_until="domcontentloaded")
except Exception as e:
err_msg = str(e)
if "ERR_ABORTED" in err_msg or "detached" in err_msg.lower():
logger.warning(
"[auth] page.reload/goto 被中止或 frame 已分离: %s", err_msg[:200]
)
else:
raise
async def apply_cookie_auth(
context: BrowserContext,
page: Page,
auth: dict[str, Any],
cookie_name: str,
auth_keys: list[str],
domain: str,
*,
path: str = "/",
reload: bool = True,
) -> None:
"""
从 auth 中按 auth_keys 顺序取第一个非空值作为 cookie 值,写入 context 并可选 reload。
接入方只需提供 cookie 名、auth 里的 key 列表、域名。
仅写 cookie 不 reload 时,同 context 内的 fetch() 仍会带上 cookie;reload 仅在需要页面文档同步登录态时用。
"""
value = None
for k in auth_keys:
v = auth.get(k)
if v is not None and v != "":
value = str(v).strip()
if value:
break
if not value:
raise ValueError(f"auth 需包含以下其一且非空: {auth_keys}")
logger.info(
"[auth] context.add_cookies domain=%s name=%s reload=%s page.url=%s",
domain,
cookie_name,
reload,
page.url,
)
await context.add_cookies(
[
{
"name": cookie_name,
"value": value,
"domain": domain,
"path": path,
"secure": True,
"httpOnly": True,
}
]
)
if reload:
await safe_page_reload(page)
async def request_json_via_page_fetch(
page: Page,
url: str,
*,
method: str = "GET",
body: str | None = None,
headers: dict[str, str] | None = None,
timeout_ms: int = 15000,
) -> dict[str, Any]:
"""
在页面内发起非流式 fetch,请求结果按 JSON 优先解析返回。
这样能复用浏览器真实网络栈、cookie 与代理扩展能力。
"""
logger.info(
"[fetch] page request method=%s url=%s page.url=%s",
method,
url[:120] + "..." if len(url) > 120 else url,
page.url or "",
)
result = await page.evaluate(
PAGE_FETCH_JSON_JS,
{
"url": url,
"method": method,
"body": body,
"headers": headers or {},
"timeoutMs": timeout_ms,
},
)
if not isinstance(result, dict):
raise RuntimeError("页面 fetch 返回结果异常")
error = result.get("error")
if error:
raise RuntimeError(str(error))
text = result.get("text")
if isinstance(text, str) and text:
try:
result["json"] = json.loads(text)
except json.JSONDecodeError:
result["json"] = None
else:
result["json"] = None
return result
async def upload_file_via_page_fetch(
page: Page,
url: str,
*,
filename: str,
mime_type: str,
data: bytes,
field_name: str = "file",
extra_fields: dict[str, str] | None = None,
timeout_ms: int = 30000,
) -> dict[str, Any]:
logger.info(
"[fetch] page upload filename=%s mime=%s url=%s page.url=%s",
filename,
mime_type,
url[:120] + "..." if len(url) > 120 else url,
page.url or "",
)
result = await page.evaluate(
PAGE_FETCH_MULTIPART_JS,
{
"url": url,
"filename": filename,
"mimeType": mime_type,
"dataBase64": base64.b64encode(data).decode("ascii"),
"fieldName": field_name,
"extraFields": extra_fields or {},
"timeoutMs": timeout_ms,
},
)
if not isinstance(result, dict):
raise RuntimeError("页面上传返回结果异常")
error = result.get("error")
if error:
raise RuntimeError(str(error))
text = result.get("text")
if isinstance(text, str) and text:
try:
result["json"] = json.loads(text)
except json.JSONDecodeError:
result["json"] = None
else:
result["json"] = None
return result
async def stream_raw_via_page_fetch(
context: BrowserContext,
page: Page,
url: str,
body: str,
request_id: str,
*,
on_http_error: Callable[[str, dict[str, str] | None], int | None] | None = None,
on_headers: Callable[[dict[str, str]], None] | None = None,
error_state: dict[str, bool] | None = None,
fetch_timeout: int = 90,
read_timeout: float = 130.0,
) -> AsyncIterator[str]:
"""
在浏览器内对 url 发起 POST body,流式回传原始字符串块(含 SSE 等)。
同一 page 多请求用 request_id 区分 binding,互不串数据。
通过 CDP Runtime.addBinding 注入 sendChunk_<request_id>,用 Runtime.bindingCalled 接收。
收到 __headers__: 时解析 JSON 并调用 on_headers(headers);收到 __error__: 时调用 on_http_error(msg);收到 __done__ 结束。
"""
chunk_queue: asyncio.Queue[str] = asyncio.Queue()
BINDING_NAME = "sendChunk_" + request_id
def on_binding_called(event: dict[str, Any]) -> None:
name = event.get("name")
payload = event.get("payload", "")
if name == BINDING_NAME:
chunk_queue.put_nowait(
payload if isinstance(payload, str) else str(payload)
)
cdp = None
try:
cdp = await context.new_cdp_session(page)
cdp.on("Runtime.bindingCalled", on_binding_called)
await cdp.send("Runtime.addBinding", {"name": BINDING_NAME})
logger.info(
"[fetch] evaluate fetch on pool page (no goto) url=%s page.url=%s",
url[:80] + "..." if len(url) > 80 else url,
page.url or "",
)
async def run_fetch() -> None:
await page.evaluate(
PAGE_FETCH_STREAM_JS,
{"url": url, "body": body, "bindingName": BINDING_NAME},
)
fetch_task = asyncio.create_task(run_fetch())
try:
headers = None
while True:
try:
chunk = await asyncio.wait_for(
chunk_queue.get(), timeout=read_timeout
)
except asyncio.TimeoutError:
logger.warning("流式读取超时")
break
if chunk == "__done__":
break
if chunk.startswith("__headers__:"):
try:
headers = json.loads(chunk[12:])
if on_headers and isinstance(headers, dict):
on_headers({k: str(v) for k, v in headers.items()})
except (json.JSONDecodeError, TypeError) as e:
logger.debug("[fetch] 解析 __headers__ 失败: %s", e)
continue
if chunk.startswith("__error__:"):
msg = chunk[10:].strip()
saw_terminal = bool(error_state and error_state.get("terminal"))
if on_http_error:
unfreeze_at = on_http_error(msg, headers)
if isinstance(unfreeze_at, int):
logger.warning("[fetch] __error__ from page: %s", msg)
raise AccountFrozenError(msg, unfreeze_at)
if saw_terminal:
logger.info(
"[fetch] page fetch disconnected after terminal event: %s",
msg,
)
continue
logger.warning(
"[fetch] __error__ from page before terminal event: %s", msg
)
raise RuntimeError(msg)
continue
yield chunk
finally:
try:
await asyncio.wait_for(fetch_task, timeout=5.0)
except asyncio.TimeoutError:
fetch_task.cancel()
try:
await fetch_task
except asyncio.CancelledError:
pass
finally:
if cdp is not None:
try:
await cdp.detach()
except Exception as e:
logger.debug("detach CDP session 时异常: %s", e)
def parse_sse_to_events(buffer: str, chunk: str) -> tuple[str, list[str]]:
"""
把 chunk 追加到 buffer,按行拆出 data: 后的 payload 列表,返回 (剩余 buffer, payload 列表)。
接入方对每个 payload 自行 JSON 解析并抽取 text / message_id / error。
"""
buffer += chunk
lines = buffer.split("\n")
buffer = lines[-1]
payloads: list[str] = []
for line in lines[:-1]:
line = line.strip()
if not line.startswith("data: "):
continue
payload = line[6:].strip()
if payload == "[DONE]" or not payload:
continue
payloads.append(payload)
return (buffer, payloads)
async def stream_completion_via_sse(
context: BrowserContext,
page: Page,
url: str,
body: str,
parse_event: ParseSseEvent,
request_id: str,
*,
on_http_error: Callable,
is_terminal_event: Callable[[str], bool] | None = None,
collect_message_id: list[str] | None = None,
) -> AsyncIterator[str]:
"""
在浏览器内 POST 拿到流,按 SSE 行拆成 data 事件,用 parse_event(payload) 解析每条;
逐块 yield 文本,可选把 message_id 收集到 collect_message_id。
parse_event(payload) 返回 (texts, message_id, error),error 非空时仅打 debug 日志不抛错。
"""
buffer = ""
stream_state: dict[str, bool] = {"terminal": False}
async for chunk in stream_raw_via_page_fetch(
context,
page,
url,
body,
request_id,
on_http_error=on_http_error,
error_state=stream_state,
):
buffer, payloads = parse_sse_to_events(buffer, chunk)
for payload in payloads:
if is_terminal_event and is_terminal_event(payload):
stream_state["terminal"] = True
try:
texts, message_id, error = parse_event(payload)
except Exception as e:
logger.debug("parse_stream_event 单条解析异常: %s", e)
continue
if error:
logger.debug("SSE error: %s", error)
continue
if message_id and collect_message_id is not None:
collect_message_id.append(message_id)
for t in texts:
yield t