From 950a5ad9ea642c95a7e90017a4da6f92ee8f64de Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 31 Oct 2020 11:02:04 +0100 Subject: [PATCH] add a recoverer where appropriate I have never seen this, but a malformed packet can easily crash pkg/sftp --- sftpd/internal_test.go | 25 +++++++++++++++++++++++++ sftpd/scp.go | 10 ++++++++-- sftpd/server.go | 11 +++++++++++ sftpd/sftpd_test.go | 3 +-- sftpd/ssh_cmd.go | 11 +++++++++-- webdavd/internal_test.go | 12 ++++++++++++ webdavd/server.go | 7 +++++++ 7 files changed, 73 insertions(+), 6 deletions(-) diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index af921d74..33b64895 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -1846,3 +1846,28 @@ func TestSFTPSubSystem(t *testing.T) { err = subsystemChannel.Close() assert.NoError(t, err) } + +func TestRecoverer(t *testing.T) { + c := Configuration{} + c.AcceptInboundConnection(nil, nil) + connID := "connectionID" + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil), + } + c.handleSftpConnection(nil, connection) + sshCmd := sshCommand{ + command: "cd", + connection: connection, + } + err := sshCmd.handle() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + scpCmd := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + }, + } + err = scpCmd.handle() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + assert.Len(t, common.Connections.GetStats(), 0) +} diff --git a/sftpd/scp.go b/sftpd/scp.go index 08a12c4e..657af42c 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -7,6 +7,7 @@ import ( "os" "path" "path/filepath" + "runtime/debug" "strconv" "strings" @@ -28,11 +29,16 @@ type scpCommand struct { sshCommand } -func (c *scpCommand) handle() error { +func (c *scpCommand) handle() (err error) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handle scp command: %#v stack strace: %v", r, string(debug.Stack())) + err = common.ErrGenericFailure + } + }() common.Connections.Add(c.connection) defer common.Connections.Remove(c.connection.GetID()) - var err error destPath := c.getDestPath() commandType := c.getCommandType() c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v", diff --git a/sftpd/server.go b/sftpd/server.go index a5571ec3..b30adf6b 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -11,6 +11,7 @@ import ( "net" "os" "path/filepath" + "runtime/debug" "strings" "time" @@ -266,6 +267,11 @@ func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Server // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) + } + }() // Before beginning a handshake must be performed on the incoming net.Conn // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck @@ -374,6 +380,11 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server } func (c Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack())) + } + }() common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 712074f2..6a803590 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -462,13 +462,12 @@ func TestConcurrency(t *testing.T) { client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { - defer client.Close() - err = checkBasicSFTP(client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client) assert.NoError(t, err) assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0) + client.Close() } }(i) } diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 1b879387..75e5910a 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -12,6 +12,7 @@ import ( "os" "os/exec" "path" + "runtime/debug" "strings" "sync" @@ -84,7 +85,13 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand return false } -func (c *sshCommand) handle() error { +func (c *sshCommand) handle() (err error) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handle ssh command: %#v stack strace: %v", r, string(debug.Stack())) + err = common.ErrGenericFailure + } + }() common.Connections.Add(c.connection) defer common.Connections.Remove(c.connection.GetID()) @@ -108,7 +115,7 @@ func (c *sshCommand) handle() error { } else if c.command == "sftpgo-remove" { return c.handeSFTPGoRemove() } - return nil + return } func (c *sshCommand) handeSFTPGoCopy() error { diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index b5d5ee60..50473d8e 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -8,6 +8,7 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httptest" "os" "path" "path/filepath" @@ -862,3 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) { _, err = httpd.RemoveUser(user4, http.StatusOK) assert.NoError(t, err) } + +func TestRecoverer(t *testing.T) { + c := &Configuration{ + BindPort: 9000, + } + server, err := newServer(c, configDir) + assert.NoError(t, err) + rr := httptest.NewRecorder() + server.ServeHTTP(rr, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) +} diff --git a/webdavd/server.go b/webdavd/server.go index fd32882d..2c426cde 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -8,6 +8,7 @@ import ( "net/http" "path" "path/filepath" + "runtime/debug" "strings" "time" @@ -85,6 +86,12 @@ func (s *webDavServer) listenAndServe() error { // ServeHTTP implements the http.Handler interface func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack())) + http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) + } + }() checkRemoteAddress(r) if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)