package flightcontrol import ( "context" "io" "runtime" "sort" "sync" "time" "github.com/moby/buildkit/util/progress" "github.com/pkg/errors" ) // flightcontrol is like singleflight but with support for cancellation and // nested progress reporting var ( errRetry = errors.Errorf("retry") errRetryTimeout = errors.Errorf("exceeded retry timeout") ) type contextKeyT string var contextKey = contextKeyT("buildkit/util/flightcontrol.progress") // Group is a flightcontrol synchronization group type Group struct { mu sync.Mutex // protects m m map[string]*call // lazily initialized } // Do executes a context function syncronized by the key func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, err error) { var backoff time.Duration for { v, err = g.do(ctx, key, fn) if err == nil || !errors.Is(err, errRetry) { return v, err } // backoff logic if backoff >= 3*time.Second { err = errors.Wrapf(errRetryTimeout, "flightcontrol") return v, err } runtime.Gosched() if backoff > 0 { time.Sleep(backoff) backoff *= 2 } else { backoff = time.Millisecond } } } func (g *Group) do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (interface{}, error) { g.mu.Lock() if g.m == nil { g.m = make(map[string]*call) } if c, ok := g.m[key]; ok { // register 2nd waiter g.mu.Unlock() return c.wait(ctx) } c := newCall(fn) g.m[key] = c go func() { // cleanup after a caller has returned <-c.ready g.mu.Lock() delete(g.m, key) g.mu.Unlock() close(c.cleaned) }() g.mu.Unlock() return c.wait(ctx) } type call struct { mu sync.Mutex result interface{} err error ready chan struct{} cleaned chan struct{} ctx *sharedContext ctxs []context.Context fn func(ctx context.Context) (interface{}, error) once sync.Once closeProgressWriter func() progressState *progressState progressCtx context.Context } func newCall(fn func(ctx context.Context) (interface{}, error)) *call { c := &call{ fn: fn, ready: make(chan struct{}), cleaned: make(chan struct{}), progressState: newProgressState(), } ctx := newContext(c) // newSharedContext pr, pctx, closeProgressWriter := progress.NewContext(context.Background()) c.progressCtx = pctx c.ctx = ctx c.closeProgressWriter = closeProgressWriter go c.progressState.run(pr) // TODO: remove this, wrap writer instead return c } func (c *call) run() { defer c.closeProgressWriter() ctx, cancel := context.WithCancel(c.ctx) defer cancel() v, err := c.fn(ctx) c.mu.Lock() c.result = v c.err = err c.mu.Unlock() close(c.ready) } func (c *call) wait(ctx context.Context) (v interface{}, err error) { c.mu.Lock() // detect case where caller has just returned, let it clean up before select { case <-c.ready: c.mu.Unlock() <-c.cleaned return nil, errRetry case <-c.ctx.done: // could return if no error c.mu.Unlock() <-c.cleaned return nil, errRetry default: } pw, ok, ctx := progress.FromContext(ctx) if ok { c.progressState.add(pw) } ctx, cancel := context.WithCancel(ctx) defer cancel() c.ctxs = append(c.ctxs, ctx) c.mu.Unlock() go c.once.Do(c.run) select { case <-ctx.Done(): if c.ctx.checkDone() { // if this cancelled the last context, then wait for function to shut down // and don't accept any more callers <-c.ready return c.result, c.err } if ok { c.progressState.close(pw) } return nil, ctx.Err() case <-c.ready: return c.result, c.err // shared not implemented yet } } func (c *call) Deadline() (deadline time.Time, ok bool) { c.mu.Lock() defer c.mu.Unlock() for _, ctx := range c.ctxs { select { case <-ctx.Done(): default: dl, ok := ctx.Deadline() if ok { return dl, ok } } } return time.Time{}, false } func (c *call) Done() <-chan struct{} { return c.ctx.done } func (c *call) Err() error { select { case <-c.ctx.Done(): return c.ctx.err default: return nil } } func (c *call) Value(key interface{}) interface{} { if key == contextKey { return c.progressState } c.mu.Lock() defer c.mu.Unlock() ctx := c.progressCtx select { case <-ctx.Done(): default: if v := ctx.Value(key); v != nil { return v } } if len(c.ctxs) > 0 { ctx = c.ctxs[0] select { case <-ctx.Done(): default: if v := ctx.Value(key); v != nil { return v } } } return nil } type sharedContext struct { *call done chan struct{} err error } func newContext(c *call) *sharedContext { return &sharedContext{call: c, done: make(chan struct{})} } func (sc *sharedContext) checkDone() bool { sc.mu.Lock() select { case <-sc.done: sc.mu.Unlock() return true default: } var err error for _, ctx := range sc.ctxs { select { case <-ctx.Done(): err = ctx.Err() default: sc.mu.Unlock() return false } } sc.err = err close(sc.done) sc.mu.Unlock() return true } type rawProgressWriter interface { WriteRawProgress(*progress.Progress) error Close() error } type progressState struct { mu sync.Mutex items map[string]*progress.Progress writers []rawProgressWriter done bool } func newProgressState() *progressState { return &progressState{ items: make(map[string]*progress.Progress), } } func (ps *progressState) run(pr progress.Reader) { for { p, err := pr.Read(context.TODO()) if err != nil { if err == io.EOF { ps.mu.Lock() ps.done = true ps.mu.Unlock() for _, w := range ps.writers { w.Close() } } return } ps.mu.Lock() for _, p := range p { for _, w := range ps.writers { w.WriteRawProgress(p) } ps.items[p.ID] = p } ps.mu.Unlock() } } func (ps *progressState) add(pw progress.Writer) { rw, ok := pw.(rawProgressWriter) if !ok { return } ps.mu.Lock() plist := make([]*progress.Progress, 0, len(ps.items)) for _, p := range ps.items { plist = append(plist, p) } sort.Slice(plist, func(i, j int) bool { return plist[i].Timestamp.Before(plist[j].Timestamp) }) for _, p := range plist { rw.WriteRawProgress(p) } if ps.done { rw.Close() } else { ps.writers = append(ps.writers, rw) } ps.mu.Unlock() } func (ps *progressState) close(pw progress.Writer) { rw, ok := pw.(rawProgressWriter) if !ok { return } ps.mu.Lock() for i, w := range ps.writers { if w == rw { w.Close() ps.writers = append(ps.writers[:i], ps.writers[i+1:]...) break } } ps.mu.Unlock() }