bridgev2/matrix: add support for generating public media URLs

This commit is contained in:
Tulir Asokan 2024-08-03 18:00:53 +03:00
commit c6bc42f16c
7 changed files with 190 additions and 1 deletions

View file

@ -22,6 +22,7 @@ type Config struct {
AppService AppserviceConfig `yaml:"appservice"`
Matrix MatrixConfig `yaml:"matrix"`
Provisioning ProvisioningConfig `yaml:"provisioning"`
PublicMedia PublicMediaConfig `yaml:"public_media"`
DirectMedia DirectMediaConfig `yaml:"direct_media"`
Backfill BackfillConfig `yaml:"backfill"`
DoublePuppet DoublePuppetConfig `yaml:"double_puppet"`
@ -84,6 +85,13 @@ type DirectMediaConfig struct {
mediaproxy.BasicConfig `yaml:",inline"`
}
type PublicMediaConfig struct {
Enabled bool `yaml:"enabled"`
SigningKey string `yaml:"signing_key"`
HashLength int `yaml:"hash_length"`
Expiry int `yaml:"expiry"`
}
type DoublePuppetConfig struct {
Servers map[string]string `yaml:"servers"`
AllowDiscovery bool `yaml:"allow_discovery"`

View file

@ -93,6 +93,15 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Str, "direct_media", "server_key")
}
helper.Copy(up.Bool, "public_media", "enabled")
if signingKey, ok := helper.Get(up.Str, "public_media", "signing_key"); !ok || signingKey == "generate" {
helper.Set(up.Str, random.String(32), "public_media", "signing_key")
} else {
helper.Copy(up.Str, "public_media", "signing_key")
}
helper.Copy(up.Int, "public_media", "expiry")
helper.Copy(up.Int, "public_media", "hash_length")
helper.Copy(up.Bool, "backfill", "enabled")
helper.Copy(up.Int, "backfill", "max_initial_messages")
helper.Copy(up.Int, "backfill", "max_catchup_messages")

View file

@ -69,7 +69,9 @@ type Connector struct {
Provisioning *ProvisioningAPI
DoublePuppet *doublePuppetUtil
MediaProxy *mediaproxy.MediaProxy
dmaSigKey [32]byte
dmaSigKey [32]byte
pubMediaSigKey []byte
doublePuppetIntents *exsync.Map[id.UserID, *appservice.IntentAPI]
@ -152,6 +154,10 @@ func (br *Connector) Start(ctx context.Context) error {
if err != nil {
return err
}
err = br.initPublicMedia()
if err != nil {
return err
}
err = br.StateStore.Upgrade(ctx)
if err != nil {
return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err}

View file

@ -199,6 +199,21 @@ provisioning:
# Enable debug API at /debug with provisioning authentication.
debug_endpoints: false
# Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks).
# These settings control whether the bridge will provide such public media access.
public_media:
# Should public media be enabled at all?
# The public_address field under the appservice section MUST be set when enabling public media.
enabled: false
# A key for signing public media URLs.
# If set to "generate", a random key will be generated.
signing_key: generate
# Number of seconds that public media URLs are valid for.
# If set to 0, URLs will never expire.
expiry: 0
# Length of hash to use for public media URLs.
hash_length: 32
# Settings for converting remote media to custom mxc:// URIs instead of reuploading.
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
direct_media:

View file

@ -0,0 +1,128 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package matrix
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/id"
)
var _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil)
func (br *Connector) initPublicMedia() error {
if !br.Config.PublicMedia.Enabled {
return nil
} else if br.GetPublicAddress() == "" {
return fmt.Errorf("public media is enabled in config, but no public address is set")
} else if br.Config.PublicMedia.HashLength > 32 {
return fmt.Errorf("public media hash length is too long")
} else if br.Config.PublicMedia.HashLength < 0 {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
return nil
}
func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte {
hasher := hmac.New(sha256.New, br.pubMediaSigKey)
hasher.Write([]byte(uri.String()))
hasher.Write(expiry)
return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)]
}
func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte {
var expiresAt []byte
if br.Config.PublicMedia.Expiry > 0 {
expiresAtInt := time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second).Unix()
expiresAt = binary.BigEndian.AppendUint64(nil, uint64(expiresAtInt))
}
return br.hashContentURI(uri, expiresAt)
}
func (br *Connector) verifyPublicMediaChecksum(uri id.ContentURI, checksum []byte) (valid, expired bool) {
var expiryBytes []byte
if br.Config.PublicMedia.Expiry > 0 {
if len(checksum) < 8 {
return
}
expiryBytes = checksum[:8]
expiresAtInt := binary.BigEndian.Uint64(expiryBytes)
expired = time.Now().Unix() > int64(expiresAtInt)
}
valid = hmac.Equal(checksum, br.hashContentURI(uri, expiryBytes))
return
}
var proxyHeadersToCopy = []string{
"Content-Type", "Content-Disposition", "Content-Length", "Content-Security-Policy",
"Access-Control-Allow-Origin", "Access-Control-Allow-Methods", "Access-Control-Allow-Headers",
"Cache-Control", "Cross-Origin-Resource-Policy",
}
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
contentURI := id.ContentURI{
Homeserver: vars["server"],
FileID: vars["mediaID"],
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return
} else if valid, expired := br.verifyPublicMediaChecksum(contentURI, checksum); !valid {
http.Error(w, "invalid checksum", http.StatusNotFound)
return
} else if expired {
http.Error(w, "checksum expired", http.StatusGone)
return
}
resp, err := br.Bot.Download(r.Context(), contentURI)
if err != nil {
br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
http.Error(w, "failed to download media", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
for _, hdr := range proxyHeadersToCopy {
w.Header()[hdr] = resp.Header[hdr]
}
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, resp.Body)
}
func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string {
if br.pubMediaSigKey == nil {
return ""
}
parsed, err := contentURI.Parse()
if err != nil || !parsed.IsValid() {
return ""
}
return fmt.Sprintf(
"%s/_mautrix/publicmedia/%s/%s/%s",
br.GetPublicAddress(),
parsed.Homeserver,
parsed.FileID,
base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)),
)
}

View file

@ -58,6 +58,10 @@ type MatrixConnectorWithServer interface {
GetRouter() *mux.Router
}
type MatrixConnectorWithPublicMedia interface {
GetPublicMediaAddress(contentURI id.ContentURIString) string
}
type MatrixConnectorWithNameDisambiguation interface {
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}

View file

@ -12,6 +12,7 @@ import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
)
@ -156,3 +157,21 @@ func (uri ContentURI) CUString() ContentURIString {
func (uri ContentURI) IsEmpty() bool {
return len(uri.Homeserver) == 0 || len(uri.FileID) == 0
}
var simpleHomeserverRegex = regexp.MustCompile(`^[a-zA-Z0-9.:-]+$`)
func (uri ContentURI) IsValid() bool {
return IsValidMediaID(uri.Homeserver) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver)
}
func IsValidMediaID(mediaID string) bool {
if len(mediaID) == 0 {
return false
}
for _, char := range mediaID {
if (char < 'A' || char > 'Z') && (char < 'a' || char > 'z') && (char < '0' || char > '9') && char != '-' && char != '_' {
return false
}
}
return true
}