From 63f35754c6e4e806ffb2e239b5ac9c02854331c5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:04:12 +0300 Subject: [PATCH] federation/serverauth: store verified origin in request context --- federation/context.go | 12 ++++++++++++ federation/serverauth.go | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/federation/context.go b/federation/context.go index 8280431f..eedb2dc1 100644 --- a/federation/context.go +++ b/federation/context.go @@ -16,6 +16,7 @@ type contextKey int const ( contextKeyIPPort contextKey = iota contextKeyDestinationServer + contextKeyOriginServer ) func DestinationServerNameFromRequest(r *http.Request) string { @@ -28,3 +29,14 @@ func DestinationServerName(ctx context.Context) string { } return "" } + +func OriginServerNameFromRequest(r *http.Request) string { + return OriginServerName(r.Context()) +} + +func OriginServerName(ctx context.Context) string { + if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok { + return origin + } + return "" +} diff --git a/federation/serverauth.go b/federation/serverauth.go index e2036d30..02780ff8 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -234,7 +234,11 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res return nil, &errInvalidRequestSignature } ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) - ctx = log.With().Str("destination_server_name", destination).Logger().WithContext(ctx) + ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin) + ctx = log.With(). + Str("origin_server_name", parsed.Origin). + Str("destination_server_name", destination). + Logger().WithContext(ctx) modifiedReq := r.WithContext(ctx) modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) return modifiedReq, nil