1
0

llm_test.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. package main
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "testing"
  6. "github.com/sashabaranov/go-openai"
  7. )
  8. // TestLLM_ConvertMCPToolsToOpenAI tests the MCP to OpenAI tool conversion
  9. func TestLLM_ConvertMCPToolsToOpenAI(t *testing.T) {
  10. tests := []struct {
  11. name string
  12. mcpTools []Tool
  13. wantLen int
  14. }{
  15. {
  16. name: "EmptyTools",
  17. mcpTools: []Tool{},
  18. wantLen: 0,
  19. },
  20. {
  21. name: "SingleTool",
  22. mcpTools: []Tool{
  23. {
  24. Name: "introspect",
  25. Description: "Discover the GraphQL schema",
  26. InputSchema: InputSchema{
  27. Type: "object",
  28. Properties: map[string]Property{
  29. "typeName": {Type: "string", Description: "The type to introspect"},
  30. },
  31. Required: []string{},
  32. AdditionalProperties: false,
  33. },
  34. },
  35. },
  36. wantLen: 1,
  37. },
  38. {
  39. name: "MultipleTools",
  40. mcpTools: []Tool{
  41. {
  42. Name: "query",
  43. Description: "Execute a GraphQL query",
  44. InputSchema: InputSchema{
  45. Type: "object",
  46. Properties: map[string]Property{
  47. "query": {Type: "string", Description: "The GraphQL query"},
  48. },
  49. Required: []string{"query"},
  50. AdditionalProperties: false,
  51. },
  52. },
  53. {
  54. Name: "mutate",
  55. Description: "Execute a GraphQL mutation",
  56. InputSchema: InputSchema{
  57. Type: "object",
  58. Properties: map[string]Property{
  59. "mutation": {Type: "string", Description: "The GraphQL mutation"},
  60. },
  61. Required: []string{"mutation"},
  62. AdditionalProperties: false,
  63. },
  64. },
  65. },
  66. wantLen: 2,
  67. },
  68. }
  69. for _, tt := range tests {
  70. t.Run(tt.name, func(t *testing.T) {
  71. tools := ConvertMCPToolsToOpenAI(tt.mcpTools)
  72. if len(tools) != tt.wantLen {
  73. t.Errorf("Expected %d tools, got %d", tt.wantLen, len(tools))
  74. }
  75. // Verify tool conversion details
  76. for i, tool := range tools {
  77. if tool.Type != openai.ToolTypeFunction {
  78. t.Errorf("Tool %d: Expected type %s, got %s", i, openai.ToolTypeFunction, tool.Type)
  79. }
  80. if tool.Function.Name != tt.mcpTools[i].Name {
  81. t.Errorf("Tool %d: Expected name %s, got %s", i, tt.mcpTools[i].Name, tool.Function.Name)
  82. }
  83. if tool.Function.Description != tt.mcpTools[i].Description {
  84. t.Errorf("Tool %d: Expected description %s, got %s", i, tt.mcpTools[i].Description, tool.Function.Description)
  85. }
  86. }
  87. })
  88. }
  89. }
  90. // TestLLM_ConvertMCPToolsToOpenAI_ObjectProperties tests that object-type properties
  91. // get additionalProperties: true to allow arbitrary key-value pairs
  92. func TestLLM_ConvertMCPToolsToOpenAI_ObjectProperties(t *testing.T) {
  93. mcpTools := []Tool{
  94. {
  95. Name: "query",
  96. Description: "Execute a GraphQL query",
  97. InputSchema: InputSchema{
  98. Type: "object",
  99. Properties: map[string]Property{
  100. "query": {
  101. Type: "string",
  102. Description: "The GraphQL query string",
  103. },
  104. "variables": {
  105. Type: "object",
  106. Description: "Optional query variables as key-value pairs",
  107. },
  108. },
  109. Required: []string{"query"},
  110. AdditionalProperties: false,
  111. },
  112. },
  113. }
  114. tools := ConvertMCPToolsToOpenAI(mcpTools)
  115. if len(tools) != 1 {
  116. t.Fatalf("Expected 1 tool, got %d", len(tools))
  117. }
  118. // Check that parameters don't have additionalProperties at top level
  119. params := tools[0].Function.Parameters.(map[string]interface{})
  120. if _, hasAdditionalProps := params["additionalProperties"]; hasAdditionalProps {
  121. t.Error("Top-level parameters should NOT have additionalProperties field")
  122. }
  123. // Check that the variables property has additionalProperties: true
  124. props := params["properties"].(map[string]interface{})
  125. variablesProp, ok := props["variables"].(map[string]interface{})
  126. if !ok {
  127. t.Fatal("variables property not found")
  128. }
  129. if additionalProps, ok := variablesProp["additionalProperties"]; !ok {
  130. t.Error("Object property 'variables' should have additionalProperties field")
  131. } else if additionalProps != true {
  132. t.Errorf("Object property 'variables' additionalProperties should be true, got %v", additionalProps)
  133. }
  134. // Check that string property does NOT have additionalProperties
  135. queryProp, ok := props["query"].(map[string]interface{})
  136. if !ok {
  137. t.Fatal("query property not found")
  138. }
  139. if _, hasAdditionalProps := queryProp["additionalProperties"]; hasAdditionalProps {
  140. t.Error("String property 'query' should NOT have additionalProperties field")
  141. }
  142. }
  143. // TestLLM_ParseToolCall tests parsing tool calls from LLM responses
  144. func TestLLM_ParseToolCall(t *testing.T) {
  145. tests := []struct {
  146. name string
  147. toolCall openai.ToolCall
  148. wantName string
  149. wantArgs map[string]interface{}
  150. wantErr bool
  151. }{
  152. {
  153. name: "ValidToolCall",
  154. toolCall: openai.ToolCall{
  155. ID: "call-123",
  156. Function: openai.FunctionCall{
  157. Name: "query",
  158. Arguments: `{"query": "{ users { email } }"}`,
  159. },
  160. },
  161. wantName: "query",
  162. wantArgs: map[string]interface{}{
  163. "query": "{ users { email } }",
  164. },
  165. wantErr: false,
  166. },
  167. {
  168. name: "EmptyArguments",
  169. toolCall: openai.ToolCall{
  170. ID: "call-456",
  171. Function: openai.FunctionCall{
  172. Name: "introspect",
  173. Arguments: `{}`,
  174. },
  175. },
  176. wantName: "introspect",
  177. wantArgs: map[string]interface{}{},
  178. wantErr: false,
  179. },
  180. {
  181. name: "InvalidJSON",
  182. toolCall: openai.ToolCall{
  183. ID: "call-789",
  184. Function: openai.FunctionCall{
  185. Name: "mutate",
  186. Arguments: `invalid json`,
  187. },
  188. },
  189. wantName: "mutate",
  190. wantArgs: nil,
  191. wantErr: true,
  192. },
  193. {
  194. name: "NestedArguments",
  195. toolCall: openai.ToolCall{
  196. ID: "call-abc",
  197. Function: openai.FunctionCall{
  198. Name: "createTask",
  199. Arguments: `{"title": "Test Task", "priority": "high", "assigneeId": "user-123"}`,
  200. },
  201. },
  202. wantName: "createTask",
  203. wantArgs: map[string]interface{}{
  204. "title": "Test Task",
  205. "priority": "high",
  206. "assigneeId": "user-123",
  207. },
  208. wantErr: false,
  209. },
  210. }
  211. for _, tt := range tests {
  212. t.Run(tt.name, func(t *testing.T) {
  213. name, args, err := ParseToolCall(tt.toolCall)
  214. if (err != nil) != tt.wantErr {
  215. t.Errorf("ParseToolCall() error = %v, wantErr %v", err, tt.wantErr)
  216. return
  217. }
  218. if name != tt.wantName {
  219. t.Errorf("ParseToolCall() name = %v, want %v", name, tt.wantName)
  220. }
  221. if !tt.wantErr && args != nil {
  222. // Compare args
  223. argsJSON, _ := json.Marshal(args)
  224. wantJSON, _ := json.Marshal(tt.wantArgs)
  225. if string(argsJSON) != string(wantJSON) {
  226. t.Errorf("ParseToolCall() args = %v, want %v", args, tt.wantArgs)
  227. }
  228. }
  229. })
  230. }
  231. }
  232. // TestLLM_ToolConversionSnapshot tests tool conversion with snapshot
  233. func TestLLM_ToolConversionSnapshot(t *testing.T) {
  234. mcpTools := []Tool{
  235. {
  236. Name: "introspect",
  237. Description: "Discover the GraphQL schema structure",
  238. InputSchema: InputSchema{
  239. Type: "object",
  240. Properties: map[string]Property{
  241. "typeName": {
  242. Type: "string",
  243. Description: "Optional type name to introspect",
  244. },
  245. },
  246. Required: []string{},
  247. AdditionalProperties: false,
  248. },
  249. },
  250. {
  251. Name: "query",
  252. Description: "Execute a GraphQL query",
  253. InputSchema: InputSchema{
  254. Type: "object",
  255. Properties: map[string]Property{
  256. "query": {
  257. Type: "string",
  258. Description: "The GraphQL query string",
  259. },
  260. },
  261. Required: []string{"query"},
  262. AdditionalProperties: false,
  263. },
  264. },
  265. {
  266. Name: "mutate",
  267. Description: "Execute a GraphQL mutation",
  268. InputSchema: InputSchema{
  269. Type: "object",
  270. Properties: map[string]Property{
  271. "mutation": {
  272. Type: "string",
  273. Description: "The GraphQL mutation string",
  274. },
  275. },
  276. Required: []string{"mutation"},
  277. AdditionalProperties: false,
  278. },
  279. },
  280. }
  281. openaiTools := ConvertMCPToolsToOpenAI(mcpTools)
  282. testSnapshotResult(t, "converted_tools", openaiTools)
  283. }
  284. // TestIsRetryableError tests the isRetryableError function
  285. func TestIsRetryableError(t *testing.T) {
  286. tests := []struct {
  287. name string
  288. err error
  289. want bool
  290. }{
  291. {
  292. name: "nil error",
  293. err: nil,
  294. want: false,
  295. },
  296. {
  297. name: "context canceled",
  298. err: fmt.Errorf("context canceled"),
  299. want: true,
  300. },
  301. {
  302. name: "context deadline exceeded",
  303. err: fmt.Errorf("context deadline exceeded"),
  304. want: true,
  305. },
  306. {
  307. name: "connection refused",
  308. err: fmt.Errorf("dial tcp 127.0.0.1:8080: connect: connection refused"),
  309. want: true,
  310. },
  311. {
  312. name: "connection reset",
  313. err: fmt.Errorf("read tcp 127.0.0.1:8080: read: connection reset by peer"),
  314. want: true,
  315. },
  316. {
  317. name: "connection closed",
  318. err: fmt.Errorf("connection closed"),
  319. want: true,
  320. },
  321. {
  322. name: "timeout",
  323. err: fmt.Errorf("dial tcp 127.0.0.1:8080: i/o timeout"),
  324. want: true,
  325. },
  326. {
  327. name: "service unavailable",
  328. err: fmt.Errorf("service unavailable (HTTP 503)"),
  329. want: true,
  330. },
  331. {
  332. name: "rate limit",
  333. err: fmt.Errorf("rate limit exceeded (HTTP 429)"),
  334. want: true,
  335. },
  336. {
  337. name: "server error 500",
  338. err: fmt.Errorf("internal server error (HTTP 500)"),
  339. want: true,
  340. },
  341. {
  342. name: "server error 502",
  343. err: fmt.Errorf("bad gateway (HTTP 502)"),
  344. want: true,
  345. },
  346. {
  347. name: "server error 503",
  348. err: fmt.Errorf("service unavailable (HTTP 503)"),
  349. want: true,
  350. },
  351. {
  352. name: "server error 504",
  353. err: fmt.Errorf("gateway timeout (HTTP 504)"),
  354. want: true,
  355. },
  356. {
  357. name: "invalid API key",
  358. err: fmt.Errorf("invalid API key (HTTP 401)"),
  359. want: false,
  360. },
  361. {
  362. name: "bad request",
  363. err: fmt.Errorf("bad request (HTTP 400)"),
  364. want: false,
  365. },
  366. {
  367. name: "generic error",
  368. err: fmt.Errorf("something went wrong"),
  369. want: false,
  370. },
  371. }
  372. for _, tt := range tests {
  373. t.Run(tt.name, func(t *testing.T) {
  374. got := isRetryableError(tt.err)
  375. if got != tt.want {
  376. t.Errorf("isRetryableError(%v) = %v, want %v", tt.err, got, tt.want)
  377. }
  378. })
  379. }
  380. }