Correctly handle select on multiple channels in Queues (#22146)

There are a few places in FlushQueueWithContext which make an incorrect
assumption about how `select` on multiple channels works.

The problem is best expressed by looking at the following example:

```go
package main

import "fmt"

func main() {
    closedChan := make(chan struct{})
    close(closedChan)
    toClose := make(chan struct{})
    count := 0

    for {
        select {
        case <-closedChan:
            count++
            fmt.Println(count)
            if count == 2 {
                close(toClose)
            }
        case <-toClose:
            return
        }
    }
}
```

This PR double-checks that the contexts are closed outside of checking
if there is data in the dataChan. It also rationalises the WorkerPool
FlushWithContext because the previous implementation failed to handle
pausing correctly. This will probably fix the underlying problem in
 #22145

Fix #22145

Signed-off-by: Andrew Thornton <art27@cantab.net>

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
zeripath 2022-12-30 00:06:47 +00:00 committed by GitHub
parent 47efba78ec
commit a609cae9fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 57 deletions

View file

@ -109,32 +109,6 @@ func (q *ChannelQueue) Flush(timeout time.Duration) error {
return q.FlushWithContext(ctx) return q.FlushWithContext(ctx)
} }
// FlushWithContext is very similar to CleanUp but it will return as soon as the dataChan is empty
func (q *ChannelQueue) FlushWithContext(ctx context.Context) error {
log.Trace("ChannelQueue: %d Flush", q.qid)
paused, _ := q.IsPausedIsResumed()
for {
select {
case <-paused:
return nil
case data, ok := <-q.dataChan:
if !ok {
return nil
}
if unhandled := q.handle(data); unhandled != nil {
log.Error("Unhandled Data whilst flushing queue %d", q.qid)
}
atomic.AddInt64(&q.numInQueue, -1)
case <-q.baseCtx.Done():
return q.baseCtx.Err()
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
}
// Shutdown processing from this queue // Shutdown processing from this queue
func (q *ChannelQueue) Shutdown() { func (q *ChannelQueue) Shutdown() {
q.lock.Lock() q.lock.Lock()

View file

@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
"sync/atomic"
"time" "time"
"code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/container"
@ -167,35 +166,6 @@ func (q *ChannelUniqueQueue) Flush(timeout time.Duration) error {
return q.FlushWithContext(ctx) return q.FlushWithContext(ctx)
} }
// FlushWithContext is very similar to CleanUp but it will return as soon as the dataChan is empty
func (q *ChannelUniqueQueue) FlushWithContext(ctx context.Context) error {
log.Trace("ChannelUniqueQueue: %d Flush", q.qid)
paused, _ := q.IsPausedIsResumed()
for {
select {
case <-paused:
return nil
default:
}
select {
case data, ok := <-q.dataChan:
if !ok {
return nil
}
if unhandled := q.handle(data); unhandled != nil {
log.Error("Unhandled Data whilst flushing queue %d", q.qid)
}
atomic.AddInt64(&q.numInQueue, -1)
case <-q.baseCtx.Done():
return q.baseCtx.Err()
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
}
// Shutdown processing from this queue // Shutdown processing from this queue
func (q *ChannelUniqueQueue) Shutdown() { func (q *ChannelUniqueQueue) Shutdown() {
log.Trace("ChannelUniqueQueue: %s Shutting down", q.name) log.Trace("ChannelUniqueQueue: %s Shutting down", q.name)

View file

@ -463,13 +463,43 @@ func (p *WorkerPool) IsEmpty() bool {
return atomic.LoadInt64(&p.numInQueue) == 0 return atomic.LoadInt64(&p.numInQueue) == 0
} }
// contextError returns either ctx.Done(), the base context's error or nil
func (p *WorkerPool) contextError(ctx context.Context) error {
select {
case <-p.baseCtx.Done():
return p.baseCtx.Err()
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
// FlushWithContext is very similar to CleanUp but it will return as soon as the dataChan is empty // FlushWithContext is very similar to CleanUp but it will return as soon as the dataChan is empty
// NB: The worker will not be registered with the manager. // NB: The worker will not be registered with the manager.
func (p *WorkerPool) FlushWithContext(ctx context.Context) error { func (p *WorkerPool) FlushWithContext(ctx context.Context) error {
log.Trace("WorkerPool: %d Flush", p.qid) log.Trace("WorkerPool: %d Flush", p.qid)
paused, _ := p.IsPausedIsResumed()
for { for {
// Because select will return any case that is satisified at random we precheck here before looking at dataChan.
select { select {
case data := <-p.dataChan: case <-paused:
// Ensure that even if paused that the cancelled error is still sent
return p.contextError(ctx)
case <-p.baseCtx.Done():
return p.baseCtx.Err()
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case <-paused:
return p.contextError(ctx)
case data, ok := <-p.dataChan:
if !ok {
return nil
}
if unhandled := p.handle(data); unhandled != nil { if unhandled := p.handle(data); unhandled != nil {
log.Error("Unhandled Data whilst flushing queue %d", p.qid) log.Error("Unhandled Data whilst flushing queue %d", p.qid)
} }
@ -495,6 +525,7 @@ func (p *WorkerPool) doWork(ctx context.Context) {
paused, _ := p.IsPausedIsResumed() paused, _ := p.IsPausedIsResumed()
data := make([]Data, 0, p.batchLength) data := make([]Data, 0, p.batchLength)
for { for {
// Because select will return any case that is satisified at random we precheck here before looking at dataChan.
select { select {
case <-paused: case <-paused:
log.Trace("Worker for Queue %d Pausing", p.qid) log.Trace("Worker for Queue %d Pausing", p.qid)
@ -515,8 +546,19 @@ func (p *WorkerPool) doWork(ctx context.Context) {
log.Trace("Worker shutting down") log.Trace("Worker shutting down")
return return
} }
case <-ctx.Done():
if len(data) > 0 {
log.Trace("Handling: %d data, %v", len(data), data)
if unhandled := p.handle(data...); unhandled != nil {
log.Error("Unhandled Data in queue %d", p.qid)
}
atomic.AddInt64(&p.numInQueue, -1*int64(len(data)))
}
log.Trace("Worker shutting down")
return
default: default:
} }
select { select {
case <-paused: case <-paused:
// go back around // go back around