Defer closing the gitrepo until the end of the wrapped context functions (#15653)

There was a mistake in #15372 where deferral of gitrepo close occurs before it should.

This PR fixes this.

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
zeripath 2021-05-06 00:30:25 +01:00 committed by GitHub
parent e071b53686
commit eedc0c8324
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 8 deletions

View file

@ -6,6 +6,7 @@
package context package context
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/url" "net/url"
@ -393,7 +394,7 @@ func RepoIDAssignment() func(ctx *Context) {
} }
// RepoAssignment returns a middleware to handle repository assignment // RepoAssignment returns a middleware to handle repository assignment
func RepoAssignment(ctx *Context) { func RepoAssignment(ctx *Context) (cancel context.CancelFunc) {
var ( var (
owner *models.User owner *models.User
err error err error
@ -529,12 +530,12 @@ func RepoAssignment(ctx *Context) {
ctx.Repo.GitRepo = gitRepo ctx.Repo.GitRepo = gitRepo
// We opened it, we should close it // We opened it, we should close it
defer func() { cancel = func() {
// If it's been set to nil then assume someone else has closed it. // If it's been set to nil then assume someone else has closed it.
if ctx.Repo.GitRepo != nil { if ctx.Repo.GitRepo != nil {
ctx.Repo.GitRepo.Close() ctx.Repo.GitRepo.Close()
} }
}() }
// Stop at this point when the repo is empty. // Stop at this point when the repo is empty.
if ctx.Repo.Repository.IsEmpty { if ctx.Repo.Repository.IsEmpty {
@ -619,6 +620,7 @@ func RepoAssignment(ctx *Context) {
ctx.Data["GoDocDirectory"] = prefix + "{/dir}" ctx.Data["GoDocDirectory"] = prefix + "{/dir}"
ctx.Data["GoDocFile"] = prefix + "{/dir}/{file}#L{line}" ctx.Data["GoDocFile"] = prefix + "{/dir}/{file}#L{line}"
} }
return
} }
// RepoRefType type of repo reference // RepoRefType type of repo reference
@ -643,7 +645,7 @@ const (
// RepoRef handles repository reference names when the ref name is not // RepoRef handles repository reference names when the ref name is not
// explicitly given // explicitly given
func RepoRef() func(*Context) { func RepoRef() func(*Context) context.CancelFunc {
// since no ref name is explicitly specified, ok to just use branch // since no ref name is explicitly specified, ok to just use branch
return RepoRefByType(RepoRefBranch) return RepoRefByType(RepoRefBranch)
} }
@ -722,8 +724,8 @@ func getRefName(ctx *Context, pathType RepoRefType) string {
// RepoRefByType handles repository reference name for a specific type // RepoRefByType handles repository reference name for a specific type
// of repository reference // of repository reference
func RepoRefByType(refType RepoRefType) func(*Context) { func RepoRefByType(refType RepoRefType) func(*Context) context.CancelFunc {
return func(ctx *Context) { return func(ctx *Context) (cancel context.CancelFunc) {
// Empty repository does not have reference information. // Empty repository does not have reference information.
if ctx.Repo.Repository.IsEmpty { if ctx.Repo.Repository.IsEmpty {
return return
@ -742,12 +744,12 @@ func RepoRefByType(refType RepoRefType) func(*Context) {
return return
} }
// We opened it, we should close it // We opened it, we should close it
defer func() { cancel = func() {
// If it's been set to nil then assume someone else has closed it. // If it's been set to nil then assume someone else has closed it.
if ctx.Repo.GitRepo != nil { if ctx.Repo.GitRepo != nil {
ctx.Repo.GitRepo.Close() ctx.Repo.GitRepo.Close()
} }
}() }
} }
// Get default branch. // Get default branch.
@ -841,6 +843,7 @@ func RepoRefByType(refType RepoRefType) func(*Context) {
return return
} }
ctx.Data["CommitsCount"] = ctx.Repo.CommitsCount ctx.Data["CommitsCount"] = ctx.Repo.CommitsCount
return
} }
} }

View file

@ -5,6 +5,7 @@
package web package web
import ( import (
goctx "context"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@ -27,6 +28,7 @@ func Wrap(handlers ...interface{}) http.HandlerFunc {
switch t := handler.(type) { switch t := handler.(type) {
case http.HandlerFunc, func(http.ResponseWriter, *http.Request), case http.HandlerFunc, func(http.ResponseWriter, *http.Request),
func(ctx *context.Context), func(ctx *context.Context),
func(ctx *context.Context) goctx.CancelFunc,
func(*context.APIContext), func(*context.APIContext),
func(*context.PrivateContext), func(*context.PrivateContext),
func(http.Handler) http.Handler: func(http.Handler) http.Handler:
@ -48,6 +50,15 @@ func Wrap(handlers ...interface{}) http.HandlerFunc {
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 {
return return
} }
case func(ctx *context.Context) goctx.CancelFunc:
ctx := context.GetContext(req)
cancel := t(ctx)
if cancel != nil {
defer cancel()
}
if ctx.Written() {
return
}
case func(ctx *context.Context): case func(ctx *context.Context):
ctx := context.GetContext(req) ctx := context.GetContext(req)
t(ctx) t(ctx)
@ -94,6 +105,23 @@ func Middle(f func(ctx *context.Context)) func(netx http.Handler) http.Handler {
} }
} }
// MiddleCancel wrap a context function as a chi middleware
func MiddleCancel(f func(ctx *context.Context) goctx.CancelFunc) func(netx http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
ctx := context.GetContext(req)
cancel := f(ctx)
if cancel != nil {
defer cancel()
}
if ctx.Written() {
return
}
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
// MiddleAPI wrap a context function as a chi middleware // MiddleAPI wrap a context function as a chi middleware
func MiddleAPI(f func(ctx *context.APIContext)) func(netx http.Handler) http.Handler { func MiddleAPI(f func(ctx *context.APIContext)) func(netx http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
@ -163,6 +191,8 @@ func (r *Route) Use(middlewares ...interface{}) {
r.R.Use(t) r.R.Use(t)
case func(*context.Context): case func(*context.Context):
r.R.Use(Middle(t)) r.R.Use(Middle(t))
case func(*context.Context) goctx.CancelFunc:
r.R.Use(MiddleCancel(t))
case func(*context.APIContext): case func(*context.APIContext):
r.R.Use(MiddleAPI(t)) r.R.Use(MiddleAPI(t))
default: default: