diff --git a/cmd/dive/cli/internal/mcp/server.go b/cmd/dive/cli/internal/mcp/server.go index ae912e8..4853edc 100644 --- a/cmd/dive/cli/internal/mcp/server.go +++ b/cmd/dive/cli/internal/mcp/server.go @@ -144,19 +144,99 @@ func NewServer(id clio.Identification, opts options.MCP) *server.MCPServer { func Run(s *server.MCPServer, opts options.MCP) error { switch opts.Transport { case "sse": - addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port) - sseServer := server.NewSSEServer(s, server.WithBaseURL(fmt.Sprintf("http://%s", addr))) + host := opts.Host + if host == "" { + host = "0.0.0.0" + } + addr := fmt.Sprintf("%s:%d", host, opts.Port) + + baseURLHost := opts.Host + if baseURLHost == "" || baseURLHost == "0.0.0.0" { + baseURLHost = "localhost" + } + + // If the user specified 0.0.0.0, they might be accessing from another machine. + // We should warn that 'localhost' in the baseURL might cause issues for remote clients. + if opts.Host == "0.0.0.0" { + log.Warnf("Listening on 0.0.0.0 but baseURL is set to localhost. Remote MCP clients might fail to connect to the message endpoint. Consider setting --host to your actual IP or hostname.") + } + + baseURL := fmt.Sprintf("http://%s:%d", baseURLHost, opts.Port) + sseServer := server.NewSSEServer(s, server.WithBaseURL(baseURL)) mux := http.NewServeMux() - mux.Handle("/sse", sseServer.SSEHandler()) - mux.Handle("/messages", sseServer.MessageHandler()) + + // Session extractor middleware to handle both header and query param + sessionMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The 2025-03-26 spec uses Mcp-Session-Id header. + // Older specs/mcp-go uses sessionId query parameter. + sessionID := r.URL.Query().Get("sessionId") + if sessionID == "" { + sessionID = r.Header.Get("Mcp-Session-Id") + } + + if sessionID != "" { + // Ensure mcp-go finds it in the query params if it's only in the header + if r.URL.Query().Get("sessionId") == "" { + q := r.URL.Query() + q.Set("sessionId", sessionID) + r.URL.RawQuery = q.Encode() + } + // Also set it in the header for consistency + r.Header.Set("Mcp-Session-Id", sessionID) + w.Header().Set("Mcp-Session-Id", sessionID) + } else if r.Method == http.MethodPost { + log.Warnf("MCP POST request to %s missing session ID (tried sessionId query and Mcp-Session-Id header) from %s", r.URL.Path, r.RemoteAddr) + } + + if version := r.Header.Get("Mcp-Protocol-Version"); version != "" { + log.Debugf("MCP client protocol version: %s", version) + } + + next.ServeHTTP(w, r) + }) + } + + // Support both GET and POST on /sse to be compatible with all clients. + // Some clients ignore the endpoint event and POST to the same URL they GET from. + mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // We MUST rewrite the path to /message because MessageHandler + // is strict about the path it's mounted on. + r.URL.Path = "/message" + sessionMiddleware(sseServer.MessageHandler()).ServeHTTP(w, r) + return + } + sseServer.SSEHandler().ServeHTTP(w, r) + }) + + // Also support the standard /message endpoint + mux.Handle("/message", sessionMiddleware(sseServer.MessageHandler())) + + // Add CORS middleware to allow cross-origin requests (e.g., from web-based MCP inspectors) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Infof("MCP Request: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Mcp-Session-Id, Mcp-Protocol-Version") + w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, Mcp-Protocol-Version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + mux.ServeHTTP(w, r) + }) log.Infof("Starting MCP SSE server on %s", addr) fmt.Printf("Starting MCP SSE server on %s\n", addr) - fmt.Printf("- SSE endpoint: http://%s/sse\n", addr) - fmt.Printf("- Message endpoint: http://%s/messages\n", addr) + fmt.Printf("- SSE endpoint: %s/sse\n", baseURL) + fmt.Printf("- Message endpoint: %s/message\n", baseURL) - return http.ListenAndServe(addr, mux) + return http.ListenAndServe(addr, handler) case "stdio": log.Infof("Starting MCP Stdio server") return server.ServeStdio(s)