diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index da3640db1a..c8d4e5a2ee 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -34,7 +34,7 @@ func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) { } return &ConnectionHelper{ Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return commandconn.New(ctx, "ssh", append(sp.Args(), []string{"--", "docker", "system", "dial-stdio"}...)...) + return commandconn.New(ctx, "ssh", sp.Args("docker", "system", "dial-stdio")...) }, Host: "http://docker", }, nil diff --git a/cli/connhelper/ssh/ssh.go b/cli/connhelper/ssh/ssh.go index 06cb983641..bde01ae7f7 100644 --- a/cli/connhelper/ssh/ssh.go +++ b/cli/connhelper/ssh/ssh.go @@ -49,8 +49,8 @@ type Spec struct { Port string } -// Args returns args except "ssh" itself and "-- ..." -func (sp *Spec) Args() []string { +// Args returns args except "ssh" itself combined with optional additional command args +func (sp *Spec) Args(add ...string) []string { var args []string if sp.User != "" { args = append(args, "-l", sp.User) @@ -58,6 +58,7 @@ func (sp *Spec) Args() []string { if sp.Port != "" { args = append(args, "-p", sp.Port) } - args = append(args, sp.Host) + args = append(args, "--", sp.Host) + args = append(args, add...) return args } diff --git a/cli/connhelper/ssh/ssh_test.go b/cli/connhelper/ssh/ssh_test.go index 1bbfd49634..87c3c3c49c 100644 --- a/cli/connhelper/ssh/ssh_test.go +++ b/cli/connhelper/ssh/ssh_test.go @@ -16,7 +16,7 @@ func TestParseURL(t *testing.T) { { url: "ssh://foo", expectedArgs: []string{ - "foo", + "--", "foo", }, }, { @@ -24,7 +24,7 @@ func TestParseURL(t *testing.T) { expectedArgs: []string{ "-l", "me", "-p", "10022", - "foo", + "--", "foo", }, }, { @@ -53,12 +53,14 @@ func TestParseURL(t *testing.T) { }, } for _, tc := range testCases { - sp, err := ParseURL(tc.url) - if tc.expectedError == "" { - assert.NilError(t, err) - assert.Check(t, is.DeepEqual(tc.expectedArgs, sp.Args())) - } else { - assert.ErrorContains(t, err, tc.expectedError) - } + t.Run(tc.url, func(t *testing.T) { + sp, err := ParseURL(tc.url) + if tc.expectedError == "" { + assert.NilError(t, err) + assert.Check(t, is.DeepEqual(tc.expectedArgs, sp.Args())) + } else { + assert.ErrorContains(t, err, tc.expectedError) + } + }) } }