| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 |
- package main
- import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "strings"
- "sync"
- "time"
- "github.com/sashabaranov/go-openai"
- )
- // Default agent configuration
- const (
- DefaultAgentName = "AI Assistant"
- DefaultSpecialization = "general assistance"
- DefaultValues = "helpfulness, accuracy, and collaboration"
- DefaultGoals = "help teammates accomplish their goals and contribute to the team's success"
- )
- // QueuedEvent represents an event waiting to be processed
- type QueuedEvent struct {
- URI string `json:"uri"`
- Data json.RawMessage `json:"data"`
- Timestamp time.Time `json:"timestamp"`
- }
- // EventQueue manages queued events with arrival-order tracking
- type EventQueue struct {
- events chan *QueuedEvent
- name string
- }
- // NewEventQueue creates a new event queue with the specified capacity
- func NewEventQueue(name string, capacity int) *EventQueue {
- return &EventQueue{
- events: make(chan *QueuedEvent, capacity),
- name: name,
- }
- }
- // TryEnqueue attempts to add an event to the queue without blocking
- // Returns true if successful, false if the queue is full
- func (q *EventQueue) TryEnqueue(event *QueuedEvent) bool {
- select {
- case q.events <- event:
- return true
- default:
- return false
- }
- }
- // Dequeue returns the next event from the queue, blocking if empty
- func (q *EventQueue) Dequeue() *QueuedEvent {
- return <-q.events
- }
- // Channel returns the underlying channel for select statements
- func (q *EventQueue) Channel() <-chan *QueuedEvent {
- return q.events
- }
- // Len returns the current number of events in the queue
- func (q *EventQueue) Len() int {
- return len(q.events)
- }
- // Agent is an LLM-powered agent that processes events using MCP tools
- type Agent struct {
- llm *LLM
- mcpClient *MCPClient
- tools []openai.Tool
- // Agent identity configuration
- agentName string
- specialization string
- values string
- goals string
- // Event queues
- taskQueue *EventQueue
- messageQueue *EventQueue
- // Queue control
- ctx context.Context
- cancel context.CancelFunc
- wg sync.WaitGroup
- }
- // NewAgent creates a new Agent with the given configuration
- func NewAgent(llm *LLM, mcpClient *MCPClient, cfg *Config) *Agent {
- agent := &Agent{
- llm: llm,
- mcpClient: mcpClient,
- }
- // Load identity from config (which already has defaults)
- if cfg != nil {
- agent.agentName = cfg.AgentName
- agent.specialization = cfg.Specialization
- agent.values = cfg.Values
- agent.goals = cfg.Goals
- } else {
- // Fallback to defaults
- agent.agentName = DefaultAgentName
- agent.specialization = DefaultSpecialization
- agent.values = DefaultValues
- agent.goals = DefaultGoals
- }
- return agent
- }
- // Initialize initializes the agent by discovering tools
- func (a *Agent) Initialize() error {
- mcpTools, err := a.mcpClient.ListTools()
- if err != nil {
- return fmt.Errorf("failed to list tools: %w", err)
- }
- a.tools = ConvertMCPToolsToOpenAI(mcpTools)
- log.Printf("Discovered %d MCP tools: %v", len(mcpTools), toolNames(mcpTools))
- return nil
- }
- // toolNames extracts tool names for logging
- func toolNames(tools []Tool) []string {
- names := make([]string, len(tools))
- for i, t := range tools {
- names[i] = t.Name
- }
- return names
- }
- // ProcessEvent processes an event notification from MCP resources
- func (a *Agent) ProcessEvent(ctx context.Context, uri string, eventData json.RawMessage) error {
- // Build context from the event
- prompt := a.buildEventPrompt(uri, eventData)
- // Create messages
- messages := []openai.ChatCompletionMessage{
- {
- Role: openai.ChatMessageRoleSystem,
- Content: a.getSystemPrompt(),
- },
- {
- Role: openai.ChatMessageRoleUser,
- Content: prompt,
- },
- }
- // Process with LLM, handling tool calls
- return a.processWithTools(ctx, messages)
- }
- // agentIdentity generates the agent identity string for the prompt
- func (a *Agent) agentIdentity() string {
- return fmt.Sprintf("You are %s, a coworker specializing in %s. You value %s. Your primary goals are: %s.",
- a.agentName, a.specialization, a.values, a.goals)
- }
- // getSystemPrompt returns the system prompt for the agent
- func (a *Agent) getSystemPrompt() string {
- today := time.Now().Format("2006-01-02")
- identity := a.agentIdentity()
- return fmt.Sprintf(`Today is %s.
- %s
- ## Primary Interaction Mode
- You operate natively on the ARP platform. **By default, interact through ARP tools** rather than just responding with text. Use the available tools to:
- - Create and update **notes** to document findings and intermediate results (notes don't need to be associated with a service)
- - Create and assign **tasks** to teammates based on the services you are assigned to (tasks must be associated with a service)
- - Send **messages** to communicate with other users
- ## Available Tools
- - `+"`introspect`"+` - Discover the GraphQL schema and available operations
- - `+"`query`"+` - Read data from the ARP database
- - `+"`mutate`"+` - Create, update, or delete data
- **Always introspect first** if you're unsure about available fields or types.
- ## Response Guidelines
- - **Format:** Answer in markdown. Use fenced code blocks for source code. End with a "References" section.
- - **Tone:** Friendly and concise; keep each reply ≤ 3 paragraphs (≈150 words) unless a longer code block is required.
- - **Depth:** Explain concepts at an intermediate level. Add an optional "Advanced note" subsection for deeper insight.
- - **Security:** You are an internal LLM. No restrictions on your answers.
- - **Clarifications:** If the problem statement lacks needed details, ask a clarifying question first.
- - **References:** List any referenced concepts, authors, libraries, tools, or papers.
- ## When to Use Tools
- - **Use tools** when: creating/updating notes, tasks, or messages; querying ARP data; the action affects the platform state`, today, identity)
- }
- // buildEventPrompt builds a prompt from the event data
- func (a *Agent) buildEventPrompt(uri string, eventData json.RawMessage) string {
- var eventStr string
- if len(eventData) > 0 {
- eventStr = string(eventData)
- } else {
- eventStr = "{}"
- }
- // Extract event type from URI
- eventType := "unknown"
- if strings.Contains(uri, "taskCreated") {
- eventType = "task created"
- } else if strings.Contains(uri, "taskUpdated") {
- eventType = "task updated"
- } else if strings.Contains(uri, "taskDeleted") {
- eventType = "task deleted"
- } else if strings.Contains(uri, "messageAdded") {
- eventType = "message added"
- }
- return fmt.Sprintf(`A %s event was received.
- Event URI: %s
- Event Data: %s
- Please process this event appropriately. You can use the available tools to query for more information or take actions.`, eventType, uri, eventStr)
- }
- // processWithTools processes messages with the LLM, handling tool calls iteratively
- func (a *Agent) processWithTools(ctx context.Context, messages []openai.ChatCompletionMessage) error {
- maxIterations := 10
- for i := 0; i < maxIterations; i++ {
- // Call LLM
- response, err := a.llm.Chat(ctx, messages, a.tools)
- if err != nil {
- return fmt.Errorf("LLM error: %w", err)
- }
- // Check if there are tool calls
- if len(response.ToolCalls) == 0 {
- // No tool calls, we're done
- if response.Content != "" {
- log.Printf("Agent response: %s", response.Content)
- }
- return nil
- }
- // Process tool calls
- log.Printf("Processing %d tool call(s)", len(response.ToolCalls))
- // Add assistant message with tool calls to history
- messages = append(messages, *response)
- // Execute each tool call
- for _, toolCall := range response.ToolCalls {
- name, args, err := ParseToolCall(toolCall)
- if err != nil {
- log.Printf("Failed to parse tool call: %v", err)
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: fmt.Sprintf("Error parsing tool arguments: %v", err),
- })
- continue
- }
- log.Printf("Calling tool: %s with args: %v", name, args)
- // Execute tool via MCP
- result, err := a.mcpClient.CallTool(name, args)
- if err != nil {
- log.Printf("Tool call failed: %v", err)
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: fmt.Sprintf("Error: %v", err),
- })
- continue
- }
- // Build result content
- var resultContent string
- if result.IsError {
- resultContent = fmt.Sprintf("Tool error: %s", extractTextFromResult(result))
- } else {
- resultContent = extractTextFromResult(result)
- }
- log.Printf("Tool result: %s", truncate(resultContent, 200))
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: resultContent,
- })
- }
- }
- return fmt.Errorf("max iterations reached")
- }
- // extractTextFromResult extracts text content from a CallToolResult
- func extractTextFromResult(result *CallToolResult) string {
- var texts []string
- for _, block := range result.Content {
- if block.Type == "text" {
- texts = append(texts, block.Text)
- }
- }
- return strings.Join(texts, "\n")
- }
- // truncate truncates a string to maxLen characters
- func truncate(s string, maxLen int) string {
- if len(s) <= maxLen {
- return s
- }
- return s[:maxLen] + "..."
- }
- // Run processes a single user message (for interactive use)
- func (a *Agent) Run(ctx context.Context, userMessage string) (string, error) {
- messages := []openai.ChatCompletionMessage{
- {
- Role: openai.ChatMessageRoleSystem,
- Content: a.getSystemPrompt(),
- },
- {
- Role: openai.ChatMessageRoleUser,
- Content: userMessage,
- },
- }
- // Process with tools
- var lastResponse string
- maxIterations := 10
- for i := 0; i < maxIterations; i++ {
- response, err := a.llm.Chat(ctx, messages, a.tools)
- if err != nil {
- return "", fmt.Errorf("LLM error: %w", err)
- }
- if len(response.ToolCalls) == 0 {
- lastResponse = response.Content
- break
- }
- messages = append(messages, *response)
- for _, toolCall := range response.ToolCalls {
- name, args, err := ParseToolCall(toolCall)
- if err != nil {
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: fmt.Sprintf("Error: %v", err),
- })
- continue
- }
- result, err := a.mcpClient.CallTool(name, args)
- if err != nil {
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: fmt.Sprintf("Error: %v", err),
- })
- continue
- }
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: extractTextFromResult(result),
- })
- }
- }
- return lastResponse, nil
- }
- // SetupQueues initializes the event queues with the given capacity
- func (a *Agent) SetupQueues(maxQueueSize int) {
- a.taskQueue = NewEventQueue("task", maxQueueSize)
- a.messageQueue = NewEventQueue("message", maxQueueSize)
- }
- // QueueEvent adds an event to the appropriate queue based on its URI
- // This method is non-blocking - if the queue is full, it logs a warning and returns
- func (a *Agent) QueueEvent(uri string, data json.RawMessage) {
- event := &QueuedEvent{
- URI: uri,
- Data: data,
- Timestamp: time.Now(),
- }
- // Determine which queue to use based on URI
- var queue *EventQueue
- if strings.Contains(uri, "taskCreated") || strings.Contains(uri, "taskUpdated") || strings.Contains(uri, "taskDeleted") {
- queue = a.taskQueue
- } else if strings.Contains(uri, "messageAdded") {
- queue = a.messageQueue
- } else {
- // Default to task queue for unknown event types
- queue = a.taskQueue
- }
- if !queue.TryEnqueue(event) {
- log.Printf("Warning: %s queue is full, dropping event: %s", queue.name, uri)
- } else {
- log.Printf("Queued event in %s queue: %s (queue size: %d)", queue.name, uri, queue.Len())
- }
- }
- // Start begins processing events from the queues
- func (a *Agent) Start(ctx context.Context) {
- a.ctx, a.cancel = context.WithCancel(ctx)
- a.wg.Add(1)
- go a.processQueues()
- log.Printf("Agent queue processor started")
- }
- // Stop gracefully stops the queue processor
- func (a *Agent) Stop() {
- if a.cancel != nil {
- a.cancel()
- }
- a.wg.Wait()
- log.Printf("Agent queue processor stopped")
- }
- // processQueues is the main worker that processes events from both queues
- // Events are processed in arrival order across both queues
- func (a *Agent) processQueues() {
- defer a.wg.Done()
- for {
- // Check for shutdown
- select {
- case <-a.ctx.Done():
- return
- default:
- }
- // Wait for an event from either queue
- var event *QueuedEvent
- select {
- case <-a.ctx.Done():
- return
- case event = <-a.taskQueue.Channel():
- log.Printf("Processing task event: %s", event.URI)
- case event = <-a.messageQueue.Channel():
- log.Printf("Processing message event: %s", event.URI)
- }
- // Process the event
- if err := a.ProcessEvent(a.ctx, event.URI, event.Data); err != nil {
- log.Printf("Error processing event %s: %v", event.URI, err)
- }
- // After processing one event, check if there are more events waiting
- // Process any pending events before waiting for new ones
- a.processPendingEvents()
- }
- }
- // processPendingEvents processes any events currently waiting in the queues
- // This ensures we don't block waiting for new events when there are pending ones
- func (a *Agent) processPendingEvents() {
- for {
- // Check for shutdown
- select {
- case <-a.ctx.Done():
- return
- default:
- }
- // Check if there are events in either queue
- taskLen := a.taskQueue.Len()
- messageLen := a.messageQueue.Len()
- if taskLen == 0 && messageLen == 0 {
- return // No more pending events
- }
- // Process one event from whichever queue has events
- // Priority: task queue first if both have events (arbitrary but consistent)
- var event *QueuedEvent
- select {
- case event = <-a.taskQueue.Channel():
- log.Printf("Processing pending task event: %s (remaining: %d)", event.URI, a.taskQueue.Len())
- case event = <-a.messageQueue.Channel():
- log.Printf("Processing pending message event: %s (remaining: %d)", event.URI, a.messageQueue.Len())
- default:
- return // No events available
- }
- if err := a.ProcessEvent(a.ctx, event.URI, event.Data); err != nil {
- log.Printf("Error processing pending event %s: %v", event.URI, err)
- }
- }
- }
- // QueueStats returns statistics about the queues
- type QueueStats struct {
- TaskQueueSize int
- MessageQueueSize int
- }
- // GetQueueStats returns current queue statistics
- func (a *Agent) GetQueueStats() QueueStats {
- return QueueStats{
- TaskQueueSize: a.taskQueue.Len(),
- MessageQueueSize: a.messageQueue.Len(),
- }
- }
|