mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
id/userid: split validation into 2 functions
Some checks failed
Some checks failed
This commit is contained in:
parent
51edfc27c0
commit
3a300246ac
2 changed files with 23 additions and 15 deletions
20
id/userid.go
20
id/userid.go
|
|
@ -104,16 +104,24 @@ func ValidateUserLocalpart(localpart string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ParseAndValidate parses the user ID into the localpart and server name like Parse,
|
||||
// and also validates that the localpart is allowed according to the user identifiers spec.
|
||||
func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
|
||||
localpart, homeserver, err = userID.Parse()
|
||||
// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
|
||||
// This should be used with care: there are real users still using historical localparts.
|
||||
func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
|
||||
localpart, homeserver, err = userID.ParseAndValidateRelaxed()
|
||||
if err == nil {
|
||||
err = ValidateUserLocalpart(localpart)
|
||||
}
|
||||
if err == nil && len(userID) > UserIDMaxLength {
|
||||
return
|
||||
}
|
||||
|
||||
// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
|
||||
// and also validates that the user ID is not too long and that the server name is valid.
|
||||
func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
|
||||
if len(userID) > UserIDMaxLength {
|
||||
err = ErrUserIDTooLong
|
||||
return
|
||||
}
|
||||
localpart, homeserver, err = userID.Parse()
|
||||
if err == nil && !ValidateServerName(homeserver) {
|
||||
err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
|
||||
}
|
||||
|
|
@ -121,7 +129,7 @@ func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error
|
|||
}
|
||||
|
||||
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
|
||||
localpart, homeserver, err = userID.ParseAndValidate()
|
||||
localpart, homeserver, err = userID.ParseAndValidateStrict()
|
||||
if err == nil {
|
||||
localpart, err = DecodeUserLocalpart(localpart)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
|
|||
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
|
||||
}
|
||||
|
||||
func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
|
||||
func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
|
||||
const inputUserID = "@s p a c e:maunium.net"
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidate()
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
|
||||
}
|
||||
|
||||
func TestUserID_ParseAndValidate_Empty(t *testing.T) {
|
||||
func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
|
||||
const inputUserID = "@:ponies.im"
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidate()
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
|
||||
}
|
||||
|
||||
func TestUserID_ParseAndValidate_Long(t *testing.T) {
|
||||
func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
|
||||
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidate()
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
|
||||
}
|
||||
|
||||
func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
|
||||
func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
|
||||
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidate()
|
||||
_, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
|
|||
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
|
||||
const inputServerName = "example.com"
|
||||
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
|
||||
parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
|
||||
parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, encodedLocalpart, parsedLocalpart)
|
||||
assert.Equal(t, inputServerName, parsedServerName)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue