diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index 869b33b644..b1ea2dd74c 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -106,120 +106,68 @@ func TestPromptForConfirmation(t *testing.T) { }() for _, tc := range []struct { - desc string - f func(*testing.T, context.Context, chan promptResult) + desc string + f func() error + expected promptResult }{ - {"SIGINT", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - + {"SIGINT", func() error { syscall.Kill(syscall.Getpid(), syscall.SIGINT) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after SIGINT") - case r := <-c: - assert.Check(t, !r.result) - assert.ErrorContains(t, r.err, "prompt terminated") - } - }}, - {"no", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - + return nil + }, promptResult{false, command.ErrPromptTerminated}}, + {"no", func() error { _, err := fmt.Fprint(promptWriter, "n\n") - assert.NilError(t, err) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after user input `n`") - case r := <-c: - assert.Check(t, !r.result) - assert.NilError(t, r.err) - } - }}, - {"yes", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - + return err + }, promptResult{false, nil}}, + {"yes", func() error { _, err := fmt.Fprint(promptWriter, "y\n") - assert.NilError(t, err) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after user input `y`") - case r := <-c: - assert.Check(t, r.result) - assert.NilError(t, r.err) - } - }}, - {"any", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - + return err + }, promptResult{true, nil}}, + {"any", func() error { _, err := fmt.Fprint(promptWriter, "a\n") - assert.NilError(t, err) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after user input `a`") - case r := <-c: - assert.Check(t, !r.result) - assert.NilError(t, r.err) - } - }}, - {"with space", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - + return err + }, promptResult{false, nil}}, + {"with space", func() error { _, err := fmt.Fprint(promptWriter, " y\n") - assert.NilError(t, err) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after user input ` y`") - case r := <-c: - assert.Check(t, r.result) - assert.NilError(t, r.err) - } - }}, - {"reader closed", func(t *testing.T, ctx context.Context, c chan promptResult) { - t.Helper() - - assert.NilError(t, promptReader.Close()) - - select { - case <-ctx.Done(): - t.Fatal("PromptForConfirmation did not return after promptReader was closed") - case r := <-c: - assert.Check(t, !r.result) - assert.NilError(t, r.err) - } - }}, + return err + }, promptResult{true, nil}}, + {"reader closed", func() error { + return promptReader.Close() + }, promptResult{false, nil}}, } { t.Run("case="+tc.desc, func(t *testing.T) { buf.Reset() promptReader, promptWriter = io.Pipe() wroteHook := make(chan struct{}, 1) - defer close(wroteHook) promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { wroteHook <- struct{}{} }) result := make(chan promptResult, 1) - defer close(result) go func() { r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "") result <- promptResult{r, err} }() - // wait for the Prompt to write to the buffer - pollForPromptOutput(ctx, t, wroteHook) - drainChannel(ctx, wroteHook) + select { + case <-time.After(100 * time.Millisecond): + case <-wroteHook: + } assert.NilError(t, bufioWriter.Flush()) assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]") - resultCtx, resultCancel := context.WithTimeout(ctx, 100*time.Millisecond) - defer resultCancel() + // wait for the Prompt to write to the buffer + drainChannel(ctx, wroteHook) - tc.f(t, resultCtx, result) + assert.NilError(t, tc.f()) + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout waiting for prompt result") + case r := <-result: + assert.Equal(t, r, tc.expected) + } }) } } @@ -235,20 +183,3 @@ func drainChannel(ctx context.Context, ch <-chan struct{}) { } }() } - -func pollForPromptOutput(ctx context.Context, t *testing.T, wroteHook <-chan struct{}) { - t.Helper() - - ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) - defer cancel() - - for { - select { - case <-ctx.Done(): - t.Fatal("Prompt output was not written to before ctx was cancelled") - return - case <-wroteHook: - return - } - } -}