DockerCLI/cli/internal/oauth/api/api.go

262 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.21
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"
"github.com/docker/cli/cli/version"
)
type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
RevokeToken(ctx context.Context, refreshToken string) error
GetAutoPAT(ctx context.Context, audience string, res TokenResponse) (string, error)
}
// API represents API interactions with Auth0.
type API struct {
// TenantURL is the base used for each request to Auth0.
TenantURL string
// ClientID is the client ID for the application to auth with the tenant.
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
}
// TokenResponse represents the response of the /oauth/token route.
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Error *string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}
var ErrTimeout = errors.New("timed out waiting for device token")
// GetDeviceCode initiates the device-code auth flow with the tenant.
// The state returned contains the device code that the user must use to
// authenticate, as well as the URL to visit, etc.
func (a API) GetDeviceCode(ctx context.Context, audience string) (State, error) {
data := url.Values{
"client_id": {a.ClientID},
"audience": {audience},
"scope": {strings.Join(a.Scopes, " ")},
}
deviceCodeURL := a.TenantURL + "/oauth/device/code"
resp, err := postForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return State{}, err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return State{}, tryDecodeOAuthError(resp)
}
var state State
err = json.NewDecoder(resp.Body).Decode(&state)
if err != nil {
return state, fmt.Errorf("failed to get device code: %w", err)
}
return state, nil
}
func tryDecodeOAuthError(resp *http.Response) error {
var body map[string]any
if err := json.NewDecoder(resp.Body).Decode(&body); err == nil {
if errorDescription, ok := body["error_description"].(string); ok {
return errors.New(errorDescription)
}
}
return errors.New("unexpected response from tenant: " + resp.Status)
}
// WaitForDeviceToken polls the tenant to get access/refresh tokens for the user.
// This should be called after GetDeviceCode, and will block until the user has
// authenticated or we have reached the time limit for authenticating (based on
// the response from GetDeviceCode).
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
// Ticker for polling tenant for login based on the interval
// specified by the tenant response.
ticker := time.NewTimer(state.IntervalDuration())
defer ticker.Stop()
// The tenant tells us for as long as we can poll it for credentials
// while the user logs in through their browser. Timeout if we don't get
// credentials within this period.
timeout := time.NewTimer(state.ExpiryDuration())
defer timeout.Stop()
for {
resetTimer(ticker, state.IntervalDuration())
select {
case <-ctx.Done():
// user canceled login
return TokenResponse{}, ctx.Err()
case <-ticker.C:
// tick, check for user login
res, err := a.getDeviceToken(ctx, state)
if err != nil {
if errors.Is(err, context.Canceled) {
// if the caller canceled the context, continue
// and let the select hit the ctx.Done() branch
continue
}
return TokenResponse{}, err
}
if res.Error != nil {
if *res.Error == "authorization_pending" {
continue
}
return res, errors.New(res.ErrorDescription)
}
return res, nil
case <-timeout.C:
// login timed out
return TokenResponse{}, ErrTimeout
}
}
}
// resetTimer is a helper function thatstops, drains and resets the timer.
// This is necessary in go versions <1.23, since the timer isn't stopped +
// the timer's channel isn't drained on timer.Reset.
// See: https://go-review.googlesource.com/c/go/+/568341
// FIXME: remove/simplify this after we update to go1.23
func resetTimer(t *time.Timer, d time.Duration) {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(d)
}
// getToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
data := url.Values{
"client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {state.DeviceCode},
}
oauthTokenURL := a.TenantURL + "/oauth/token"
resp, err := postForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return TokenResponse{}, fmt.Errorf("failed to get tokens: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// this endpoint returns a 403 with an `authorization_pending` error until the
// user has authenticated, so we don't check the status code here and instead
// decode the response and check for the error.
var res TokenResponse
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return res, fmt.Errorf("failed to decode response: %w", err)
}
return res, nil
}
// RevokeToken revokes a refresh token with the tenant so that it can no longer
// be used to get new tokens.
func (a API) RevokeToken(ctx context.Context, refreshToken string) error {
data := url.Values{
"client_id": {a.ClientID},
"token": {refreshToken},
}
revokeURL := a.TenantURL + "/oauth/revoke"
resp, err := postForm(ctx, revokeURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return tryDecodeOAuthError(resp)
}
return nil
}
func postForm(ctx context.Context, reqURL string, data io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, data)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
cliVersion := strings.ReplaceAll(version.Version, ".", "_")
req.Header.Set("User-Agent", fmt.Sprintf("docker-cli:%s:%s-%s", cliVersion, runtime.GOOS, runtime.GOARCH))
return http.DefaultClient.Do(req)
}
func (a API) GetAutoPAT(ctx context.Context, audience string, res TokenResponse) (string, error) {
patURL := audience + "/v2/access-tokens/desktop-generate"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, patURL, nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+res.AccessToken)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unexpected response from Hub: %s", resp.Status)
}
var response patGenerateResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", err
}
return response.Data.Token, nil
}
type patGenerateResponse struct {
Data struct {
Token string `json:"token"`
}
}