diff --git a/go.sum b/go.sum index a765a0dc..b21bbbc2 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -30,11 +31,13 @@ github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.6.8/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI= @@ -80,6 +83,7 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= maunium.net/go/maulogger/v2 v2.2.0/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= diff --git a/id/userid.go b/id/userid.go index cd832eb2..67199419 100644 --- a/id/userid.go +++ b/id/userid.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -9,7 +9,9 @@ package id import ( "bytes" "encoding/hex" + "errors" "fmt" + "regexp" "strings" ) @@ -17,6 +19,8 @@ import ( // https://matrix.org/docs/spec/appendices#user-identifiers type UserID string +const UserIDMaxLength = 255 + func NewUserID(localpart, homeserver string) UserID { return UserID(fmt.Sprintf("@%s:%s", localpart, homeserver)) } @@ -25,11 +29,22 @@ func NewEncodedUserID(localpart, homeserver string) UserID { return NewUserID(EncodeUserLocalpart(localpart), homeserver) } +var ( + ErrInvalidUserID = errors.New("is not a valid user ID") + ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") + ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") + ErrEmptyLocalpart = errors.New("empty localparts are not allowed") +) + // Parse parses the user ID into the localpart and server name. -// See http://matrix.org/docs/spec/intro.html#user-identifiers +// +// Note that this only enforces very basic user ID formatting requirements: user IDs start with +// a @, and contain a : after the @. If you want to enforce localpart validity, see the +// ParseAndValidate and ValidateUserLocalpart functions. func (userID UserID) Parse() (localpart, homeserver string, err error) { if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') { - err = fmt.Errorf("%s is not a valid user id", userID) + // This error wrapping lets you use errors.Is() nicely even though the message contains the user ID + err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID) return } parts := strings.SplitN(string(userID), ":", 2) @@ -37,8 +52,34 @@ func (userID UserID) Parse() (localpart, homeserver string, err error) { return } -func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { +var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/]+$") + +// ValidateUserLocalpart validates a Matrix user ID localpart using the grammar +// in https://matrix.org/docs/spec/appendices#user-identifier +func ValidateUserLocalpart(localpart string) error { + if len(localpart) == 0 { + return ErrEmptyLocalpart + } else if !ValidLocalpartRegex.MatchString(localpart) { + return fmt.Errorf("'%s' %w", localpart, ErrNoncompliantLocalpart) + } + return nil +} + +// ParseAndValidate parses the user ID into the localpart and server name like Parse, +// and also validates that the localpart is allowed according to the user identifiers spec. +func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { localpart, homeserver, err = userID.Parse() + if err == nil { + err = ValidateUserLocalpart(localpart) + } + if err == nil && len(userID) > UserIDMaxLength { + err = ErrUserIDTooLong + } + return +} + +func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.ParseAndValidate() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } diff --git a/id/userid_test.go b/id/userid_test.go new file mode 100644 index 00000000..fdbf240a --- /dev/null +++ b/id/userid_test.go @@ -0,0 +1,81 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/id" +) + +func TestUserID_Parse(t *testing.T) { + const inputUserID = "@s p a c e:maunium.net" + parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse() + assert.NoError(t, err) + assert.Equal(t, "s p a c e", parsedLocalpart) + assert.Equal(t, "maunium.net", parsedServerName) +} + +func TestUserID_Parse_Emtpty(t *testing.T) { + const inputUserID = "@:ponies.im" + parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse() + assert.NoError(t, err) + assert.Equal(t, "", parsedLocalpart) + assert.Equal(t, "ponies.im", parsedServerName) +} + +func TestUserID_Parse_Invalid(t *testing.T) { + const inputUserID = "hello world" + _, _, err := id.UserID(inputUserID).Parse() + assert.Error(t, err) + assert.True(t, errors.Is(err, id.ErrInvalidUserID)) +} + +func TestUserID_ParseAndValidate_Invalid(t *testing.T) { + const inputUserID = "@s p a c e:maunium.net" + _, _, err := id.UserID(inputUserID).ParseAndValidate() + assert.Error(t, err) + assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) +} + +func TestUserID_ParseAndValidate_Empty(t *testing.T) { + const inputUserID = "@:ponies.im" + _, _, err := id.UserID(inputUserID).ParseAndValidate() + assert.Error(t, err) + assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) +} + +func TestUserID_ParseAndValidate_Long(t *testing.T) { + const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" + _, _, err := id.UserID(inputUserID).ParseAndValidate() + assert.Error(t, err) + assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) +} + +func TestUserID_ParseAndValidate_NotLong(t *testing.T) { + const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" + _, _, err := id.UserID(inputUserID).ParseAndValidate() + assert.NoError(t, err) +} + +func TestUserIDEncoding(t *testing.T) { + const inputLocalpart = "This localpart contains IlLeGaL chäracters 🚨" + const encodedLocalpart = "_this=20localpart=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" + const inputServerName = "example.com" + userID := id.NewEncodedUserID(inputLocalpart, inputServerName) + parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() + assert.NoError(t, err) + assert.Equal(t, encodedLocalpart, parsedLocalpart) + assert.Equal(t, inputServerName, parsedServerName) + decodedLocalpart, decodedServerName, err := userID.ParseAndDecode() + assert.NoError(t, err) + assert.Equal(t, inputLocalpart, decodedLocalpart) + assert.Equal(t, inputServerName, decodedServerName) +}