mcp_stdio.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package main
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "log"
  8. "os"
  9. "os/exec"
  10. "sync"
  11. "time"
  12. )
  13. // MCPStdioClient is an MCP client that communicates via stdin/stdout
  14. // Used for external MCP servers spawned as child processes
  15. type MCPStdioClient struct {
  16. serverName string
  17. config MCPServerConfig
  18. cmd *exec.Cmd
  19. stdin io.WriteCloser
  20. stdout io.Reader
  21. stderr io.Reader
  22. // Request ID counter
  23. idCounter int
  24. idMu sync.Mutex
  25. // Pending requests (ID -> response channel)
  26. pending map[interface{}]chan json.RawMessage
  27. pendingMu sync.Mutex
  28. // Tools cache
  29. tools []Tool
  30. // Done channel for cleanup
  31. done chan struct{}
  32. doneMu sync.Mutex
  33. }
  34. // NewMCPStdioClient creates a new stdio MCP client for an external server
  35. func NewMCPStdioClient(serverName string, config MCPServerConfig) *MCPStdioClient {
  36. return &MCPStdioClient{
  37. serverName: serverName,
  38. config: config,
  39. pending: make(map[interface{}]chan json.RawMessage),
  40. done: make(chan struct{}),
  41. }
  42. }
  43. // Start spawns the external MCP server process
  44. func (c *MCPStdioClient) Start() error {
  45. // Build command
  46. c.cmd = exec.Command(c.config.Command, c.config.Args...)
  47. // Set environment variables
  48. if len(c.config.Env) > 0 {
  49. env := os.Environ()
  50. for key, value := range c.config.Env {
  51. env = append(env, fmt.Sprintf("%s=%s", key, value))
  52. }
  53. c.cmd.Env = env
  54. }
  55. // Get stdin pipe
  56. stdin, err := c.cmd.StdinPipe()
  57. if err != nil {
  58. return fmt.Errorf("failed to get stdin pipe: %w", err)
  59. }
  60. c.stdin = stdin
  61. // Get stdout pipe
  62. stdout, err := c.cmd.StdoutPipe()
  63. if err != nil {
  64. return fmt.Errorf("failed to get stdout pipe: %w", err)
  65. }
  66. c.stdout = stdout
  67. // Get stderr pipe for logging
  68. stderr, err := c.cmd.StderrPipe()
  69. if err != nil {
  70. return fmt.Errorf("failed to get stderr pipe: %w", err)
  71. }
  72. c.stderr = stderr
  73. // Start the process
  74. if err := c.cmd.Start(); err != nil {
  75. return fmt.Errorf("failed to start MCP server '%s': %w", c.serverName, err)
  76. }
  77. // Start reading stdout in background
  78. go c.readOutput()
  79. // Start reading stderr in background for logging
  80. go c.readStderr()
  81. return nil
  82. }
  83. // readOutput reads JSON-RPC responses from stdout
  84. func (c *MCPStdioClient) readOutput() {
  85. scanner := bufio.NewScanner(c.stdout)
  86. for scanner.Scan() {
  87. line := scanner.Bytes()
  88. // Parse to check if it's a response (has ID) or notification
  89. var msg struct {
  90. ID interface{} `json:"id"`
  91. }
  92. if err := json.Unmarshal(line, &msg); err != nil {
  93. log.Printf("[%s] Failed to parse output: %v", c.serverName, err)
  94. continue
  95. }
  96. if msg.ID != nil {
  97. // JSON numbers are unmarshaled as float64, but we use int for IDs
  98. var idKey interface{} = msg.ID
  99. if f, ok := msg.ID.(float64); ok {
  100. idKey = int(f)
  101. }
  102. // It's a response - dispatch to pending request
  103. c.pendingMu.Lock()
  104. if ch, ok := c.pending[idKey]; ok {
  105. ch <- json.RawMessage(line)
  106. delete(c.pending, idKey)
  107. }
  108. c.pendingMu.Unlock()
  109. }
  110. // Notifications are ignored for now (external servers typically don't send them)
  111. }
  112. }
  113. // readStderr reads stderr output for debugging
  114. func (c *MCPStdioClient) readStderr() {
  115. scanner := bufio.NewScanner(c.stderr)
  116. for scanner.Scan() {
  117. log.Printf("[%s] stderr: %s", c.serverName, scanner.Text())
  118. }
  119. }
  120. // Initialize sends the initialize request
  121. func (c *MCPStdioClient) Initialize() (*InitializeResult, error) {
  122. params := InitializeParams{
  123. ProtocolVersion: ProtocolVersion,
  124. Capabilities: ClientCapabilities{
  125. Roots: &RootsCapability{ListChanged: false},
  126. },
  127. ClientInfo: ImplementationInfo{
  128. Name: "ARP Agent",
  129. Version: "1.0.0",
  130. },
  131. }
  132. result := &InitializeResult{}
  133. if err := c.sendRequest("initialize", params, result); err != nil {
  134. return nil, err
  135. }
  136. return result, nil
  137. }
  138. // ListTools discovers available tools from this server
  139. func (c *MCPStdioClient) ListTools() ([]Tool, error) {
  140. result := &ListToolsResult{}
  141. if err := c.sendRequest("tools/list", nil, result); err != nil {
  142. return nil, err
  143. }
  144. c.tools = result.Tools
  145. return result.Tools, nil
  146. }
  147. // CallTool executes a tool call
  148. func (c *MCPStdioClient) CallTool(name string, arguments map[string]interface{}) (*CallToolResult, error) {
  149. params := CallToolParams{
  150. Name: name,
  151. Arguments: arguments,
  152. }
  153. result := &CallToolResult{}
  154. if err := c.sendRequest("tools/call", params, result); err != nil {
  155. return nil, err
  156. }
  157. return result, nil
  158. }
  159. // GetTools returns the cached tools
  160. func (c *MCPStdioClient) GetTools() []Tool {
  161. return c.tools
  162. }
  163. // Close stops the external MCP server process
  164. func (c *MCPStdioClient) Close() error {
  165. c.doneMu.Lock()
  166. select {
  167. case <-c.done:
  168. // Already closed
  169. default:
  170. close(c.done)
  171. }
  172. c.doneMu.Unlock()
  173. if c.stdin != nil {
  174. c.stdin.Close()
  175. }
  176. if c.cmd != nil && c.cmd.Process != nil {
  177. // Give the process a moment to exit gracefully
  178. done := make(chan error, 1)
  179. go func() {
  180. done <- c.cmd.Wait()
  181. }()
  182. select {
  183. case <-time.After(5 * time.Second):
  184. // Force kill if it doesn't exit gracefully
  185. log.Printf("[%s] Force killing process", c.serverName)
  186. c.cmd.Process.Kill()
  187. case <-done:
  188. // Process exited gracefully
  189. }
  190. }
  191. return nil
  192. }
  193. // nextID generates a unique request ID
  194. func (c *MCPStdioClient) nextID() int {
  195. c.idMu.Lock()
  196. defer c.idMu.Unlock()
  197. c.idCounter++
  198. return c.idCounter
  199. }
  200. // sendRequest sends a JSON-RPC request and waits for the response
  201. func (c *MCPStdioClient) sendRequest(method string, params interface{}, result interface{}) error {
  202. // Build request
  203. id := c.nextID()
  204. var paramsJSON json.RawMessage
  205. if params != nil {
  206. var err error
  207. paramsJSON, err = json.Marshal(params)
  208. if err != nil {
  209. return fmt.Errorf("failed to marshal params: %w", err)
  210. }
  211. }
  212. req := JSONRPCRequest{
  213. JSONRPC: "2.0",
  214. ID: id,
  215. Method: method,
  216. Params: paramsJSON,
  217. }
  218. reqBody, err := json.Marshal(req)
  219. if err != nil {
  220. return fmt.Errorf("failed to marshal request: %w", err)
  221. }
  222. // Add newline as line delimiter
  223. reqBody = append(reqBody, '\n')
  224. // Register pending request before sending
  225. respChan := make(chan json.RawMessage, 1)
  226. c.pendingMu.Lock()
  227. c.pending[id] = respChan
  228. c.pendingMu.Unlock()
  229. // Cleanup on return
  230. defer func() {
  231. c.pendingMu.Lock()
  232. delete(c.pending, id)
  233. c.pendingMu.Unlock()
  234. }()
  235. // Send request
  236. if _, err := c.stdin.Write(reqBody); err != nil {
  237. return fmt.Errorf("failed to send request: %w", err)
  238. }
  239. // Wait for response
  240. select {
  241. case respData := <-respChan:
  242. // Parse response
  243. var rpcResp JSONRPCResponse
  244. if err := json.Unmarshal(respData, &rpcResp); err != nil {
  245. return fmt.Errorf("failed to parse response: %w", err)
  246. }
  247. if rpcResp.Error != nil {
  248. return fmt.Errorf("RPC error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
  249. }
  250. // Parse result if provided
  251. if result != nil && rpcResp.Result != nil {
  252. if err := json.Unmarshal(rpcResp.Result, result); err != nil {
  253. return fmt.Errorf("failed to parse result: %w", err)
  254. }
  255. }
  256. return nil
  257. case <-time.After(30 * time.Second):
  258. return fmt.Errorf("timeout waiting for response from %s", c.serverName)
  259. case <-c.done:
  260. return fmt.Errorf("client closed")
  261. }
  262. }