DockerCLI/cli/internal/oauth/api/api_test.go

429 lines
11 KiB
Go

package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"gotest.tools/v3/assert"
)
func TestGetDeviceCode(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
var clientID, audience, scope, path string
expectedState := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
VerificationURI: "aVerificationURI",
ExpiresIn: 60,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
clientID = r.FormValue("client_id")
audience = r.FormValue("audience")
scope = r.FormValue("scope")
path = r.URL.Path
jsonState, err := json.Marshal(expectedState)
assert.NilError(t, err)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
state, err := api.GetDeviceCode(context.Background(), "anAudience")
assert.NilError(t, err)
assert.DeepEqual(t, expectedState, state)
assert.Equal(t, clientID, "aClientID")
assert.Equal(t, audience, "anAudience")
assert.Equal(t, scope, "bork meow")
assert.Equal(t, path, "/oauth/device/code")
})
t.Run("error w/ description", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
jsonState, err := json.Marshal(TokenResponse{
ErrorDescription: "invalid audience",
})
assert.NilError(t, err)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
_, err := api.GetDeviceCode(context.Background(), "bad_audience")
assert.ErrorContains(t, err, "invalid audience")
})
t.Run("general error", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "an error", http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
_, err := api.GetDeviceCode(context.Background(), "anAudience")
assert.ErrorContains(t, err, "unexpected response from tenant: 500 Internal Server Error")
})
t.Run("canceled context", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
time.Sleep(2 * time.Second)
http.Error(w, "an error", http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Second)
cancel()
}()
_, err := api.GetDeviceCode(ctx, "anAudience")
assert.ErrorContains(t, err, "context canceled")
})
}
func TestWaitForDeviceToken(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
expectedToken := TokenResponse{
AccessToken: "a-real-token",
IDToken: "",
RefreshToken: "the-refresh-token",
Scope: "",
ExpiresIn: 3600,
TokenType: "",
}
var respond atomic.Bool
go func() {
time.Sleep(5 * time.Second)
respond.Store(true)
}()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")
if respond.Load() {
jsonState, err := json.Marshal(expectedToken)
assert.NilError(t, err)
w.Write(jsonState)
} else {
pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 30,
}
token, err := api.WaitForDeviceToken(context.Background(), state)
assert.NilError(t, err)
assert.DeepEqual(t, token, expectedToken)
})
t.Run("timeout", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")
pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 1,
}
_, err := api.WaitForDeviceToken(context.Background(), state)
assert.ErrorIs(t, err, ErrTimeout)
})
t.Run("canceled context", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 5,
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Second)
cancel()
}()
_, err := api.WaitForDeviceToken(ctx, state)
assert.ErrorContains(t, err, "context canceled")
})
}
func TestRevoke(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
err := api.RevokeToken(context.Background(), "v1.a-refresh-token")
assert.NilError(t, err)
})
t.Run("unexpected response", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")
w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
err := api.RevokeToken(context.Background(), "v1.a-refresh-token")
assert.ErrorContains(t, err, "unexpected response from tenant: 404 Not Found")
})
t.Run("error w/ description", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonState, err := json.Marshal(TokenResponse{
ErrorDescription: "invalid client id",
})
assert.NilError(t, err)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
err := api.RevokeToken(context.Background(), "v1.a-refresh-token")
assert.ErrorContains(t, err, "invalid client id")
})
t.Run("canceled context", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := api.RevokeToken(ctx, "v1.a-refresh-token")
assert.ErrorContains(t, err, "context canceled")
})
}
func TestGetAutoPAT(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/v2/access-tokens/desktop-generate", r.URL.Path)
assert.Equal(t, "Bearer bork", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
marshalledResponse, err := json.Marshal(patGenerateResponse{
Data: struct {
Token string `json:"token"`
}{
Token: "a-docker-pat",
},
})
assert.NilError(t, err)
w.WriteHeader(http.StatusCreated)
w.Write(marshalledResponse)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
pat, err := api.GetAutoPAT(context.Background(), ts.URL, TokenResponse{
AccessToken: "bork",
})
assert.NilError(t, err)
assert.Equal(t, "a-docker-pat", pat)
})
t.Run("general error", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
_, err := api.GetAutoPAT(context.Background(), ts.URL, TokenResponse{
AccessToken: "bork",
})
assert.ErrorContains(t, err, "unexpected response from Hub: 500 Internal Server Error")
})
t.Run("context canceled", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/v2/access-tokens/desktop-generate", r.URL.Path)
assert.Equal(t, "Bearer bork", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
marshalledResponse, err := json.Marshal(patGenerateResponse{
Data: struct {
Token string `json:"token"`
}{
Token: "a-docker-pat",
},
})
assert.NilError(t, err)
time.Sleep(500 * time.Millisecond)
w.WriteHeader(http.StatusCreated)
w.Write(marshalledResponse)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
pat, err := api.GetAutoPAT(ctx, ts.URL, TokenResponse{
AccessToken: "bork",
})
assert.ErrorContains(t, err, "context canceled")
assert.Equal(t, "", pat)
})
}