Add new methods for validating user ID localparts. Fixes #27

This commit is contained in:
Tulir Asokan 2021-02-24 12:08:20 +02:00
commit ecce653670
3 changed files with 130 additions and 4 deletions

View file

@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -9,7 +9,9 @@ package id
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"regexp"
"strings"
)
@ -17,6 +19,8 @@ import (
// https://matrix.org/docs/spec/appendices#user-identifiers
type UserID string
const UserIDMaxLength = 255
func NewUserID(localpart, homeserver string) UserID {
return UserID(fmt.Sprintf("@%s:%s", localpart, homeserver))
}
@ -25,11 +29,22 @@ func NewEncodedUserID(localpart, homeserver string) UserID {
return NewUserID(EncodeUserLocalpart(localpart), homeserver)
}
var (
ErrInvalidUserID = errors.New("is not a valid user ID")
ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
)
// Parse parses the user ID into the localpart and server name.
// See http://matrix.org/docs/spec/intro.html#user-identifiers
//
// Note that this only enforces very basic user ID formatting requirements: user IDs start with
// a @, and contain a : after the @. If you want to enforce localpart validity, see the
// ParseAndValidate and ValidateUserLocalpart functions.
func (userID UserID) Parse() (localpart, homeserver string, err error) {
if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') {
err = fmt.Errorf("%s is not a valid user id", userID)
// This error wrapping lets you use errors.Is() nicely even though the message contains the user ID
err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID)
return
}
parts := strings.SplitN(string(userID), ":", 2)
@ -37,8 +52,34 @@ func (userID UserID) Parse() (localpart, homeserver string, err error) {
return
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/]+$")
// ValidateUserLocalpart validates a Matrix user ID localpart using the grammar
// in https://matrix.org/docs/spec/appendices#user-identifier
func ValidateUserLocalpart(localpart string) error {
if len(localpart) == 0 {
return ErrEmptyLocalpart
} else if !ValidLocalpartRegex.MatchString(localpart) {
return fmt.Errorf("'%s' %w", localpart, ErrNoncompliantLocalpart)
}
return nil
}
// ParseAndValidate parses the user ID into the localpart and server name like Parse,
// and also validates that the localpart is allowed according to the user identifiers spec.
func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
localpart, homeserver, err = userID.Parse()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
if err == nil && len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
}
return
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
localpart, homeserver, err = userID.ParseAndValidate()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}

81
id/userid_test.go Normal file
View file

@ -0,0 +1,81 @@
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package id_test
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/id"
)
func TestUserID_Parse(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse()
assert.NoError(t, err)
assert.Equal(t, "s p a c e", parsedLocalpart)
assert.Equal(t, "maunium.net", parsedServerName)
}
func TestUserID_Parse_Emtpty(t *testing.T) {
const inputUserID = "@:ponies.im"
parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse()
assert.NoError(t, err)
assert.Equal(t, "", parsedLocalpart)
assert.Equal(t, "ponies.im", parsedServerName)
}
func TestUserID_Parse_Invalid(t *testing.T) {
const inputUserID = "hello world"
_, _, err := id.UserID(inputUserID).Parse()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
func TestUserID_ParseAndValidate_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
func TestUserID_ParseAndValidate_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.NoError(t, err)
}
func TestUserIDEncoding(t *testing.T) {
const inputLocalpart = "This localpart contains IlLeGaL chäracters 🚨"
const encodedLocalpart = "_this=20localpart=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
decodedLocalpart, decodedServerName, err := userID.ParseAndDecode()
assert.NoError(t, err)
assert.Equal(t, inputLocalpart, decodedLocalpart)
assert.Equal(t, inputServerName, decodedServerName)
}