| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- package main
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "io"
- "log"
- "os"
- "os/exec"
- "sync"
- "time"
- )
- // MCPStdioClient is an MCP client that communicates via stdin/stdout
- // Used for external MCP servers spawned as child processes
- type MCPStdioClient struct {
- serverName string
- config MCPServerConfig
- cmd *exec.Cmd
- stdin io.WriteCloser
- stdout io.Reader
- stderr io.Reader
- // Request ID counter
- idCounter int
- idMu sync.Mutex
- // Pending requests (ID -> response channel)
- pending map[interface{}]chan json.RawMessage
- pendingMu sync.Mutex
- // Tools cache
- tools []Tool
- // Done channel for cleanup
- done chan struct{}
- // Close once for idempotent close
- closeOnce sync.Once
- }
- // NewMCPStdioClient creates a new stdio MCP client for an external server
- func NewMCPStdioClient(serverName string, config MCPServerConfig) *MCPStdioClient {
- return &MCPStdioClient{
- serverName: serverName,
- config: config,
- pending: make(map[interface{}]chan json.RawMessage),
- done: make(chan struct{}),
- }
- }
- // Start spawns the external MCP server process
- func (c *MCPStdioClient) Start() error {
- // Build command
- c.cmd = exec.Command(c.config.Command, c.config.Args...)
- // Set environment variables
- if len(c.config.Env) > 0 {
- env := os.Environ()
- for key, value := range c.config.Env {
- env = append(env, fmt.Sprintf("%s=%s", key, value))
- }
- c.cmd.Env = env
- }
- // Get stdin pipe
- stdin, err := c.cmd.StdinPipe()
- if err != nil {
- return fmt.Errorf("failed to get stdin pipe: %w", err)
- }
- c.stdin = stdin
- // Get stdout pipe
- stdout, err := c.cmd.StdoutPipe()
- if err != nil {
- return fmt.Errorf("failed to get stdout pipe: %w", err)
- }
- c.stdout = stdout
- // Get stderr pipe for logging
- stderr, err := c.cmd.StderrPipe()
- if err != nil {
- return fmt.Errorf("failed to get stderr pipe: %w", err)
- }
- c.stderr = stderr
- // Start the process
- if err := c.cmd.Start(); err != nil {
- return fmt.Errorf("failed to start MCP server '%s': %w", c.serverName, err)
- }
- // Start reading stdout in background
- go c.readOutput()
- // Start reading stderr in background for logging
- go c.readStderr()
- return nil
- }
- // readOutput reads JSON-RPC responses from stdout
- func (c *MCPStdioClient) readOutput() {
- scanner := bufio.NewScanner(c.stdout)
- for scanner.Scan() {
- line := scanner.Bytes()
- // Parse to check if it's a response (has ID) or notification
- var msg struct {
- ID interface{} `json:"id"`
- }
- if err := json.Unmarshal(line, &msg); err != nil {
- log.Printf("[%s] Failed to parse output: %v", c.serverName, err)
- continue
- }
- if msg.ID != nil {
- // JSON numbers are unmarshaled as float64, but we use int for IDs
- var idKey interface{} = msg.ID
- if f, ok := msg.ID.(float64); ok {
- idKey = int(f)
- }
- // It's a response - dispatch to pending request
- c.pendingMu.Lock()
- if ch, ok := c.pending[idKey]; ok {
- ch <- json.RawMessage(line)
- delete(c.pending, idKey)
- }
- c.pendingMu.Unlock()
- }
- // Notifications are ignored for now (external servers typically don't send them)
- }
- }
- // readStderr reads stderr output for debugging
- func (c *MCPStdioClient) readStderr() {
- scanner := bufio.NewScanner(c.stderr)
- for scanner.Scan() {
- log.Printf("[%s] stderr: %s", c.serverName, scanner.Text())
- }
- }
- // Initialize sends the initialize request
- func (c *MCPStdioClient) Initialize() (*InitializeResult, error) {
- params := InitializeParams{
- ProtocolVersion: ProtocolVersion,
- Capabilities: ClientCapabilities{
- Roots: &RootsCapability{ListChanged: false},
- },
- ClientInfo: ImplementationInfo{
- Name: "ARP Agent",
- Version: "1.0.0",
- },
- }
- result := &InitializeResult{}
- if err := c.sendRequest("initialize", params, result); err != nil {
- return nil, err
- }
- return result, nil
- }
- // ListTools discovers available tools from this server
- func (c *MCPStdioClient) ListTools() ([]Tool, error) {
- result := &ListToolsResult{}
- if err := c.sendRequest("tools/list", nil, result); err != nil {
- return nil, err
- }
- c.tools = result.Tools
- return result.Tools, nil
- }
- // CallTool executes a tool call
- func (c *MCPStdioClient) CallTool(name string, arguments map[string]interface{}) (*CallToolResult, error) {
- params := CallToolParams{
- Name: name,
- Arguments: arguments,
- }
- result := &CallToolResult{}
- if err := c.sendRequest("tools/call", params, result); err != nil {
- return nil, err
- }
- return result, nil
- }
- // GetTools returns the cached tools
- func (c *MCPStdioClient) GetTools() []Tool {
- return c.tools
- }
- // Close stops the external MCP server process (idempotent)
- func (c *MCPStdioClient) Close() error {
- c.closeOnce.Do(func() {
- close(c.done)
- })
- if c.stdin != nil {
- c.stdin.Close()
- }
- if c.cmd != nil && c.cmd.Process != nil {
- // Give the process a moment to exit gracefully
- done := make(chan error, 1)
- go func() {
- done <- c.cmd.Wait()
- }()
- select {
- case <-time.After(5 * time.Second):
- // Force kill if it doesn't exit gracefully
- log.Printf("[%s] Force killing process", c.serverName)
- c.cmd.Process.Kill()
- case <-done:
- // Process exited gracefully
- }
- }
- return nil
- }
- // nextID generates a unique request ID
- func (c *MCPStdioClient) nextID() int {
- c.idMu.Lock()
- defer c.idMu.Unlock()
- c.idCounter++
- return c.idCounter
- }
- // sendRequest sends a JSON-RPC request and waits for the response
- func (c *MCPStdioClient) sendRequest(method string, params interface{}, result interface{}) error {
- // Build request
- id := c.nextID()
- var paramsJSON json.RawMessage
- if params != nil {
- var err error
- paramsJSON, err = json.Marshal(params)
- if err != nil {
- return fmt.Errorf("failed to marshal params: %w", err)
- }
- }
- req := JSONRPCRequest{
- JSONRPC: "2.0",
- ID: id,
- Method: method,
- Params: paramsJSON,
- }
- reqBody, err := json.Marshal(req)
- if err != nil {
- return fmt.Errorf("failed to marshal request: %w", err)
- }
- // Add newline as line delimiter
- reqBody = append(reqBody, '\n')
- // Register pending request before sending
- respChan := make(chan json.RawMessage, 1)
- c.pendingMu.Lock()
- c.pending[id] = respChan
- c.pendingMu.Unlock()
- // Cleanup on return
- defer func() {
- c.pendingMu.Lock()
- delete(c.pending, id)
- c.pendingMu.Unlock()
- }()
- // Send request
- if _, err := c.stdin.Write(reqBody); err != nil {
- return fmt.Errorf("failed to send request: %w", err)
- }
- // Wait for response
- select {
- case respData := <-respChan:
- // Parse response
- var rpcResp JSONRPCResponse
- if err := json.Unmarshal(respData, &rpcResp); err != nil {
- return fmt.Errorf("failed to parse response: %w", err)
- }
- if rpcResp.Error != nil {
- return fmt.Errorf("RPC error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
- }
- // Parse result if provided
- if result != nil && rpcResp.Result != nil {
- if err := json.Unmarshal(rpcResp.Result, result); err != nil {
- return fmt.Errorf("failed to parse result: %w", err)
- }
- }
- return nil
- case <-time.After(30 * time.Second):
- return fmt.Errorf("timeout waiting for response from %s", c.serverName)
- case <-c.done:
- return fmt.Errorf("client closed")
- }
- }
|