From f3c2c26b1025f3e785f142fbd8ebed5cf8265f44 Mon Sep 17 00:00:00 2001 From: Archimedes Trajano Date: Tue, 6 Aug 2024 19:30:06 -0400 Subject: [PATCH] disable pseudoterminal creation avoided the join, also did manual iteration added test, also added reflect for the DeepEqual comparison Signed-off-by: Archimedes Trajano --- cli/connhelper/connhelper.go | 12 +++++++++++ cli/connhelper/connhelper_test.go | 34 +++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index 1797abaed4..ab83ee2920 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -52,6 +52,7 @@ func getConnectionHelper(daemonURL string, sshFlags []string) (*ConnectionHelper args = append(args, "--host", "unix://"+sp.Path) } sshFlags = addSSHTimeout(sshFlags) + sshFlags = disablePseudoTerminalAllocation(sshFlags) args = append(args, "system", "dial-stdio") return commandconn.New(ctx, "ssh", append(sshFlags, sp.Args(args...)...)...) }, @@ -79,3 +80,14 @@ func addSSHTimeout(sshFlags []string) []string { } return sshFlags } + +// disablePseudoTerminalAllocation disables pseudo-terminal allocation to +// prevent SSH from executing as a login shell +func disablePseudoTerminalAllocation(sshFlags []string) []string { + for _, flag := range sshFlags { + if flag == "-T" { + return sshFlags + } + } + return append(sshFlags, "-T") +} diff --git a/cli/connhelper/connhelper_test.go b/cli/connhelper/connhelper_test.go index 14384f5c85..0d9aee0fb2 100644 --- a/cli/connhelper/connhelper_test.go +++ b/cli/connhelper/connhelper_test.go @@ -1,6 +1,7 @@ package connhelper import ( + "reflect" "testing" "gotest.tools/v3/assert" @@ -29,3 +30,36 @@ func TestSSHFlags(t *testing.T) { assert.DeepEqual(t, addSSHTimeout(tc.in), tc.out) } } + +func TestDisablePseudoTerminalAllocation(t *testing.T) { + testCases := []struct { + name string + sshFlags []string + expected []string + }{ + { + name: "No -T flag present", + sshFlags: []string{"-v", "-oStrictHostKeyChecking=no"}, + expected: []string{"-v", "-oStrictHostKeyChecking=no", "-T"}, + }, + { + name: "Already contains -T flag", + sshFlags: []string{"-v", "-T", "-oStrictHostKeyChecking=no"}, + expected: []string{"-v", "-T", "-oStrictHostKeyChecking=no"}, + }, + { + name: "Empty sshFlags", + sshFlags: []string{}, + expected: []string{"-T"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := disablePseudoTerminalAllocation(tc.sshFlags) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("expected %v, got %v", tc.expected, result) + } + }) + } +}