oauth/api: drain timer channel on each iteration

Previously, if while polling for oauth device-code login results a user
suspended the process (such as with CTRL-Z) and then restored it with
`fg`, an error might occur in the form of:

```
failed waiting for authentication: You are polling faster than the specified interval of 5 seconds.
```

This is due to our use of a `time.Ticker` here - if no receiver drains
the ticker channel (and timers/tickers use a buffered channel behind the
scenes), more than one tick will pile up, causing the program to "tick"
twice, in fast succession, after it is resumed.

The new implementation replaces the `time.Ticker` with a `time.Timer`
(`time.Ticker` is just a nice wrapper) and introduces a helper function
`resetTimer` to ensure that before every `select`, the timer is stopped
and it's channel is drained.

Signed-off-by: Laura Brehm <laurabrehm@hey.com>
(cherry picked from commit 60d0450287)
Signed-off-by: Paweł Gronowski <pawel.gronowski@docker.com>
This commit is contained in:
Laura Brehm 2024-08-28 13:09:42 +01:00 committed by Paweł Gronowski
parent c5d846735c
commit 1dfd11acc0
No known key found for this signature in database
GPG Key ID: B85EFCFE26DEF92A
2 changed files with 38 additions and 5 deletions

View File

@ -96,18 +96,32 @@ func tryDecodeOAuthError(resp *http.Response) error {
// authenticated or we have reached the time limit for authenticating (based on // authenticated or we have reached the time limit for authenticating (based on
// the response from GetDeviceCode). // the response from GetDeviceCode).
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) { func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ticker := time.NewTicker(state.IntervalDuration()) // Ticker for polling tenant for login based on the interval
// specified by the tenant response.
ticker := time.NewTimer(state.IntervalDuration())
defer ticker.Stop() defer ticker.Stop()
timeout := time.After(state.ExpiryDuration()) // The tenant tells us for as long as we can poll it for credentials
// while the user logs in through their browser. Timeout if we don't get
// credentials within this period.
timeout := time.NewTimer(state.ExpiryDuration())
defer timeout.Stop()
for { for {
resetTimer(ticker, state.IntervalDuration())
select { select {
case <-ctx.Done(): case <-ctx.Done():
// user canceled login
return TokenResponse{}, ctx.Err() return TokenResponse{}, ctx.Err()
case <-ticker.C: case <-ticker.C:
// tick, check for user login
res, err := a.getDeviceToken(ctx, state) res, err := a.getDeviceToken(ctx, state)
if err != nil { if err != nil {
return res, err if errors.Is(err, context.Canceled) {
// if the caller canceled the context, continue
// and let the select hit the ctx.Done() branch
continue
}
return TokenResponse{}, err
} }
if res.Error != nil { if res.Error != nil {
@ -119,14 +133,33 @@ func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse
} }
return res, nil return res, nil
case <-timeout: case <-timeout.C:
// login timed out
return TokenResponse{}, ErrTimeout return TokenResponse{}, ErrTimeout
} }
} }
} }
// resetTimer is a helper function thatstops, drains and resets the timer.
// This is necessary in go versions <1.23, since the timer isn't stopped +
// the timer's channel isn't drained on timer.Reset.
// See: https://go-review.googlesource.com/c/go/+/568341
// FIXME: remove/simplify this after we update to go1.23
func resetTimer(t *time.Timer, d time.Duration) {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(d)
}
// getToken calls the token endpoint of Auth0 and returns the response. // getToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) { func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
data := url.Values{ data := url.Values{
"client_id": {a.ClientID}, "client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},

View File

@ -198,7 +198,7 @@ func TestWaitForDeviceToken(t *testing.T) {
state := State{ state := State{
DeviceCode: "aDeviceCode", DeviceCode: "aDeviceCode",
UserCode: "aUserCode", UserCode: "aUserCode",
Interval: 1, Interval: 5,
ExpiresIn: 1, ExpiresIn: 1,
} }