| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526 |
- package main
- import (
- "bufio"
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "strings"
- "sync"
- "time"
- )
- // MCP Protocol constants
- const (
- ProtocolVersion = "2024-11-05"
- )
- // JSON-RPC types
- type JSONRPCRequest struct {
- JSONRPC string `json:"jsonrpc"`
- ID interface{} `json:"id,omitempty"`
- Method string `json:"method"`
- Params json.RawMessage `json:"params,omitempty"`
- }
- type JSONRPCResponse struct {
- JSONRPC string `json:"jsonrpc"`
- ID interface{} `json:"id,omitempty"`
- Result json.RawMessage `json:"result,omitempty"`
- Error *RPCError `json:"error,omitempty"`
- }
- type RPCError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data interface{} `json:"data,omitempty"`
- }
- // MCP types
- type InitializeParams struct {
- ProtocolVersion string `json:"protocolVersion"`
- Capabilities ClientCapabilities `json:"capabilities"`
- ClientInfo ImplementationInfo `json:"clientInfo"`
- }
- type InitializeResult struct {
- ProtocolVersion string `json:"protocolVersion"`
- Capabilities ServerCapabilities `json:"capabilities"`
- ServerInfo ImplementationInfo `json:"serverInfo"`
- Instructions string `json:"instructions,omitempty"`
- }
- type ClientCapabilities struct {
- Experimental map[string]interface{} `json:"experimental,omitempty"`
- Roots *RootsCapability `json:"roots,omitempty"`
- Sampling *SamplingCapability `json:"sampling,omitempty"`
- }
- type RootsCapability struct {
- ListChanged bool `json:"listChanged,omitempty"`
- }
- type SamplingCapability struct{}
- type ServerCapabilities struct {
- Experimental map[string]interface{} `json:"experimental,omitempty"`
- Tools *ToolsCapability `json:"tools,omitempty"`
- Resources *ResourcesCapability `json:"resources,omitempty"`
- }
- type ToolsCapability struct {
- ListChanged bool `json:"listChanged,omitempty"`
- }
- type ResourcesCapability struct {
- Subscribe bool `json:"subscribe,omitempty"`
- ListChanged bool `json:"listChanged,omitempty"`
- }
- type ImplementationInfo struct {
- Name string `json:"name"`
- Version string `json:"version"`
- }
- // Tool types
- type Tool struct {
- Name string `json:"name"`
- Description string `json:"description"`
- InputSchema InputSchema `json:"inputSchema"`
- }
- type InputSchema struct {
- Type string `json:"type"`
- Properties map[string]Property `json:"properties,omitempty"`
- Required []string `json:"required,omitempty"`
- AdditionalProperties bool `json:"additionalProperties"`
- }
- type Property struct {
- Type string `json:"type"`
- Description string `json:"description,omitempty"`
- }
- type ListToolsResult struct {
- Tools []Tool `json:"tools"`
- }
- type CallToolParams struct {
- Name string `json:"name"`
- Arguments map[string]interface{} `json:"arguments,omitempty"`
- }
- type CallToolResult struct {
- Content []ContentBlock `json:"content"`
- IsError bool `json:"isError,omitempty"`
- }
- type ContentBlock struct {
- Type string `json:"type"`
- Text string `json:"text"`
- }
- // Resource types
- type Resource struct {
- URI string `json:"uri"`
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- MimeType string `json:"mimeType,omitempty"`
- }
- type ListResourcesResult struct {
- Resources []Resource `json:"resources"`
- }
- type SubscribeParams struct {
- URI string `json:"uri"`
- }
- type UnsubscribeParams struct {
- URI string `json:"uri"`
- }
- // Resource notification
- type ResourceUpdatedNotification struct {
- URI string `json:"uri"`
- Contents ResourceContents `json:"contents"`
- }
- type ResourceContents struct {
- URI string `json:"uri"`
- MimeType string `json:"mimeType,omitempty"`
- Text string `json:"text,omitempty"`
- Blob string `json:"blob,omitempty"`
- }
- // JSON-RPC Notification
- type JSONRPCNotification struct {
- JSONRPC string `json:"jsonrpc"`
- Method string `json:"method"`
- Params interface{} `json:"params,omitempty"`
- }
- // MCPClient is an MCP client for the ARP server
- type MCPClient struct {
- baseURL string
- token string
- httpClient *http.Client
- sseClient *http.Client // Separate client for SSE (no timeout)
- // SSE connection
- sseResp *http.Response
- sseDone chan struct{}
- sseEvents chan json.RawMessage
- // Message endpoint (received from SSE endpoint event)
- messageEndpoint string
- // Request ID counter
- idCounter int
- idMu sync.Mutex
- // Tools cache
- tools []Tool
- // Pending requests (ID -> response channel)
- pending map[interface{}]chan json.RawMessage
- pendingMu sync.Mutex
- }
- // NewMCPClient creates a new MCP client
- func NewMCPClient(baseURL string, token string) *MCPClient {
- return &MCPClient{
- baseURL: baseURL,
- token: token,
- httpClient: &http.Client{
- Timeout: 30 * time.Second,
- },
- sseClient: &http.Client{
- // No timeout for SSE - connection should stay open indefinitely
- Timeout: 0,
- },
- sseDone: make(chan struct{}),
- sseEvents: make(chan json.RawMessage, 100),
- pending: make(map[interface{}]chan json.RawMessage),
- }
- }
- // Connect establishes SSE connection to the MCP server
- func (c *MCPClient) Connect() error {
- // Build SSE URL
- sseURL := c.baseURL
- if !strings.HasSuffix(sseURL, "/mcp") {
- sseURL = strings.TrimSuffix(sseURL, "/")
- sseURL = sseURL + "/mcp"
- }
- req, err := http.NewRequest("GET", sseURL, nil)
- if err != nil {
- return fmt.Errorf("failed to create SSE request: %w", err)
- }
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Cache-Control", "no-cache")
- req.Header.Set("Connection", "keep-alive")
- if c.token != "" {
- req.Header.Set("Authorization", "Bearer "+c.token)
- }
- resp, err := c.sseClient.Do(req)
- if err != nil {
- return fmt.Errorf("failed to connect to SSE: %w", err)
- }
- if resp.StatusCode != http.StatusOK {
- resp.Body.Close()
- return fmt.Errorf("SSE connection failed with status: %d", resp.StatusCode)
- }
- c.sseResp = resp
- // Start reading SSE events
- go c.readSSE()
- // Wait for endpoint event
- select {
- case event := <-c.sseEvents:
- // The endpoint is sent as plain text, not JSON
- // e.g., "/message?sessionId=123456789"
- c.messageEndpoint = string(event)
- case <-time.After(10 * time.Second):
- return fmt.Errorf("timeout waiting for SSE endpoint event")
- }
- return nil
- }
- // readSSE reads SSE events from the response body
- func (c *MCPClient) readSSE() {
- defer close(c.sseEvents)
- scanner := bufio.NewScanner(c.sseResp.Body)
- var eventType string
- var eventData strings.Builder
- for scanner.Scan() {
- line := scanner.Text()
- if strings.HasPrefix(line, "event:") {
- eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
- eventData.Reset()
- } else if strings.HasPrefix(line, "data:") {
- data := strings.TrimPrefix(line, "data:")
- eventData.WriteString(data)
- } else if line == "" && eventType != "" {
- // End of event
- data := strings.TrimSpace(eventData.String())
- // Handle endpoint event specially
- if eventType == "endpoint" {
- select {
- case c.sseEvents <- json.RawMessage(data):
- default:
- }
- } else if eventType == "message" {
- // Parse to check if it's a response (has ID) or notification
- var msg struct {
- ID interface{} `json:"id"`
- }
- if err := json.Unmarshal([]byte(data), &msg); err == nil && msg.ID != nil {
- // JSON numbers are unmarshaled as float64, but we use int for IDs
- // Convert float64 to int for matching
- var idKey interface{} = msg.ID
- if f, ok := msg.ID.(float64); ok {
- idKey = int(f)
- }
- // It's a response - dispatch to pending request
- c.pendingMu.Lock()
- if ch, ok := c.pending[idKey]; ok {
- ch <- json.RawMessage(data)
- delete(c.pending, idKey)
- }
- c.pendingMu.Unlock()
- } else {
- // It's a notification - send to general events channel
- select {
- case c.sseEvents <- json.RawMessage(data):
- default:
- }
- }
- }
- eventType = ""
- eventData.Reset()
- }
- }
- }
- // Initialize sends the initialize request
- func (c *MCPClient) Initialize() (*InitializeResult, error) {
- params := InitializeParams{
- ProtocolVersion: ProtocolVersion,
- Capabilities: ClientCapabilities{
- Roots: &RootsCapability{ListChanged: false},
- },
- ClientInfo: ImplementationInfo{
- Name: "ARP Agent",
- Version: "1.0.0",
- },
- }
- result := &InitializeResult{}
- if err := c.sendRequest("initialize", params, result); err != nil {
- return nil, err
- }
- return result, nil
- }
- // ListTools discovers available tools
- func (c *MCPClient) ListTools() ([]Tool, error) {
- result := &ListToolsResult{}
- if err := c.sendRequest("tools/list", nil, result); err != nil {
- return nil, err
- }
- c.tools = result.Tools
- return result.Tools, nil
- }
- // CallTool executes a tool call
- func (c *MCPClient) CallTool(name string, arguments map[string]interface{}) (*CallToolResult, error) {
- params := CallToolParams{
- Name: name,
- Arguments: arguments,
- }
- result := &CallToolResult{}
- if err := c.sendRequest("tools/call", params, result); err != nil {
- return nil, err
- }
- return result, nil
- }
- // ListResources lists available resources
- func (c *MCPClient) ListResources() ([]Resource, error) {
- result := &ListResourcesResult{}
- if err := c.sendRequest("resources/list", nil, result); err != nil {
- return nil, err
- }
- return result.Resources, nil
- }
- // SubscribeResource subscribes to a resource for notifications
- func (c *MCPClient) SubscribeResource(uri string) error {
- params := SubscribeParams{URI: uri}
- return c.sendRequest("resources/subscribe", params, nil)
- }
- // UnsubscribeResource unsubscribes from a resource
- func (c *MCPClient) UnsubscribeResource(uri string) error {
- params := UnsubscribeParams{URI: uri}
- return c.sendRequest("resources/unsubscribe", params, nil)
- }
- // Notifications returns a channel for receiving resource notifications
- func (c *MCPClient) Notifications() <-chan json.RawMessage {
- return c.sseEvents
- }
- // Close closes the MCP client connection
- func (c *MCPClient) Close() error {
- close(c.sseDone)
- if c.sseResp != nil {
- return c.sseResp.Body.Close()
- }
- return nil
- }
- // nextID generates a unique request ID
- func (c *MCPClient) nextID() int {
- c.idMu.Lock()
- defer c.idMu.Unlock()
- c.idCounter++
- return c.idCounter
- }
- // sendRequest sends a JSON-RPC request and waits for the response via SSE
- func (c *MCPClient) sendRequest(method string, params interface{}, result interface{}) error {
- // Build request
- id := c.nextID()
- var paramsJSON json.RawMessage
- if params != nil {
- var err error
- paramsJSON, err = json.Marshal(params)
- if err != nil {
- return fmt.Errorf("failed to marshal params: %w", err)
- }
- }
- req := JSONRPCRequest{
- JSONRPC: "2.0",
- ID: id,
- Method: method,
- Params: paramsJSON,
- }
- reqBody, err := json.Marshal(req)
- if err != nil {
- return fmt.Errorf("failed to marshal request: %w", err)
- }
- // Build message URL
- messageURL := c.baseURL
- if c.messageEndpoint != "" {
- // Parse the endpoint URL - it may be relative or absolute
- if strings.HasPrefix(c.messageEndpoint, "/") {
- // Relative URL - parse it and merge with base URL
- endpointURL, err := url.Parse(c.messageEndpoint)
- if err != nil {
- return fmt.Errorf("failed to parse endpoint URL: %w", err)
- }
- baseURL, err := url.Parse(c.baseURL)
- if err != nil {
- return fmt.Errorf("failed to parse base URL: %w", err)
- }
- // Merge the endpoint with the base URL (preserves query string)
- baseURL.Path = endpointURL.Path
- baseURL.RawQuery = endpointURL.RawQuery
- messageURL = baseURL.String()
- } else {
- messageURL = c.messageEndpoint
- }
- }
- // Register pending request before sending
- respChan := make(chan json.RawMessage, 1)
- c.pendingMu.Lock()
- c.pending[id] = respChan
- c.pendingMu.Unlock()
- // Cleanup on return
- defer func() {
- c.pendingMu.Lock()
- delete(c.pending, id)
- c.pendingMu.Unlock()
- }()
- // Send HTTP POST request
- httpReq, err := http.NewRequest("POST", messageURL, bytes.NewReader(reqBody))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- if c.token != "" {
- httpReq.Header.Set("Authorization", "Bearer "+c.token)
- }
- resp, err := c.httpClient.Do(httpReq)
- if err != nil {
- return fmt.Errorf("request failed: %w", err)
- }
- defer resp.Body.Close()
- // Check HTTP status
- if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
- }
- // Wait for response via SSE
- select {
- case respData := <-respChan:
- // Parse response
- var rpcResp JSONRPCResponse
- if err := json.Unmarshal(respData, &rpcResp); err != nil {
- return fmt.Errorf("failed to parse response: %w", err)
- }
- if rpcResp.Error != nil {
- return fmt.Errorf("RPC error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
- }
- // Parse result if provided
- if result != nil && rpcResp.Result != nil {
- if err := json.Unmarshal(rpcResp.Result, result); err != nil {
- return fmt.Errorf("failed to parse result: %w", err)
- }
- }
- return nil
- case <-time.After(30 * time.Second):
- return fmt.Errorf("timeout waiting for response")
- case <-c.sseDone:
- return fmt.Errorf("connection closed")
- }
- }
- // GetTools returns the cached tools
- func (c *MCPClient) GetTools() []Tool {
- return c.tools
- }
|