id/userid: split validation into 2 functions
Some checks failed
Go / Lint (latest) (push) Has been cancelled
Go / Build (old, libolm) (push) Has been cancelled
Go / Build (latest, libolm) (push) Has been cancelled
Go / Build (old, goolm) (push) Has been cancelled
Go / Build (latest, goolm) (push) Has been cancelled

This commit is contained in:
Tulir Asokan 2025-10-06 23:10:04 +03:00
commit 3a300246ac
2 changed files with 23 additions and 15 deletions

View file

@ -104,16 +104,24 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
// ParseAndValidate parses the user ID into the localpart and server name like Parse,
// and also validates that the localpart is allowed according to the user identifiers spec.
func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
localpart, homeserver, err = userID.Parse()
// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
// This should be used with care: there are real users still using historical localparts.
func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
localpart, homeserver, err = userID.ParseAndValidateRelaxed()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
if err == nil && len(userID) > UserIDMaxLength {
return
}
// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
// and also validates that the user ID is not too long and that the server name is valid.
func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
if len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
return
}
localpart, homeserver, err = userID.Parse()
if err == nil && !ValidateServerName(homeserver) {
err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
@ -121,7 +129,7 @@ func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
localpart, homeserver, err = userID.ParseAndValidate()
localpart, homeserver, err = userID.ParseAndValidateStrict()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}

View file

@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
func TestUserID_ParseAndValidate_Empty(t *testing.T) {
func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
func TestUserID_ParseAndValidate_Long(t *testing.T) {
func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
_, _, err := id.UserID(inputUserID).ParseAndValidate()
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.NoError(t, err)
}
@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)