fix(mcp): support Mcp-Session-Id header and improve SSE transport compliance

- Implement session extraction middleware to handle both header and query param
- Fix path rewriting for POST requests on /sse to support 'dumb' clients
- Add CORS support for Mcp-Session-Id and Mcp-Protocol-Version headers
- Improve baseURL logic and add warnings for 0.0.0.0 listening
This commit is contained in:
Daoud AbdelMonem Faleh 2026-03-03 16:53:34 +01:00
commit d635282b41

View file

@ -144,19 +144,99 @@ func NewServer(id clio.Identification, opts options.MCP) *server.MCPServer {
func Run(s *server.MCPServer, opts options.MCP) error {
switch opts.Transport {
case "sse":
addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
sseServer := server.NewSSEServer(s, server.WithBaseURL(fmt.Sprintf("http://%s", addr)))
host := opts.Host
if host == "" {
host = "0.0.0.0"
}
addr := fmt.Sprintf("%s:%d", host, opts.Port)
baseURLHost := opts.Host
if baseURLHost == "" || baseURLHost == "0.0.0.0" {
baseURLHost = "localhost"
}
// If the user specified 0.0.0.0, they might be accessing from another machine.
// We should warn that 'localhost' in the baseURL might cause issues for remote clients.
if opts.Host == "0.0.0.0" {
log.Warnf("Listening on 0.0.0.0 but baseURL is set to localhost. Remote MCP clients might fail to connect to the message endpoint. Consider setting --host to your actual IP or hostname.")
}
baseURL := fmt.Sprintf("http://%s:%d", baseURLHost, opts.Port)
sseServer := server.NewSSEServer(s, server.WithBaseURL(baseURL))
mux := http.NewServeMux()
mux.Handle("/sse", sseServer.SSEHandler())
mux.Handle("/messages", sseServer.MessageHandler())
// Session extractor middleware to handle both header and query param
sessionMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The 2025-03-26 spec uses Mcp-Session-Id header.
// Older specs/mcp-go uses sessionId query parameter.
sessionID := r.URL.Query().Get("sessionId")
if sessionID == "" {
sessionID = r.Header.Get("Mcp-Session-Id")
}
if sessionID != "" {
// Ensure mcp-go finds it in the query params if it's only in the header
if r.URL.Query().Get("sessionId") == "" {
q := r.URL.Query()
q.Set("sessionId", sessionID)
r.URL.RawQuery = q.Encode()
}
// Also set it in the header for consistency
r.Header.Set("Mcp-Session-Id", sessionID)
w.Header().Set("Mcp-Session-Id", sessionID)
} else if r.Method == http.MethodPost {
log.Warnf("MCP POST request to %s missing session ID (tried sessionId query and Mcp-Session-Id header) from %s", r.URL.Path, r.RemoteAddr)
}
if version := r.Header.Get("Mcp-Protocol-Version"); version != "" {
log.Debugf("MCP client protocol version: %s", version)
}
next.ServeHTTP(w, r)
})
}
// Support both GET and POST on /sse to be compatible with all clients.
// Some clients ignore the endpoint event and POST to the same URL they GET from.
mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
// We MUST rewrite the path to /message because MessageHandler
// is strict about the path it's mounted on.
r.URL.Path = "/message"
sessionMiddleware(sseServer.MessageHandler()).ServeHTTP(w, r)
return
}
sseServer.SSEHandler().ServeHTTP(w, r)
})
// Also support the standard /message endpoint
mux.Handle("/message", sessionMiddleware(sseServer.MessageHandler()))
// Add CORS middleware to allow cross-origin requests (e.g., from web-based MCP inspectors)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Infof("MCP Request: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Mcp-Session-Id, Mcp-Protocol-Version")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, Mcp-Protocol-Version")
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
mux.ServeHTTP(w, r)
})
log.Infof("Starting MCP SSE server on %s", addr)
fmt.Printf("Starting MCP SSE server on %s\n", addr)
fmt.Printf("- SSE endpoint: http://%s/sse\n", addr)
fmt.Printf("- Message endpoint: http://%s/messages\n", addr)
fmt.Printf("- SSE endpoint: %s/sse\n", baseURL)
fmt.Printf("- Message endpoint: %s/message\n", baseURL)
return http.ListenAndServe(addr, mux)
return http.ListenAndServe(addr, handler)
case "stdio":
log.Infof("Starting MCP Stdio server")
return server.ServeStdio(s)