|
16 | 16 | from werkzeug.wrappers import Request, Response |
17 | 17 | import jsonschema |
18 | 18 |
|
| 19 | +from starlette.testclient import TestClient |
19 | 20 |
|
20 | 21 | try: |
21 | 22 | import gevent |
@@ -592,6 +593,85 @@ def suppress_deprecation_warnings(): |
592 | 593 | yield |
593 | 594 |
|
594 | 595 |
|
| 596 | +@pytest.fixture() |
| 597 | +def json_rpc(): |
| 598 | + def inner(app, method: str, params, request_id: str | None = None): |
| 599 | + if request_id is None: |
| 600 | + request_id = "1" # arbitrary |
| 601 | + |
| 602 | + with TestClient(app) as client: |
| 603 | + init_response = client.post( |
| 604 | + "/mcp/", |
| 605 | + headers={ |
| 606 | + "Accept": "application/json, text/event-stream", |
| 607 | + "Content-Type": "application/json", |
| 608 | + }, |
| 609 | + json={ |
| 610 | + "jsonrpc": "2.0", |
| 611 | + "method": "initialize", |
| 612 | + "params": { |
| 613 | + "clientInfo": {"name": "test-client", "version": "1.0"}, |
| 614 | + "protocolVersion": "2025-11-25", |
| 615 | + "capabilities": {}, |
| 616 | + }, |
| 617 | + "id": request_id, |
| 618 | + }, |
| 619 | + ) |
| 620 | + |
| 621 | + session_id = init_response.headers["mcp-session-id"] |
| 622 | + |
| 623 | + # Notification response is mandatory. |
| 624 | + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle |
| 625 | + client.post( |
| 626 | + "/mcp/", |
| 627 | + headers={ |
| 628 | + "Accept": "application/json, text/event-stream", |
| 629 | + "Content-Type": "application/json", |
| 630 | + "mcp-session-id": session_id, |
| 631 | + }, |
| 632 | + json={ |
| 633 | + "jsonrpc": "2.0", |
| 634 | + "method": "notifications/initialized", |
| 635 | + "params": {}, |
| 636 | + }, |
| 637 | + ) |
| 638 | + |
| 639 | + response = client.post( |
| 640 | + "/mcp/", |
| 641 | + headers={ |
| 642 | + "Accept": "application/json, text/event-stream", |
| 643 | + "Content-Type": "application/json", |
| 644 | + "mcp-session-id": session_id, |
| 645 | + }, |
| 646 | + json={ |
| 647 | + "jsonrpc": "2.0", |
| 648 | + "method": method, |
| 649 | + "params": params, |
| 650 | + "id": request_id, |
| 651 | + }, |
| 652 | + ) |
| 653 | + |
| 654 | + return session_id, response |
| 655 | + |
| 656 | + return inner |
| 657 | + |
| 658 | + |
| 659 | +@pytest.fixture() |
| 660 | +def select_transactions_with_mcp_spans(): |
| 661 | + def inner(events, method_name): |
| 662 | + return [ |
| 663 | + transaction |
| 664 | + for transaction in events |
| 665 | + if transaction["type"] == "transaction" |
| 666 | + and any( |
| 667 | + span["data"].get("mcp.method.name") == method_name |
| 668 | + for span in transaction.get("spans", []) |
| 669 | + ) |
| 670 | + ] |
| 671 | + |
| 672 | + return inner |
| 673 | + |
| 674 | + |
595 | 675 | class MockServerRequestHandler(BaseHTTPRequestHandler): |
596 | 676 | def do_GET(self): # noqa: N802 |
597 | 677 | # Process an HTTP GET request and return a response. |
|
0 commit comments