mcp_stdio.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package main
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "log"
  8. "os"
  9. "os/exec"
  10. "sync"
  11. "time"
  12. )
  13. // MCPStdioClient is an MCP client that communicates via stdin/stdout
  14. // Used for external MCP servers spawned as child processes
  15. type MCPStdioClient struct {
  16. serverName string
  17. config MCPServerConfig
  18. cmd *exec.Cmd
  19. stdin io.WriteCloser
  20. stdout io.Reader
  21. stderr io.Reader
  22. // Request ID counter
  23. idCounter int
  24. idMu sync.Mutex
  25. // Pending requests (ID -> response channel)
  26. pending map[interface{}]chan json.RawMessage
  27. pendingMu sync.Mutex
  28. // Tools cache
  29. tools []Tool
  30. // Done channel for cleanup
  31. done chan struct{}
  32. // Close once for idempotent close
  33. closeOnce sync.Once
  34. }
  35. // NewMCPStdioClient creates a new stdio MCP client for an external server
  36. func NewMCPStdioClient(serverName string, config MCPServerConfig) *MCPStdioClient {
  37. return &MCPStdioClient{
  38. serverName: serverName,
  39. config: config,
  40. pending: make(map[interface{}]chan json.RawMessage),
  41. done: make(chan struct{}),
  42. }
  43. }
  44. // Start spawns the external MCP server process
  45. func (c *MCPStdioClient) Start() error {
  46. // Build command
  47. c.cmd = exec.Command(c.config.Command, c.config.Args...)
  48. // Set environment variables
  49. if len(c.config.Env) > 0 {
  50. env := os.Environ()
  51. for key, value := range c.config.Env {
  52. env = append(env, fmt.Sprintf("%s=%s", key, value))
  53. }
  54. c.cmd.Env = env
  55. }
  56. // Get stdin pipe
  57. stdin, err := c.cmd.StdinPipe()
  58. if err != nil {
  59. return fmt.Errorf("failed to get stdin pipe: %w", err)
  60. }
  61. c.stdin = stdin
  62. // Get stdout pipe
  63. stdout, err := c.cmd.StdoutPipe()
  64. if err != nil {
  65. return fmt.Errorf("failed to get stdout pipe: %w", err)
  66. }
  67. c.stdout = stdout
  68. // Get stderr pipe for logging
  69. stderr, err := c.cmd.StderrPipe()
  70. if err != nil {
  71. return fmt.Errorf("failed to get stderr pipe: %w", err)
  72. }
  73. c.stderr = stderr
  74. // Start the process
  75. if err := c.cmd.Start(); err != nil {
  76. return fmt.Errorf("failed to start MCP server '%s': %w", c.serverName, err)
  77. }
  78. // Start reading stdout in background
  79. go c.readOutput()
  80. // Start reading stderr in background for logging
  81. go c.readStderr()
  82. return nil
  83. }
  84. // readOutput reads JSON-RPC responses from stdout
  85. func (c *MCPStdioClient) readOutput() {
  86. scanner := bufio.NewScanner(c.stdout)
  87. for scanner.Scan() {
  88. line := scanner.Bytes()
  89. // Parse to check if it's a response (has ID) or notification
  90. var msg struct {
  91. ID interface{} `json:"id"`
  92. }
  93. if err := json.Unmarshal(line, &msg); err != nil {
  94. log.Printf("[%s] Failed to parse output: %v", c.serverName, err)
  95. continue
  96. }
  97. if msg.ID != nil {
  98. // JSON numbers are unmarshaled as float64, but we use int for IDs
  99. var idKey interface{} = msg.ID
  100. if f, ok := msg.ID.(float64); ok {
  101. idKey = int(f)
  102. }
  103. // It's a response - dispatch to pending request
  104. c.pendingMu.Lock()
  105. if ch, ok := c.pending[idKey]; ok {
  106. ch <- json.RawMessage(line)
  107. delete(c.pending, idKey)
  108. }
  109. c.pendingMu.Unlock()
  110. }
  111. // Notifications are ignored for now (external servers typically don't send them)
  112. }
  113. }
  114. // readStderr reads stderr output for debugging
  115. func (c *MCPStdioClient) readStderr() {
  116. scanner := bufio.NewScanner(c.stderr)
  117. for scanner.Scan() {
  118. log.Printf("[%s] stderr: %s", c.serverName, scanner.Text())
  119. }
  120. }
  121. // Initialize sends the initialize request
  122. func (c *MCPStdioClient) Initialize() (*InitializeResult, error) {
  123. params := InitializeParams{
  124. ProtocolVersion: ProtocolVersion,
  125. Capabilities: ClientCapabilities{
  126. Roots: &RootsCapability{ListChanged: false},
  127. },
  128. ClientInfo: ImplementationInfo{
  129. Name: "ARP Agent",
  130. Version: "1.0.0",
  131. },
  132. }
  133. result := &InitializeResult{}
  134. if err := c.sendRequest("initialize", params, result); err != nil {
  135. return nil, err
  136. }
  137. return result, nil
  138. }
  139. // ListTools discovers available tools from this server
  140. func (c *MCPStdioClient) ListTools() ([]Tool, error) {
  141. result := &ListToolsResult{}
  142. if err := c.sendRequest("tools/list", nil, result); err != nil {
  143. return nil, err
  144. }
  145. c.tools = result.Tools
  146. return result.Tools, nil
  147. }
  148. // CallTool executes a tool call
  149. func (c *MCPStdioClient) CallTool(name string, arguments map[string]interface{}) (*CallToolResult, error) {
  150. params := CallToolParams{
  151. Name: name,
  152. Arguments: arguments,
  153. }
  154. result := &CallToolResult{}
  155. if err := c.sendRequest("tools/call", params, result); err != nil {
  156. return nil, err
  157. }
  158. return result, nil
  159. }
  160. // GetTools returns the cached tools
  161. func (c *MCPStdioClient) GetTools() []Tool {
  162. return c.tools
  163. }
  164. // Close stops the external MCP server process (idempotent)
  165. func (c *MCPStdioClient) Close() error {
  166. c.closeOnce.Do(func() {
  167. close(c.done)
  168. })
  169. if c.stdin != nil {
  170. c.stdin.Close()
  171. }
  172. if c.cmd != nil && c.cmd.Process != nil {
  173. // Give the process a moment to exit gracefully
  174. done := make(chan error, 1)
  175. go func() {
  176. done <- c.cmd.Wait()
  177. }()
  178. select {
  179. case <-time.After(5 * time.Second):
  180. // Force kill if it doesn't exit gracefully
  181. log.Printf("[%s] Force killing process", c.serverName)
  182. c.cmd.Process.Kill()
  183. case <-done:
  184. // Process exited gracefully
  185. }
  186. }
  187. return nil
  188. }
  189. // nextID generates a unique request ID
  190. func (c *MCPStdioClient) nextID() int {
  191. c.idMu.Lock()
  192. defer c.idMu.Unlock()
  193. c.idCounter++
  194. return c.idCounter
  195. }
  196. // sendRequest sends a JSON-RPC request and waits for the response
  197. func (c *MCPStdioClient) sendRequest(method string, params interface{}, result interface{}) error {
  198. // Build request
  199. id := c.nextID()
  200. var paramsJSON json.RawMessage
  201. if params != nil {
  202. var err error
  203. paramsJSON, err = json.Marshal(params)
  204. if err != nil {
  205. return fmt.Errorf("failed to marshal params: %w", err)
  206. }
  207. }
  208. req := JSONRPCRequest{
  209. JSONRPC: "2.0",
  210. ID: id,
  211. Method: method,
  212. Params: paramsJSON,
  213. }
  214. reqBody, err := json.Marshal(req)
  215. if err != nil {
  216. return fmt.Errorf("failed to marshal request: %w", err)
  217. }
  218. // Add newline as line delimiter
  219. reqBody = append(reqBody, '\n')
  220. // Register pending request before sending
  221. respChan := make(chan json.RawMessage, 1)
  222. c.pendingMu.Lock()
  223. c.pending[id] = respChan
  224. c.pendingMu.Unlock()
  225. // Cleanup on return
  226. defer func() {
  227. c.pendingMu.Lock()
  228. delete(c.pending, id)
  229. c.pendingMu.Unlock()
  230. }()
  231. // Send request
  232. if _, err := c.stdin.Write(reqBody); err != nil {
  233. return fmt.Errorf("failed to send request: %w", err)
  234. }
  235. // Wait for response
  236. select {
  237. case respData := <-respChan:
  238. // Parse response
  239. var rpcResp JSONRPCResponse
  240. if err := json.Unmarshal(respData, &rpcResp); err != nil {
  241. return fmt.Errorf("failed to parse response: %w", err)
  242. }
  243. if rpcResp.Error != nil {
  244. return fmt.Errorf("RPC error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
  245. }
  246. // Parse result if provided
  247. if result != nil && rpcResp.Result != nil {
  248. if err := json.Unmarshal(rpcResp.Result, result); err != nil {
  249. return fmt.Errorf("failed to parse result: %w", err)
  250. }
  251. }
  252. return nil
  253. case <-time.After(30 * time.Second):
  254. return fmt.Errorf("timeout waiting for response from %s", c.serverName)
  255. case <-c.done:
  256. return fmt.Errorf("client closed")
  257. }
  258. }