diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index 44118e91..d51a59fb 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -4189,18 +4189,19 @@ func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFi u.Filters.TOTPConfig = totpConfig u.Filters.RecoveryCodes = recoveryCodes err = provider.updateUser(&u) - if err == nil { - webDAVUsersCache.swap(&u, "") - } } if err != nil { return u, err } - providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, userID) - if userID == 0 { - return provider.userExists(username, "") + user, err := provider.userExists(username, "") + if err != nil { + return u, err } - return u, nil + providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, userID) + if userID > 0 { + webDAVUsersCache.swap(&user, "") + } + return user, nil } // ExecutePostLoginHook executes the post login hook if defined diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index f8fe779c..0d74974b 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -2980,13 +2980,46 @@ func TestPreLoginScript(t *testing.T) { err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + folderMountPath := "/vpath" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: folderMountPath, + }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testData := []byte("test data") + err = os.WriteFile(testFilePath, testData, 0666) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(folderMountPath, testFileName), int64(len(testData)), client) + assert.NoError(t, err) + info, err := os.Stat(filepath.Join(mappedPath, testFileName)) + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(len(testData))) + } } err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) @@ -3023,6 +3056,10 @@ func TestPreLoginScript(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "")