actions.ts 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. "use server";
  2. import { Client } from "@modelcontextprotocol/sdk/client/index.js";
  3. import {
  4. createClient,
  5. executeRequest,
  6. listPrimitives,
  7. Primitive,
  8. } from "./client";
  9. import { MCPClientLogger } from "./logger";
  10. import { McpRequestMessage, McpConfig, ServerConfig } from "./types";
  11. import fs from "fs/promises";
  12. import path from "path";
  13. const logger = new MCPClientLogger("MCP Actions");
  14. // Use Map to store all clients
  15. const clientsMap = new Map<
  16. string,
  17. { client: Client | null; primitives: Primitive[]; errorMsg: string | null }
  18. >();
  19. // Whether initialized
  20. let initialized = false;
  21. // Store failed clients
  22. let errorClients: string[] = [];
  23. const CONFIG_PATH = path.join(process.cwd(), "app/mcp/mcp_config.json");
  24. // 获取 MCP 配置
  25. export async function getMcpConfig(): Promise<McpConfig> {
  26. try {
  27. const configStr = await fs.readFile(CONFIG_PATH, "utf-8");
  28. return JSON.parse(configStr);
  29. } catch (error) {
  30. console.error("Failed to read MCP config:", error);
  31. return { mcpServers: {} };
  32. }
  33. }
  34. // 更新 MCP 配置
  35. export async function updateMcpConfig(config: McpConfig): Promise<void> {
  36. try {
  37. await fs.writeFile(CONFIG_PATH, JSON.stringify(config, null, 2));
  38. } catch (error) {
  39. console.error("Failed to write MCP config:", error);
  40. throw error;
  41. }
  42. }
  43. // 重新初始化所有客户端
  44. export async function reinitializeMcpClients() {
  45. logger.info("Reinitializing MCP clients...");
  46. // 遍历所有客户端,关闭
  47. try {
  48. for (const [clientId, clientData] of clientsMap.entries()) {
  49. clientData.client?.close();
  50. }
  51. } catch (error) {
  52. logger.error(`Failed to close clients: ${error}`);
  53. }
  54. // 清空状态
  55. clientsMap.clear();
  56. errorClients = [];
  57. initialized = false;
  58. // 重新初始化
  59. return initializeMcpClients();
  60. }
  61. // Initialize all configured clients
  62. export async function initializeMcpClients() {
  63. // If already initialized, return
  64. if (initialized) {
  65. return { errorClients };
  66. }
  67. logger.info("Starting to initialize MCP clients...");
  68. errorClients = [];
  69. const config = await getMcpConfig();
  70. // Initialize all clients, key is clientId, value is client config
  71. for (const [clientId, serverConfig] of Object.entries(config.mcpServers)) {
  72. try {
  73. logger.info(`Initializing MCP client: ${clientId}`);
  74. const client = await createClient(serverConfig as ServerConfig, clientId);
  75. const primitives = await listPrimitives(client);
  76. clientsMap.set(clientId, { client, primitives, errorMsg: null });
  77. logger.success(
  78. `Client [${clientId}] initialized, ${primitives.length} primitives supported`,
  79. );
  80. } catch (error) {
  81. errorClients.push(clientId);
  82. clientsMap.set(clientId, {
  83. client: null,
  84. primitives: [],
  85. errorMsg: error instanceof Error ? error.message : String(error),
  86. });
  87. logger.error(`Failed to initialize client ${clientId}: ${error}`);
  88. }
  89. }
  90. initialized = true;
  91. if (errorClients.length > 0) {
  92. logger.warn(`Failed to initialize clients: ${errorClients.join(", ")}`);
  93. } else {
  94. logger.success("All MCP clients initialized");
  95. }
  96. const availableClients = await getAvailableClients();
  97. logger.info(`Available clients: ${availableClients.join(",")}`);
  98. return { errorClients };
  99. }
  100. // Execute MCP request
  101. export async function executeMcpAction(
  102. clientId: string,
  103. request: McpRequestMessage,
  104. ) {
  105. try {
  106. // Find the corresponding client
  107. const client = clientsMap.get(clientId)?.client;
  108. if (!client) {
  109. logger.error(`Client ${clientId} not found`);
  110. return;
  111. }
  112. logger.info(`Executing MCP request for ${clientId}`);
  113. // Execute request and return result
  114. return await executeRequest(client, request);
  115. } catch (error) {
  116. logger.error(`MCP execution error: ${error}`);
  117. throw error;
  118. }
  119. }
  120. // Get all available client IDs
  121. export async function getAvailableClients() {
  122. return Array.from(clientsMap.entries())
  123. .filter(([_, data]) => data.errorMsg === null)
  124. .map(([clientId]) => clientId);
  125. }
  126. // Get all primitives from all clients
  127. export async function getAllPrimitives(): Promise<
  128. {
  129. clientId: string;
  130. primitives: Primitive[];
  131. }[]
  132. > {
  133. return Array.from(clientsMap.entries()).map(([clientId, { primitives }]) => ({
  134. clientId,
  135. primitives,
  136. }));
  137. }
  138. // 获取客户端的 Primitives
  139. export async function getClientPrimitives(clientId: string) {
  140. try {
  141. const clientData = clientsMap.get(clientId);
  142. if (!clientData) {
  143. console.warn(`Client ${clientId} not found in map`);
  144. return null;
  145. }
  146. if (clientData.errorMsg) {
  147. console.warn(`Client ${clientId} has error: ${clientData.errorMsg}`);
  148. return null;
  149. }
  150. return clientData.primitives;
  151. } catch (error) {
  152. console.error(`Failed to get primitives for client ${clientId}:`, error);
  153. return null;
  154. }
  155. }
  156. // 重启所有客户端
  157. export async function restartAllClients() {
  158. logger.info("Restarting all MCP clients...");
  159. // 清空状态
  160. clientsMap.clear();
  161. errorClients = [];
  162. initialized = false;
  163. // 重新初始化
  164. await initializeMcpClients();
  165. return {
  166. success: errorClients.length === 0,
  167. errorClients,
  168. };
  169. }
  170. // 获取所有客户端状态
  171. export async function getAllClientStatus(): Promise<
  172. Record<string, string | null>
  173. > {
  174. const status: Record<string, string | null> = {};
  175. for (const [clientId, data] of clientsMap.entries()) {
  176. status[clientId] = data.errorMsg;
  177. }
  178. return status;
  179. }
  180. // 检查客户端状态
  181. export async function getClientErrors(): Promise<
  182. Record<string, string | null>
  183. > {
  184. const errors: Record<string, string | null> = {};
  185. for (const [clientId, data] of clientsMap.entries()) {
  186. errors[clientId] = data.errorMsg;
  187. }
  188. return errors;
  189. }
  190. // 获取客户端状态,不重新初始化
  191. export async function refreshClientStatus() {
  192. logger.info("Refreshing client status...");
  193. // 如果还没初始化过,则初始化
  194. if (!initialized) {
  195. return initializeMcpClients();
  196. }
  197. // 否则只更新错误状态
  198. errorClients = [];
  199. for (const [clientId, clientData] of clientsMap.entries()) {
  200. if (clientData.errorMsg !== null) {
  201. errorClients.push(clientId);
  202. }
  203. }
  204. return { errorClients };
  205. }