mcp_client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. package main
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. // MCP Protocol constants
  15. const (
  16. ProtocolVersion = "2024-11-05"
  17. )
  18. // JSON-RPC types
  19. type JSONRPCRequest struct {
  20. JSONRPC string `json:"jsonrpc"`
  21. ID interface{} `json:"id,omitempty"`
  22. Method string `json:"method"`
  23. Params json.RawMessage `json:"params,omitempty"`
  24. }
  25. type JSONRPCResponse struct {
  26. JSONRPC string `json:"jsonrpc"`
  27. ID interface{} `json:"id,omitempty"`
  28. Result json.RawMessage `json:"result,omitempty"`
  29. Error *RPCError `json:"error,omitempty"`
  30. }
  31. type RPCError struct {
  32. Code int `json:"code"`
  33. Message string `json:"message"`
  34. Data interface{} `json:"data,omitempty"`
  35. }
  36. // MCP types
  37. type InitializeParams struct {
  38. ProtocolVersion string `json:"protocolVersion"`
  39. Capabilities ClientCapabilities `json:"capabilities"`
  40. ClientInfo ImplementationInfo `json:"clientInfo"`
  41. }
  42. type InitializeResult struct {
  43. ProtocolVersion string `json:"protocolVersion"`
  44. Capabilities ServerCapabilities `json:"capabilities"`
  45. ServerInfo ImplementationInfo `json:"serverInfo"`
  46. Instructions string `json:"instructions,omitempty"`
  47. }
  48. type ClientCapabilities struct {
  49. Experimental map[string]interface{} `json:"experimental,omitempty"`
  50. Roots *RootsCapability `json:"roots,omitempty"`
  51. Sampling *SamplingCapability `json:"sampling,omitempty"`
  52. }
  53. type RootsCapability struct {
  54. ListChanged bool `json:"listChanged,omitempty"`
  55. }
  56. type SamplingCapability struct{}
  57. type ServerCapabilities struct {
  58. Experimental map[string]interface{} `json:"experimental,omitempty"`
  59. Tools *ToolsCapability `json:"tools,omitempty"`
  60. Resources *ResourcesCapability `json:"resources,omitempty"`
  61. }
  62. type ToolsCapability struct {
  63. ListChanged bool `json:"listChanged,omitempty"`
  64. }
  65. type ResourcesCapability struct {
  66. Subscribe bool `json:"subscribe,omitempty"`
  67. ListChanged bool `json:"listChanged,omitempty"`
  68. }
  69. type ImplementationInfo struct {
  70. Name string `json:"name"`
  71. Version string `json:"version"`
  72. }
  73. // Tool types
  74. type Tool struct {
  75. Name string `json:"name"`
  76. Description string `json:"description"`
  77. InputSchema InputSchema `json:"inputSchema"`
  78. }
  79. type InputSchema struct {
  80. Type string `json:"type"`
  81. Properties map[string]Property `json:"properties,omitempty"`
  82. Required []string `json:"required,omitempty"`
  83. AdditionalProperties bool `json:"additionalProperties"`
  84. }
  85. type Property struct {
  86. Type string `json:"type"`
  87. Description string `json:"description,omitempty"`
  88. }
  89. type ListToolsResult struct {
  90. Tools []Tool `json:"tools"`
  91. }
  92. type CallToolParams struct {
  93. Name string `json:"name"`
  94. Arguments map[string]interface{} `json:"arguments,omitempty"`
  95. }
  96. type CallToolResult struct {
  97. Content []ContentBlock `json:"content"`
  98. IsError bool `json:"isError,omitempty"`
  99. }
  100. type ContentBlock struct {
  101. Type string `json:"type"`
  102. Text string `json:"text"`
  103. }
  104. // Resource types
  105. type Resource struct {
  106. URI string `json:"uri"`
  107. Name string `json:"name"`
  108. Description string `json:"description,omitempty"`
  109. MimeType string `json:"mimeType,omitempty"`
  110. }
  111. type ListResourcesResult struct {
  112. Resources []Resource `json:"resources"`
  113. }
  114. type SubscribeParams struct {
  115. URI string `json:"uri"`
  116. }
  117. type UnsubscribeParams struct {
  118. URI string `json:"uri"`
  119. }
  120. // Resource notification
  121. type ResourceUpdatedNotification struct {
  122. URI string `json:"uri"`
  123. Contents ResourceContents `json:"contents"`
  124. }
  125. type ResourceContents struct {
  126. URI string `json:"uri"`
  127. MimeType string `json:"mimeType,omitempty"`
  128. Text string `json:"text,omitempty"`
  129. Blob string `json:"blob,omitempty"`
  130. }
  131. // JSON-RPC Notification
  132. type JSONRPCNotification struct {
  133. JSONRPC string `json:"jsonrpc"`
  134. Method string `json:"method"`
  135. Params interface{} `json:"params,omitempty"`
  136. }
  137. // MCPClient is an MCP client for the ARP server
  138. type MCPClient struct {
  139. baseURL string
  140. token string
  141. httpClient *http.Client
  142. sseClient *http.Client // Separate client for SSE (no timeout)
  143. // SSE connection
  144. sseResp *http.Response
  145. sseDone chan struct{}
  146. sseEvents chan json.RawMessage
  147. // Message endpoint (received from SSE endpoint event)
  148. messageEndpoint string
  149. // Request ID counter
  150. idCounter int
  151. idMu sync.Mutex
  152. // Tools cache
  153. tools []Tool
  154. // Pending requests (ID -> response channel)
  155. pending map[interface{}]chan json.RawMessage
  156. pendingMu sync.Mutex
  157. }
  158. // NewMCPClient creates a new MCP client
  159. func NewMCPClient(baseURL string, token string) *MCPClient {
  160. return &MCPClient{
  161. baseURL: baseURL,
  162. token: token,
  163. httpClient: &http.Client{
  164. Timeout: 30 * time.Second,
  165. },
  166. sseClient: &http.Client{
  167. // No timeout for SSE - connection should stay open indefinitely
  168. Timeout: 0,
  169. },
  170. sseDone: make(chan struct{}),
  171. sseEvents: make(chan json.RawMessage, 100),
  172. pending: make(map[interface{}]chan json.RawMessage),
  173. }
  174. }
  175. // Connect establishes SSE connection to the MCP server
  176. func (c *MCPClient) Connect() error {
  177. // Build SSE URL
  178. sseURL := c.baseURL
  179. if !strings.HasSuffix(sseURL, "/mcp") {
  180. sseURL = strings.TrimSuffix(sseURL, "/")
  181. sseURL = sseURL + "/mcp"
  182. }
  183. req, err := http.NewRequest("GET", sseURL, nil)
  184. if err != nil {
  185. return fmt.Errorf("failed to create SSE request: %w", err)
  186. }
  187. req.Header.Set("Accept", "text/event-stream")
  188. req.Header.Set("Cache-Control", "no-cache")
  189. req.Header.Set("Connection", "keep-alive")
  190. if c.token != "" {
  191. req.Header.Set("Authorization", "Bearer "+c.token)
  192. }
  193. resp, err := c.sseClient.Do(req)
  194. if err != nil {
  195. return fmt.Errorf("failed to connect to SSE: %w", err)
  196. }
  197. if resp.StatusCode != http.StatusOK {
  198. resp.Body.Close()
  199. return fmt.Errorf("SSE connection failed with status: %d", resp.StatusCode)
  200. }
  201. c.sseResp = resp
  202. // Start reading SSE events
  203. go c.readSSE()
  204. // Wait for endpoint event
  205. select {
  206. case event := <-c.sseEvents:
  207. // The endpoint is sent as plain text, not JSON
  208. // e.g., "/message?sessionId=123456789"
  209. c.messageEndpoint = string(event)
  210. case <-time.After(10 * time.Second):
  211. return fmt.Errorf("timeout waiting for SSE endpoint event")
  212. }
  213. return nil
  214. }
  215. // readSSE reads SSE events from the response body
  216. func (c *MCPClient) readSSE() {
  217. defer close(c.sseEvents)
  218. scanner := bufio.NewScanner(c.sseResp.Body)
  219. var eventType string
  220. var eventData strings.Builder
  221. for scanner.Scan() {
  222. line := scanner.Text()
  223. if strings.HasPrefix(line, "event:") {
  224. eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
  225. eventData.Reset()
  226. } else if strings.HasPrefix(line, "data:") {
  227. data := strings.TrimPrefix(line, "data:")
  228. eventData.WriteString(data)
  229. } else if line == "" && eventType != "" {
  230. // End of event
  231. data := strings.TrimSpace(eventData.String())
  232. // Handle endpoint event specially
  233. if eventType == "endpoint" {
  234. select {
  235. case c.sseEvents <- json.RawMessage(data):
  236. default:
  237. }
  238. } else if eventType == "message" {
  239. // Parse to check if it's a response (has ID) or notification
  240. var msg struct {
  241. ID interface{} `json:"id"`
  242. }
  243. if err := json.Unmarshal([]byte(data), &msg); err == nil && msg.ID != nil {
  244. // JSON numbers are unmarshaled as float64, but we use int for IDs
  245. // Convert float64 to int for matching
  246. var idKey interface{} = msg.ID
  247. if f, ok := msg.ID.(float64); ok {
  248. idKey = int(f)
  249. }
  250. // It's a response - dispatch to pending request
  251. c.pendingMu.Lock()
  252. if ch, ok := c.pending[idKey]; ok {
  253. ch <- json.RawMessage(data)
  254. delete(c.pending, idKey)
  255. }
  256. c.pendingMu.Unlock()
  257. } else {
  258. // It's a notification - send to general events channel
  259. select {
  260. case c.sseEvents <- json.RawMessage(data):
  261. default:
  262. }
  263. }
  264. }
  265. eventType = ""
  266. eventData.Reset()
  267. }
  268. }
  269. }
  270. // Initialize sends the initialize request
  271. func (c *MCPClient) Initialize() (*InitializeResult, error) {
  272. params := InitializeParams{
  273. ProtocolVersion: ProtocolVersion,
  274. Capabilities: ClientCapabilities{
  275. Roots: &RootsCapability{ListChanged: false},
  276. },
  277. ClientInfo: ImplementationInfo{
  278. Name: "ARP Agent",
  279. Version: "1.0.0",
  280. },
  281. }
  282. result := &InitializeResult{}
  283. if err := c.sendRequest("initialize", params, result); err != nil {
  284. return nil, err
  285. }
  286. return result, nil
  287. }
  288. // ListTools discovers available tools
  289. func (c *MCPClient) ListTools() ([]Tool, error) {
  290. result := &ListToolsResult{}
  291. if err := c.sendRequest("tools/list", nil, result); err != nil {
  292. return nil, err
  293. }
  294. c.tools = result.Tools
  295. return result.Tools, nil
  296. }
  297. // CallTool executes a tool call
  298. func (c *MCPClient) CallTool(name string, arguments map[string]interface{}) (*CallToolResult, error) {
  299. params := CallToolParams{
  300. Name: name,
  301. Arguments: arguments,
  302. }
  303. result := &CallToolResult{}
  304. if err := c.sendRequest("tools/call", params, result); err != nil {
  305. return nil, err
  306. }
  307. return result, nil
  308. }
  309. // ListResources lists available resources
  310. func (c *MCPClient) ListResources() ([]Resource, error) {
  311. result := &ListResourcesResult{}
  312. if err := c.sendRequest("resources/list", nil, result); err != nil {
  313. return nil, err
  314. }
  315. return result.Resources, nil
  316. }
  317. // SubscribeResource subscribes to a resource for notifications
  318. func (c *MCPClient) SubscribeResource(uri string) error {
  319. params := SubscribeParams{URI: uri}
  320. return c.sendRequest("resources/subscribe", params, nil)
  321. }
  322. // UnsubscribeResource unsubscribes from a resource
  323. func (c *MCPClient) UnsubscribeResource(uri string) error {
  324. params := UnsubscribeParams{URI: uri}
  325. return c.sendRequest("resources/unsubscribe", params, nil)
  326. }
  327. // Notifications returns a channel for receiving resource notifications
  328. func (c *MCPClient) Notifications() <-chan json.RawMessage {
  329. return c.sseEvents
  330. }
  331. // Close closes the MCP client connection
  332. func (c *MCPClient) Close() error {
  333. close(c.sseDone)
  334. if c.sseResp != nil {
  335. return c.sseResp.Body.Close()
  336. }
  337. return nil
  338. }
  339. // nextID generates a unique request ID
  340. func (c *MCPClient) nextID() int {
  341. c.idMu.Lock()
  342. defer c.idMu.Unlock()
  343. c.idCounter++
  344. return c.idCounter
  345. }
  346. // sendRequest sends a JSON-RPC request and waits for the response via SSE
  347. func (c *MCPClient) sendRequest(method string, params interface{}, result interface{}) error {
  348. // Build request
  349. id := c.nextID()
  350. var paramsJSON json.RawMessage
  351. if params != nil {
  352. var err error
  353. paramsJSON, err = json.Marshal(params)
  354. if err != nil {
  355. return fmt.Errorf("failed to marshal params: %w", err)
  356. }
  357. }
  358. req := JSONRPCRequest{
  359. JSONRPC: "2.0",
  360. ID: id,
  361. Method: method,
  362. Params: paramsJSON,
  363. }
  364. reqBody, err := json.Marshal(req)
  365. if err != nil {
  366. return fmt.Errorf("failed to marshal request: %w", err)
  367. }
  368. // Build message URL
  369. messageURL := c.baseURL
  370. if c.messageEndpoint != "" {
  371. // Parse the endpoint URL - it may be relative or absolute
  372. if strings.HasPrefix(c.messageEndpoint, "/") {
  373. // Relative URL - parse it and merge with base URL
  374. endpointURL, err := url.Parse(c.messageEndpoint)
  375. if err != nil {
  376. return fmt.Errorf("failed to parse endpoint URL: %w", err)
  377. }
  378. baseURL, err := url.Parse(c.baseURL)
  379. if err != nil {
  380. return fmt.Errorf("failed to parse base URL: %w", err)
  381. }
  382. // Merge the endpoint with the base URL (preserves query string)
  383. baseURL.Path = endpointURL.Path
  384. baseURL.RawQuery = endpointURL.RawQuery
  385. messageURL = baseURL.String()
  386. } else {
  387. messageURL = c.messageEndpoint
  388. }
  389. }
  390. // Register pending request before sending
  391. respChan := make(chan json.RawMessage, 1)
  392. c.pendingMu.Lock()
  393. c.pending[id] = respChan
  394. c.pendingMu.Unlock()
  395. // Cleanup on return
  396. defer func() {
  397. c.pendingMu.Lock()
  398. delete(c.pending, id)
  399. c.pendingMu.Unlock()
  400. }()
  401. // Send HTTP POST request
  402. httpReq, err := http.NewRequest("POST", messageURL, bytes.NewReader(reqBody))
  403. if err != nil {
  404. return fmt.Errorf("failed to create request: %w", err)
  405. }
  406. httpReq.Header.Set("Content-Type", "application/json")
  407. if c.token != "" {
  408. httpReq.Header.Set("Authorization", "Bearer "+c.token)
  409. }
  410. resp, err := c.httpClient.Do(httpReq)
  411. if err != nil {
  412. return fmt.Errorf("request failed: %w", err)
  413. }
  414. defer resp.Body.Close()
  415. // Check HTTP status
  416. if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusOK {
  417. body, _ := io.ReadAll(resp.Body)
  418. return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
  419. }
  420. // Wait for response via SSE
  421. select {
  422. case respData := <-respChan:
  423. // Parse response
  424. var rpcResp JSONRPCResponse
  425. if err := json.Unmarshal(respData, &rpcResp); err != nil {
  426. return fmt.Errorf("failed to parse response: %w", err)
  427. }
  428. if rpcResp.Error != nil {
  429. return fmt.Errorf("RPC error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
  430. }
  431. // Parse result if provided
  432. if result != nil && rpcResp.Result != nil {
  433. if err := json.Unmarshal(rpcResp.Result, result); err != nil {
  434. return fmt.Errorf("failed to parse result: %w", err)
  435. }
  436. }
  437. return nil
  438. case <-time.After(30 * time.Second):
  439. return fmt.Errorf("timeout waiting for response")
  440. case <-c.sseDone:
  441. return fmt.Errorf("connection closed")
  442. }
  443. }
  444. // GetTools returns the cached tools
  445. func (c *MCPClient) GetTools() []Tool {
  446. return c.tools
  447. }