mirror of
https://github.com/wagoodman/dive
synced 2026-03-14 22:35:50 +01:00
fix(mcp): support Mcp-Session-Id header and improve SSE transport compliance
- Implement session extraction middleware to handle both header and query param - Fix path rewriting for POST requests on /sse to support 'dumb' clients - Add CORS support for Mcp-Session-Id and Mcp-Protocol-Version headers - Improve baseURL logic and add warnings for 0.0.0.0 listening
This commit is contained in:
parent
2f0d59e489
commit
d635282b41
1 changed files with 87 additions and 7 deletions
|
|
@ -144,19 +144,99 @@ func NewServer(id clio.Identification, opts options.MCP) *server.MCPServer {
|
|||
func Run(s *server.MCPServer, opts options.MCP) error {
|
||||
switch opts.Transport {
|
||||
case "sse":
|
||||
addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
|
||||
sseServer := server.NewSSEServer(s, server.WithBaseURL(fmt.Sprintf("http://%s", addr)))
|
||||
host := opts.Host
|
||||
if host == "" {
|
||||
host = "0.0.0.0"
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", host, opts.Port)
|
||||
|
||||
baseURLHost := opts.Host
|
||||
if baseURLHost == "" || baseURLHost == "0.0.0.0" {
|
||||
baseURLHost = "localhost"
|
||||
}
|
||||
|
||||
// If the user specified 0.0.0.0, they might be accessing from another machine.
|
||||
// We should warn that 'localhost' in the baseURL might cause issues for remote clients.
|
||||
if opts.Host == "0.0.0.0" {
|
||||
log.Warnf("Listening on 0.0.0.0 but baseURL is set to localhost. Remote MCP clients might fail to connect to the message endpoint. Consider setting --host to your actual IP or hostname.")
|
||||
}
|
||||
|
||||
baseURL := fmt.Sprintf("http://%s:%d", baseURLHost, opts.Port)
|
||||
sseServer := server.NewSSEServer(s, server.WithBaseURL(baseURL))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/sse", sseServer.SSEHandler())
|
||||
mux.Handle("/messages", sseServer.MessageHandler())
|
||||
|
||||
// Session extractor middleware to handle both header and query param
|
||||
sessionMiddleware := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// The 2025-03-26 spec uses Mcp-Session-Id header.
|
||||
// Older specs/mcp-go uses sessionId query parameter.
|
||||
sessionID := r.URL.Query().Get("sessionId")
|
||||
if sessionID == "" {
|
||||
sessionID = r.Header.Get("Mcp-Session-Id")
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
// Ensure mcp-go finds it in the query params if it's only in the header
|
||||
if r.URL.Query().Get("sessionId") == "" {
|
||||
q := r.URL.Query()
|
||||
q.Set("sessionId", sessionID)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
// Also set it in the header for consistency
|
||||
r.Header.Set("Mcp-Session-Id", sessionID)
|
||||
w.Header().Set("Mcp-Session-Id", sessionID)
|
||||
} else if r.Method == http.MethodPost {
|
||||
log.Warnf("MCP POST request to %s missing session ID (tried sessionId query and Mcp-Session-Id header) from %s", r.URL.Path, r.RemoteAddr)
|
||||
}
|
||||
|
||||
if version := r.Header.Get("Mcp-Protocol-Version"); version != "" {
|
||||
log.Debugf("MCP client protocol version: %s", version)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Support both GET and POST on /sse to be compatible with all clients.
|
||||
// Some clients ignore the endpoint event and POST to the same URL they GET from.
|
||||
mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost {
|
||||
// We MUST rewrite the path to /message because MessageHandler
|
||||
// is strict about the path it's mounted on.
|
||||
r.URL.Path = "/message"
|
||||
sessionMiddleware(sseServer.MessageHandler()).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
sseServer.SSEHandler().ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Also support the standard /message endpoint
|
||||
mux.Handle("/message", sessionMiddleware(sseServer.MessageHandler()))
|
||||
|
||||
// Add CORS middleware to allow cross-origin requests (e.g., from web-based MCP inspectors)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Infof("MCP Request: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Mcp-Session-Id, Mcp-Protocol-Version")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, Mcp-Protocol-Version")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
mux.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
log.Infof("Starting MCP SSE server on %s", addr)
|
||||
fmt.Printf("Starting MCP SSE server on %s\n", addr)
|
||||
fmt.Printf("- SSE endpoint: http://%s/sse\n", addr)
|
||||
fmt.Printf("- Message endpoint: http://%s/messages\n", addr)
|
||||
fmt.Printf("- SSE endpoint: %s/sse\n", baseURL)
|
||||
fmt.Printf("- Message endpoint: %s/message\n", baseURL)
|
||||
|
||||
return http.ListenAndServe(addr, mux)
|
||||
return http.ListenAndServe(addr, handler)
|
||||
case "stdio":
|
||||
log.Infof("Starting MCP Stdio server")
|
||||
return server.ServeStdio(s)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue