ソースを参照

feat: MCP message type

Kadxy 11 ヶ月 前
コミット
fe67f79050
4 ファイル変更103 行追加21 行削除
  1. 7 2
      app/mcp/actions.ts
  2. 5 1
      app/mcp/client.ts
  3. 61 0
      app/mcp/types.ts
  4. 30 18
      app/store/chat.ts

+ 7 - 2
app/mcp/actions.ts

@@ -3,8 +3,9 @@
 import { createClient, executeRequest } from "./client";
 import { MCPClientLogger } from "./logger";
 import conf from "./mcp_config.json";
+import { McpRequestMessage } from "./types";
 
-const logger = new MCPClientLogger("MCP Server");
+const logger = new MCPClientLogger("MCP Actions");
 
 // Use Map to store all clients
 const clientsMap = new Map<string, any>();
@@ -51,7 +52,10 @@ export async function initializeMcpClients() {
 }
 
 // Execute MCP request
-export async function executeMcpAction(clientId: string, request: any) {
+export async function executeMcpAction(
+  clientId: string,
+  request: McpRequestMessage,
+) {
   try {
     // Find the corresponding client
     const client = clientsMap.get(clientId);
@@ -61,6 +65,7 @@ export async function executeMcpAction(clientId: string, request: any) {
     }
 
     logger.info(`Executing MCP request for ${clientId}`);
+
     // Execute request and return result
     return await executeRequest(client, request);
   } catch (error) {

+ 5 - 1
app/mcp/client.ts

@@ -1,6 +1,7 @@
 import { Client } from "@modelcontextprotocol/sdk/client/index.js";
 import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
 import { MCPClientLogger } from "./logger";
+import { McpRequestMessage } from "./types";
 import { z } from "zod";
 
 export interface ServerConfig {
@@ -79,6 +80,9 @@ export async function listPrimitives(client: Client) {
 }
 
 /** Execute a request */
-export async function executeRequest(client: Client, request: any) {
+export async function executeRequest(
+  client: Client,
+  request: McpRequestMessage,
+) {
   return client.request(request, z.any());
 }

+ 61 - 0
app/mcp/types.ts

@@ -0,0 +1,61 @@
+// ref: https://spec.modelcontextprotocol.io/specification/basic/messages/
+
+import { z } from "zod";
+
+export interface McpRequestMessage {
+  jsonrpc?: "2.0";
+  id?: string | number;
+  method: "tools/call" | string;
+  params?: {
+    [key: string]: unknown;
+  };
+}
+
+export const McpRequestMessageSchema: z.ZodType<McpRequestMessage> = z.object({
+  jsonrpc: z.literal("2.0").optional(),
+  id: z.union([z.string(), z.number()]).optional(),
+  method: z.string(),
+  params: z.record(z.unknown()).optional(),
+});
+
+export interface McpResponseMessage {
+  jsonrpc?: "2.0";
+  id?: string | number;
+  result?: {
+    [key: string]: unknown;
+  };
+  error?: {
+    code: number;
+    message: string;
+    data?: unknown;
+  };
+}
+
+export const McpResponseMessageSchema: z.ZodType<McpResponseMessage> = z.object(
+  {
+    jsonrpc: z.literal("2.0").optional(),
+    id: z.union([z.string(), z.number()]).optional(),
+    result: z.record(z.unknown()).optional(),
+    error: z
+      .object({
+        code: z.number(),
+        message: z.string(),
+        data: z.unknown().optional(),
+      })
+      .optional(),
+  },
+);
+
+export interface McpNotifications {
+  jsonrpc?: "2.0";
+  method: string;
+  params?: {
+    [key: string]: unknown;
+  };
+}
+
+export const McpNotificationsSchema: z.ZodType<McpNotifications> = z.object({
+  jsonrpc: z.literal("2.0").optional(),
+  method: z.string(),
+  params: z.record(z.unknown()).optional(),
+});

+ 30 - 18
app/store/chat.ts

@@ -1,4 +1,9 @@
-import { getMessageTextContent, trimTopic } from "../utils";
+import {
+  getMessageTextContent,
+  isDalle3,
+  safeLocalStorage,
+  trimTopic,
+} from "../utils";
 
 import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
 import { nanoid } from "nanoid";
@@ -14,14 +19,13 @@ import {
   DEFAULT_INPUT_TEMPLATE,
   DEFAULT_MODELS,
   DEFAULT_SYSTEM_TEMPLATE,
+  GEMINI_SUMMARIZE_MODEL,
   KnowledgeCutOffDate,
+  ServiceProvider,
   StoreKey,
   SUMMARIZE_MODEL,
-  GEMINI_SUMMARIZE_MODEL,
-  ServiceProvider,
 } from "../constant";
 import Locale, { getLang } from "../locales";
-import { isDalle3, safeLocalStorage } from "../utils";
 import { prettyObject } from "../utils/format";
 import { createPersistStore } from "../utils/store";
 import { estimateTokenLength } from "../utils/token";
@@ -55,6 +59,7 @@ export type ChatMessage = RequestMessage & {
   model?: ModelType;
   tools?: ChatMessageTool[];
   audio_url?: string;
+  isMcpResponse?: boolean;
 };
 
 export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -368,20 +373,22 @@ export const useChatStore = createPersistStore(
         get().summarizeSession(false, targetSession);
       },
 
-      async onUserInput(content: string, attachImages?: string[]) {
+      async onUserInput(
+        content: string,
+        attachImages?: string[],
+        isMcpResponse?: boolean,
+      ) {
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
 
-        const userContent = fillTemplateWith(content, modelConfig);
-        console.log("[User Input] after template: ", userContent);
-
-        let mContent: string | MultimodalContent[] = userContent;
+        // MCP Response no need to fill template
+        let mContent: string | MultimodalContent[] = isMcpResponse
+          ? content
+          : fillTemplateWith(content, modelConfig);
 
-        if (attachImages && attachImages.length > 0) {
+        if (!isMcpResponse && attachImages && attachImages.length > 0) {
           mContent = [
-            ...(userContent
-              ? [{ type: "text" as const, text: userContent }]
-              : []),
+            ...(content ? [{ type: "text" as const, text: content }] : []),
             ...attachImages.map((url) => ({
               type: "image_url" as const,
               image_url: { url },
@@ -392,6 +399,7 @@ export const useChatStore = createPersistStore(
         let userMessage: ChatMessage = createMessage({
           role: "user",
           content: mContent,
+          isMcpResponse,
         });
 
         const botMessage: ChatMessage = createMessage({
@@ -770,9 +778,10 @@ export const useChatStore = createPersistStore(
           lastInput,
         });
       },
+
+      /** check if the message contains MCP JSON and execute the MCP action */
       checkMcpJson(message: ChatMessage) {
-        const content =
-          typeof message.content === "string" ? message.content : "";
+        const content = getMessageTextContent(message);
         if (isMcpJson(content)) {
           try {
             const mcpRequest = extractMcpJson(content);
@@ -782,11 +791,14 @@ export const useChatStore = createPersistStore(
               executeMcpAction(mcpRequest.clientId, mcpRequest.mcp)
                 .then((result) => {
                   console.log("[MCP Response]", result);
-                  // 直接使用onUserInput发送结果
-                  get().onUserInput(
+                  const mcpResponse =
                     typeof result === "object"
                       ? JSON.stringify(result)
-                      : String(result),
+                      : String(result);
+                  get().onUserInput(
+                    `\`\`\`json:mcp:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``,
+                    [],
+                    true,
                   );
                 })
                 .catch((error) => showToast(String(error)));