From 4d67ef09c8d960efdfcc3b1e6a6eb20ab63a3d94 Mon Sep 17 00:00:00 2001 From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> Date: Tue, 18 Jun 2024 14:18:49 +0200 Subject: [PATCH] fix: ctx cancellation on login prompt Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> (cherry picked from commit c15ade0c647606b769deb009a3c2e508efa71e67) Signed-off-by: Sebastiaan van Stijn --- cli/command/registry.go | 47 +++++++---------------- cli/command/registry/login.go | 2 +- cli/command/registry/login_test.go | 41 ++++++++++++++++++++ cli/command/utils.go | 43 +++++++++++++++++++++ cli/command/utils_test.go | 61 ++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 34 deletions(-) diff --git a/cli/command/registry.go b/cli/command/registry.go index ba97861a6b..89fa0cca59 100644 --- a/cli/command/registry.go +++ b/cli/command/registry.go @@ -1,10 +1,8 @@ package command import ( - "bufio" "context" "fmt" - "io" "os" "runtime" "strings" @@ -18,7 +16,6 @@ import ( "github.com/docker/docker/api/types" registrytypes "github.com/docker/docker/api/types/registry" "github.com/docker/docker/registry" - "github.com/moby/term" "github.com/pkg/errors" ) @@ -44,7 +41,7 @@ func RegistryAuthenticationPrivilegedFunc(cli Cli, index *registrytypes.IndexInf default: } - err = ConfigureAuth(cli, "", "", &authConfig, isDefaultRegistry) + err = ConfigureAuth(ctx, cli, "", "", &authConfig, isDefaultRegistry) if err != nil { return "", err } @@ -90,7 +87,7 @@ func GetDefaultAuthConfig(cfg *configfile.ConfigFile, checkCredStore bool, serve } // ConfigureAuth handles prompting of user's username and password if needed -func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error { +func ConfigureAuth(ctx context.Context, cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error { // On Windows, force the use of the regular OS stdin stream. // // See: @@ -125,9 +122,15 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes fmt.Fprintln(cli.Out()) } } - promptWithDefault(cli.Out(), "Username", authconfig.Username) + + var prompt string + if authconfig.Username == "" { + prompt = "Username: " + } else { + prompt = fmt.Sprintf("Username (%s): ", authconfig.Username) + } var err error - flUser, err = readInput(cli.In()) + flUser, err = PromptForInput(ctx, cli.In(), cli.Out(), prompt) if err != nil { return err } @@ -139,16 +142,13 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes return errors.Errorf("Error: Non-null Username Required") } if flPassword == "" { - oldState, err := term.SaveState(cli.In().FD()) + restoreInput, err := DisableInputEcho(cli.In()) if err != nil { return err } - fmt.Fprintf(cli.Out(), "Password: ") - _ = term.DisableEcho(cli.In().FD(), oldState) - defer func() { - _ = term.RestoreTerminal(cli.In().FD(), oldState) - }() - flPassword, err = readInput(cli.In()) + defer restoreInput() + + flPassword, err = PromptForInput(ctx, cli.In(), cli.Out(), "Password: ") if err != nil { return err } @@ -164,25 +164,6 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes return nil } -// readInput reads, and returns user input from in. It tries to return a -// single line, not including the end-of-line bytes, and trims leading -// and trailing whitespace. -func readInput(in io.Reader) (string, error) { - line, _, err := bufio.NewReader(in).ReadLine() - if err != nil { - return "", errors.Wrap(err, "error while reading input") - } - return strings.TrimSpace(string(line)), nil -} - -func promptWithDefault(out io.Writer, prompt string, configDefault string) { - if configDefault == "" { - fmt.Fprintf(out, "%s: ", prompt) - } else { - fmt.Fprintf(out, "%s (%s): ", prompt, configDefault) - } -} - // RetrieveAuthTokenFromImage retrieves an encoded auth token given a complete // image. The auth configuration is serialized as a base64url encoded RFC4648, // section 5) JSON string for sending through the X-Registry-Auth header. diff --git a/cli/command/registry/login.go b/cli/command/registry/login.go index 69476b6ea0..75cf7c5265 100644 --- a/cli/command/registry/login.go +++ b/cli/command/registry/login.go @@ -121,7 +121,7 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err response, err = loginWithCredStoreCreds(ctx, dockerCli, &authConfig) } if err != nil || authConfig.Username == "" || authConfig.Password == "" { - err = command.ConfigureAuth(dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry) + err = command.ConfigureAuth(ctx, dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry) if err != nil { return err } diff --git a/cli/command/registry/login_test.go b/cli/command/registry/login_test.go index 9d87ad7667..a6ee4a7f60 100644 --- a/cli/command/registry/login_test.go +++ b/cli/command/registry/login_test.go @@ -6,7 +6,10 @@ import ( "errors" "fmt" "testing" + "time" + "github.com/creack/pty" + "github.com/docker/cli/cli/command" configtypes "github.com/docker/cli/cli/config/types" "github.com/docker/cli/cli/streams" "github.com/docker/cli/internal/test" @@ -185,3 +188,41 @@ func TestRunLogin(t *testing.T) { }) } } + +func TestLoginTermination(t *testing.T) { + p, tty, err := pty.Open() + assert.NilError(t, err) + + t.Cleanup(func() { + _ = tty.Close() + _ = p.Close() + }) + + cli := test.NewFakeCli(&fakeClient{}, func(fc *test.FakeCli) { + fc.SetOut(streams.NewOut(tty)) + fc.SetIn(streams.NewIn(tty)) + }) + tmpFile := fs.NewFile(t, "test-login-termination") + defer tmpFile.Remove() + + configFile := cli.ConfigFile() + configFile.Filename = tmpFile.Path() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + runErr := make(chan error) + go func() { + runErr <- runLogin(ctx, cli, loginOptions{}) + }() + + // Let the prompt get canceled by the context + cancel() + + select { + case <-time.After(1 * time.Second): + t.Fatal("timed out after 1 second. `runLogin` did not return") + case err := <-runErr: + assert.ErrorIs(t, err, command.ErrPromptTerminated) + } +} diff --git a/cli/command/utils.go b/cli/command/utils.go index 48d2c4250a..b206db8edd 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker/api/types/versions" "github.com/docker/docker/errdefs" "github.com/moby/sys/sequential" + "github.com/moby/term" "github.com/pkg/errors" "github.com/spf13/pflag" ) @@ -76,6 +77,48 @@ func PrettyPrint(i any) string { var ErrPromptTerminated = errdefs.Cancelled(errors.New("prompt terminated")) +// DisableInputEcho disables input echo on the provided streams.In. +// This is useful when the user provides sensitive information like passwords. +// The function returns a restore function that should be called to restore the +// terminal state. +func DisableInputEcho(ins *streams.In) (restore func() error, err error) { + oldState, err := term.SaveState(ins.FD()) + if err != nil { + return nil, err + } + restore = func() error { + return term.RestoreTerminal(ins.FD(), oldState) + } + return restore, term.DisableEcho(ins.FD(), oldState) +} + +// PromptForInput requests input from the user. +// +// If the user terminates the CLI with SIGINT or SIGTERM while the prompt is +// active, the prompt will return an empty string ("") with an ErrPromptTerminated error. +// When the prompt returns an error, the caller should propagate the error up +// the stack and close the io.Reader used for the prompt which will prevent the +// background goroutine from blocking indefinitely. +func PromptForInput(ctx context.Context, in io.Reader, out io.Writer, message string) (string, error) { + _, _ = fmt.Fprint(out, message) + + result := make(chan string) + go func() { + scanner := bufio.NewScanner(in) + if scanner.Scan() { + result <- strings.TrimSpace(scanner.Text()) + } + }() + + select { + case <-ctx.Done(): + _, _ = fmt.Fprintln(out, "") + return "", ErrPromptTerminated + case r := <-result: + return r, nil + } +} + // PromptForConfirmation requests and checks confirmation from the user. // This will display the provided message followed by ' [y/N] '. If the user // input 'y' or 'Y' it returns true otherwise false. If no message is provided, diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index 1566067f3b..2f2140757e 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/docker/cli/cli/command" + "github.com/docker/cli/cli/streams" "github.com/docker/cli/internal/test" "github.com/pkg/errors" "gotest.tools/v3/assert" @@ -80,6 +81,66 @@ func TestValidateOutputPath(t *testing.T) { } } +func TestPromptForInput(t *testing.T) { + t.Run("case=cancelling the context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + reader, _ := io.Pipe() + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + wroteHook := make(chan struct{}, 1) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + promptErr := make(chan error, 1) + go func() { + _, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") + promptErr <- err + }() + + select { + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for prompt to write to buffer") + case <-wroteHook: + cancel() + } + + select { + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for prompt to be canceled") + case err := <-promptErr: + assert.ErrorIs(t, err, command.ErrPromptTerminated) + } + }) + + t.Run("case=user input should be properly trimmed", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + reader, writer := io.Pipe() + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + wroteHook := make(chan struct{}, 1) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + go func() { + <-wroteHook + writer.Write([]byte(" foo \n")) + }() + + answer, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") + assert.NilError(t, err) + assert.Equal(t, answer, "foo") + }) +} + func TestPromptForConfirmation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel)