diff --git a/cli/command/registry/login.go b/cli/command/registry/login.go index 82204bab84..ee647a6223 100644 --- a/cli/command/registry/login.go +++ b/cli/command/registry/login.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "io" + "os" + "strconv" "strings" "github.com/docker/cli/cli" @@ -141,9 +143,22 @@ func loginWithStoredCredentials(ctx context.Context, dockerCli command.Cli, auth return &response, err } +const OauthLoginEscapeHatchEnvVar = "DOCKER_CLI_DISABLE_OAUTH_LOGIN" + +func isOauthLoginDisabled() bool { + if v := os.Getenv(OauthLoginEscapeHatchEnvVar); v != "" { + enabled, err := strconv.ParseBool(v) + if err != nil { + return false + } + return enabled + } + return false +} + func loginUser(ctx context.Context, dockerCli command.Cli, opts loginOptions, defaultUsername, serverAddress string) (*registrytypes.AuthenticateOKBody, error) { // If we're logging into the index server and the user didn't provide a username or password, use the device flow - if serverAddress == registry.IndexServer && opts.user == "" && opts.password == "" { + if serverAddress == registry.IndexServer && opts.user == "" && opts.password == "" && !isOauthLoginDisabled() { response, err := loginWithDeviceCodeFlow(ctx, dockerCli) // if the error represents a failure to initiate the device-code flow, // then we fallback to regular cli credentials login diff --git a/cli/command/registry/login_test.go b/cli/command/registry/login_test.go index 4c3003cf4b..87d1fd92fc 100644 --- a/cli/command/registry/login_test.go +++ b/cli/command/registry/login_test.go @@ -228,3 +228,47 @@ func TestLoginTermination(t *testing.T) { assert.ErrorIs(t, err, command.ErrPromptTerminated) } } + +func TestIsOauthLoginDisabled(t *testing.T) { + testCases := []struct { + envVar string + disabled bool + }{ + { + envVar: "", + disabled: false, + }, + { + envVar: "bork", + disabled: false, + }, + { + envVar: "0", + disabled: false, + }, + { + envVar: "false", + disabled: false, + }, + { + envVar: "true", + disabled: true, + }, + { + envVar: "TRUE", + disabled: true, + }, + { + envVar: "1", + disabled: true, + }, + } + + for _, tc := range testCases { + t.Setenv(OauthLoginEscapeHatchEnvVar, tc.envVar) + + disabled := isOauthLoginDisabled() + + assert.Equal(t, disabled, tc.disabled) + } +} diff --git a/e2e/registry/login_test.go b/e2e/registry/login_test.go new file mode 100644 index 0000000000..523e00d065 --- /dev/null +++ b/e2e/registry/login_test.go @@ -0,0 +1,56 @@ +package registry + +import ( + "io" + "os/exec" + "strings" + "syscall" + "testing" + "time" + + "github.com/creack/pty" + "gotest.tools/v3/assert" +) + +func TestOauthLogin(t *testing.T) { + t.Parallel() + loginCmd := exec.Command("docker", "login") + + p, err := pty.Start(loginCmd) + assert.NilError(t, err) + defer func() { + _ = loginCmd.Wait() + _ = p.Close() + }() + + time.Sleep(1 * time.Second) + pid := loginCmd.Process.Pid + t.Logf("terminating PID %d", pid) + err = syscall.Kill(pid, syscall.SIGTERM) + assert.NilError(t, err) + + output, _ := io.ReadAll(p) + assert.Check(t, strings.Contains(string(output), "USING WEB BASED LOGIN"), string(output)) +} + +func TestLoginWithEscapeHatch(t *testing.T) { + t.Parallel() + loginCmd := exec.Command("docker", "login") + loginCmd.Env = append(loginCmd.Env, "DOCKER_CLI_DISABLE_OAUTH_LOGIN=1") + + p, err := pty.Start(loginCmd) + assert.NilError(t, err) + defer func() { + _ = loginCmd.Wait() + _ = p.Close() + }() + + time.Sleep(1 * time.Second) + pid := loginCmd.Process.Pid + t.Logf("terminating PID %d", pid) + err = syscall.Kill(pid, syscall.SIGTERM) + assert.NilError(t, err) + + output, _ := io.ReadAll(p) + assert.Check(t, strings.Contains(string(output), "Username:"), string(output)) +} diff --git a/e2e/registry/main_test.go b/e2e/registry/main_test.go new file mode 100644 index 0000000000..9174715188 --- /dev/null +++ b/e2e/registry/main_test.go @@ -0,0 +1,17 @@ +package registry + +import ( + "fmt" + "os" + "testing" + + "github.com/docker/cli/internal/test/environment" +) + +func TestMain(m *testing.M) { + if err := environment.Setup(); err != nil { + fmt.Println(err.Error()) + os.Exit(3) + } + os.Exit(m.Run()) +}