diff --git a/cli/command/cli.go b/cli/command/cli.go index 2e051440fd..d325d5156c 100644 --- a/cli/command/cli.go +++ b/cli/command/cli.go @@ -3,26 +3,25 @@ package command import ( "context" "io" - "net" - "net/http" "os" "path/filepath" "runtime" "strconv" - "time" "github.com/docker/cli/cli" "github.com/docker/cli/cli/config" cliconfig "github.com/docker/cli/cli/config" "github.com/docker/cli/cli/config/configfile" - "github.com/docker/cli/cli/connhelper" + dcontext "github.com/docker/cli/cli/context" + "github.com/docker/cli/cli/context/docker" + kubcontext "github.com/docker/cli/cli/context/kubernetes" + "github.com/docker/cli/cli/context/store" cliflags "github.com/docker/cli/cli/flags" manifeststore "github.com/docker/cli/cli/manifest/store" registryclient "github.com/docker/cli/cli/registry/client" "github.com/docker/cli/cli/trust" dopts "github.com/docker/cli/opts" clitypes "github.com/docker/cli/types" - "github.com/docker/docker/api" "github.com/docker/docker/api/types" registrytypes "github.com/docker/docker/api/types/registry" "github.com/docker/docker/client" @@ -34,6 +33,9 @@ import ( "github.com/theupdateframework/notary/passphrase" ) +// ContextDockerHost is the reported context when DOCKER_HOST env var or -H flag is set +const ContextDockerHost = "" + // Streams is an interface which exposes the standard input and output streams type Streams interface { In() *InStream @@ -57,6 +59,9 @@ type Cli interface { RegistryClient(bool) registryclient.RegistryClient ContentTrustEnabled() bool NewContainerizedEngineClient(sockPath string) (clitypes.ContainerizedClient, error) + ContextStore() store.Store + CurrentContext() string + StackOrchestrator(flagValue string) (Orchestrator, error) } // DockerCli is an instance the docker command line client. @@ -71,8 +76,16 @@ type DockerCli struct { clientInfo ClientInfo contentTrust bool newContainerizeClient func(string) (clitypes.ContainerizedClient, error) + contextStore store.Store + currentContext string } +var storeConfig = store.NewConfig( + func() interface{} { return &DockerContext{} }, + store.EndpointTypeGetter(docker.DockerEndpoint, func() interface{} { return &docker.EndpointMeta{} }), + store.EndpointTypeGetter(kubcontext.KubernetesEndpoint, func() interface{} { return &kubcontext.EndpointMeta{} }), +) + // DefaultVersion returns api.defaultVersion or DOCKER_API_VERSION if specified. func (cli *DockerCli) DefaultVersion() string { return cli.clientInfo.DefaultVersion @@ -167,14 +180,23 @@ func (cli *DockerCli) RegistryClient(allowInsecure bool) registryclient.Registry // line flags are parsed. func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions) error { cli.configFile = cliconfig.LoadDefaultConfigFile(cli.err) - var err error - cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile) + cli.contextStore = store.New(cliconfig.ContextStoreDir(), storeConfig) + cli.currentContext, err = resolveContextName(opts.Common, cli.configFile) + if err != nil { + return err + } + endpoint, err := resolveDockerEndpoint(cli.contextStore, cli.currentContext, opts.Common) + if err != nil { + return errors.Wrap(err, "unable to resolve docker endpoint") + } + + cli.client, err = newAPIClientFromEndpoint(endpoint, cli.configFile) if tlsconfig.IsErrEncryptedKey(err) { passRetriever := passphrase.PromptRetrieverWithInOut(cli.In(), cli.Out(), nil) newClient := func(password string) (client.APIClient, error) { - opts.Common.TLSOptions.Passphrase = password - return NewAPIClientFromFlags(opts.Common, cli.configFile) + endpoint.TLSPassword = password + return newAPIClientFromEndpoint(endpoint, cli.configFile) } cli.client, err = getClientWithPassword(passRetriever, newClient) } @@ -198,6 +220,75 @@ func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions) error { return nil } +// NewAPIClientFromFlags creates a new APIClient from command line flags +func NewAPIClientFromFlags(opts *cliflags.CommonOptions, configFile *configfile.ConfigFile) (client.APIClient, error) { + store := store.New(cliconfig.ContextStoreDir(), storeConfig) + contextName, err := resolveContextName(opts, configFile) + if err != nil { + return nil, err + } + endpoint, err := resolveDockerEndpoint(store, contextName, opts) + if err != nil { + return nil, errors.Wrap(err, "unable to resolve docker endpoint") + } + return newAPIClientFromEndpoint(endpoint, configFile) +} + +func newAPIClientFromEndpoint(ep docker.Endpoint, configFile *configfile.ConfigFile) (client.APIClient, error) { + clientOpts, err := ep.ClientOpts() + if err != nil { + return nil, err + } + customHeaders := configFile.HTTPHeaders + if customHeaders == nil { + customHeaders = map[string]string{} + } + customHeaders["User-Agent"] = UserAgent() + clientOpts = append(clientOpts, client.WithHTTPHeaders(customHeaders)) + return client.NewClientWithOpts(clientOpts...) +} + +func resolveDockerEndpoint(s store.Store, contextName string, opts *cliflags.CommonOptions) (docker.Endpoint, error) { + if contextName != ContextDockerHost { + ctxMeta, err := s.GetContextMetadata(contextName) + if err != nil { + return docker.Endpoint{}, err + } + epMeta, err := docker.EndpointFromContext(ctxMeta) + if err != nil { + return docker.Endpoint{}, err + } + return epMeta.WithTLSData(s, contextName) + } + host, err := getServerHost(opts.Hosts, opts.TLSOptions) + if err != nil { + return docker.Endpoint{}, err + } + + var ( + skipTLSVerify bool + tlsData *dcontext.TLSData + ) + + if opts.TLSOptions != nil { + skipTLSVerify = opts.TLSOptions.InsecureSkipVerify + tlsData, err = dcontext.TLSDataFromFiles(opts.TLSOptions.CAFile, opts.TLSOptions.CertFile, opts.TLSOptions.KeyFile) + if err != nil { + return docker.Endpoint{}, err + } + } + + return docker.Endpoint{ + EndpointMeta: docker.EndpointMeta{ + EndpointMetaBase: dcontext.EndpointMetaBase{ + Host: host, + SkipTLSVerify: skipTLSVerify, + }, + }, + TLSData: tlsData, + }, nil +} + func isEnabled(value string) (bool, error) { switch value { case "enabled": @@ -253,6 +344,51 @@ func (cli *DockerCli) NewContainerizedEngineClient(sockPath string) (clitypes.Co return cli.newContainerizeClient(sockPath) } +// ContextStore returns the ContextStore +func (cli *DockerCli) ContextStore() store.Store { + return cli.contextStore +} + +// CurrentContext returns the current context name +func (cli *DockerCli) CurrentContext() string { + return cli.currentContext +} + +// StackOrchestrator resolves which stack orchestrator is in use +func (cli *DockerCli) StackOrchestrator(flagValue string) (Orchestrator, error) { + var ctxOrchestrator string + + configFile := cli.configFile + if configFile == nil { + configFile = cliconfig.LoadDefaultConfigFile(cli.Err()) + } + + currentContext := cli.CurrentContext() + if currentContext == "" { + currentContext = configFile.CurrentContext + } + if currentContext == "" { + currentContext = ContextDockerHost + } + if currentContext != ContextDockerHost { + contextstore := cli.contextStore + if contextstore == nil { + contextstore = store.New(cliconfig.ContextStoreDir(), storeConfig) + } + ctxRaw, err := contextstore.GetContextMetadata(currentContext) + if err != nil { + return "", err + } + ctxMeta, err := GetDockerContext(ctxRaw) + if err != nil { + return "", err + } + ctxOrchestrator = string(ctxMeta.StackOrchestrator) + } + + return GetStackOrchestrator(flagValue, ctxOrchestrator, configFile.StackOrchestrator, cli.Err()) +} + // ServerInfo stores details about the supported features and platform of the // server type ServerInfo struct { @@ -272,51 +408,6 @@ func NewDockerCli(in io.ReadCloser, out, err io.Writer, isTrusted bool, containe return &DockerCli{in: NewInStream(in), out: NewOutStream(out), err: err, contentTrust: isTrusted, newContainerizeClient: containerizedFn} } -// NewAPIClientFromFlags creates a new APIClient from command line flags -func NewAPIClientFromFlags(opts *cliflags.CommonOptions, configFile *configfile.ConfigFile) (client.APIClient, error) { - host, err := getServerHost(opts.Hosts, opts.TLSOptions) - if err != nil { - return &client.Client{}, err - } - var clientOpts []func(*client.Client) error - helper, err := connhelper.GetConnectionHelper(host) - if err != nil { - return &client.Client{}, err - } - if helper == nil { - clientOpts = append(clientOpts, withHTTPClient(opts.TLSOptions)) - clientOpts = append(clientOpts, client.WithHost(host)) - } else { - clientOpts = append(clientOpts, func(c *client.Client) error { - httpClient := &http.Client{ - // No tls - // No proxy - Transport: &http.Transport{ - DialContext: helper.Dialer, - }, - } - return client.WithHTTPClient(httpClient)(c) - }) - clientOpts = append(clientOpts, client.WithHost(helper.Host)) - clientOpts = append(clientOpts, client.WithDialContext(helper.Dialer)) - } - - customHeaders := configFile.HTTPHeaders - if customHeaders == nil { - customHeaders = map[string]string{} - } - customHeaders["User-Agent"] = UserAgent() - clientOpts = append(clientOpts, client.WithHTTPHeaders(customHeaders)) - - verStr := api.DefaultVersion - if tmpStr := os.Getenv("DOCKER_API_VERSION"); tmpStr != "" { - verStr = tmpStr - } - clientOpts = append(clientOpts, client.WithVersion(verStr)) - - return client.NewClientWithOpts(clientOpts...) -} - func getServerHost(hosts []string, tlsOptions *tlsconfig.Options) (string, error) { var host string switch len(hosts) { @@ -331,35 +422,37 @@ func getServerHost(hosts []string, tlsOptions *tlsconfig.Options) (string, error return dopts.ParseHost(tlsOptions != nil, host) } -func withHTTPClient(tlsOpts *tlsconfig.Options) func(*client.Client) error { - return func(c *client.Client) error { - if tlsOpts == nil { - // Use the default HTTPClient - return nil - } - - opts := *tlsOpts - opts.ExclusiveRootPools = true - tlsConfig, err := tlsconfig.Client(opts) - if err != nil { - return err - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - DialContext: (&net.Dialer{ - KeepAlive: 30 * time.Second, - Timeout: 30 * time.Second, - }).DialContext, - }, - CheckRedirect: client.CheckRedirect, - } - return client.WithHTTPClient(httpClient)(c) - } -} - // UserAgent returns the user agent string used for making API requests func UserAgent() string { return "Docker-Client/" + cli.Version + " (" + runtime.GOOS + ")" } + +// resolveContextName resolves the current context name with the following rules: +// - setting both --context and --host flags is ambiguous +// - if --context is set, use this value +// - if --host flag or DOCKER_HOST is set, fallbacks to use the same logic as before context-store was added +// for backward compatibility with existing scripts +// - if DOCKER_CONTEXT is set, use this value +// - if Config file has a globally set "CurrentContext", use this value +// - fallbacks to default HOST, uses TLS config from flags/env vars +func resolveContextName(opts *cliflags.CommonOptions, config *configfile.ConfigFile) (string, error) { + if opts.Context != "" && len(opts.Hosts) > 0 { + return "", errors.New("Conflicting options: either specify --host or --context, not bot") + } + if opts.Context != "" { + return opts.Context, nil + } + if len(opts.Hosts) > 0 { + return ContextDockerHost, nil + } + if _, present := os.LookupEnv("DOCKER_HOST"); present { + return ContextDockerHost, nil + } + if ctxName, ok := os.LookupEnv("DOCKER_CONTEXT"); ok { + return ctxName, nil + } + if config != nil && config.CurrentContext != "" { + return config.CurrentContext, nil + } + return ContextDockerHost, nil +} diff --git a/cli/command/cli_test.go b/cli/command/cli_test.go index 6ac120fa22..71029e9107 100644 --- a/cli/command/cli_test.go +++ b/cli/command/cli_test.go @@ -66,6 +66,7 @@ func TestNewAPIClientFromFlagsForDefaultSchema(t *testing.T) { func TestNewAPIClientFromFlagsWithAPIVersionFromEnv(t *testing.T) { customVersion := "v3.3.3" defer env.Patch(t, "DOCKER_API_VERSION", customVersion)() + defer env.Patch(t, "DOCKER_HOST", ":2375")() opts := &flags.CommonOptions{} configFile := &configfile.ConfigFile{} diff --git a/cli/command/context.go b/cli/command/context.go new file mode 100644 index 0000000000..2b4c76ed15 --- /dev/null +++ b/cli/command/context.go @@ -0,0 +1,27 @@ +package command + +import ( + "errors" + + "github.com/docker/cli/cli/context/store" +) + +// DockerContext is a typed representation of what we put in Context metadata +type DockerContext struct { + Description string `json:"description,omitempty"` + StackOrchestrator Orchestrator `json:"stack_orchestrator,omitempty"` +} + +// GetDockerContext extracts metadata from stored context metadata +func GetDockerContext(storeMetadata store.ContextMetadata) (DockerContext, error) { + if storeMetadata.Metadata == nil { + // can happen if we save endpoints before assigning a context metadata + // it is totally valid, and we should return a default initialized value + return DockerContext{}, nil + } + res, ok := storeMetadata.Metadata.(DockerContext) + if !ok { + return DockerContext{}, errors.New("context metadata is not a valid DockerContext") + } + return res, nil +} diff --git a/cli/command/orchestrator.go b/cli/command/orchestrator.go index 5f3e446205..c71b3f8963 100644 --- a/cli/command/orchestrator.go +++ b/cli/command/orchestrator.go @@ -44,7 +44,7 @@ func normalize(value string) (Orchestrator, error) { return OrchestratorKubernetes, nil case "swarm": return OrchestratorSwarm, nil - case "": + case "", "unset": return orchestratorUnset, nil case "all": return OrchestratorAll, nil @@ -53,9 +53,14 @@ func normalize(value string) (Orchestrator, error) { } } +// NormalizeOrchestrator parses an orchestrator value and checks if it is valid +func NormalizeOrchestrator(value string) (Orchestrator, error) { + return normalize(value) +} + // GetStackOrchestrator checks DOCKER_STACK_ORCHESTRATOR environment variable and configuration file // orchestrator value and returns user defined Orchestrator. -func GetStackOrchestrator(flagValue, value string, stderr io.Writer) (Orchestrator, error) { +func GetStackOrchestrator(flagValue, contextValue, globalDefault string, stderr io.Writer) (Orchestrator, error) { // Check flag if o, err := normalize(flagValue); o != orchestratorUnset { return o, err @@ -68,8 +73,10 @@ func GetStackOrchestrator(flagValue, value string, stderr io.Writer) (Orchestrat if o, err := normalize(env); o != orchestratorUnset { return o, err } - // Check specified orchestrator - if o, err := normalize(value); o != orchestratorUnset { + if o, err := normalize(contextValue); o != orchestratorUnset { + return o, err + } + if o, err := normalize(globalDefault); o != orchestratorUnset { return o, err } // Nothing set, use default orchestrator diff --git a/cli/command/orchestrator_test.go b/cli/command/orchestrator_test.go index 322e8a9169..141c27e434 100644 --- a/cli/command/orchestrator_test.go +++ b/cli/command/orchestrator_test.go @@ -2,87 +2,82 @@ package command import ( "io/ioutil" - "os" "testing" - cliconfig "github.com/docker/cli/cli/config" - "github.com/docker/cli/cli/flags" "gotest.tools/assert" is "gotest.tools/assert/cmp" "gotest.tools/env" - "gotest.tools/fs" ) func TestOrchestratorSwitch(t *testing.T) { - defaultVersion := "v0.00" - var testcases = []struct { doc string - configfile string + globalOrchestrator string envOrchestrator string flagOrchestrator string + contextOrchestrator string expectedOrchestrator string expectedKubernetes bool expectedSwarm bool }{ { - doc: "default", - configfile: `{ - }`, + doc: "default", expectedOrchestrator: "swarm", expectedKubernetes: false, expectedSwarm: true, }, { - doc: "kubernetesConfigFile", - configfile: `{ - "stackOrchestrator": "kubernetes" - }`, + doc: "kubernetesConfigFile", + globalOrchestrator: "kubernetes", expectedOrchestrator: "kubernetes", expectedKubernetes: true, expectedSwarm: false, }, { - doc: "kubernetesEnv", - configfile: `{ - }`, + doc: "kubernetesEnv", envOrchestrator: "kubernetes", expectedOrchestrator: "kubernetes", expectedKubernetes: true, expectedSwarm: false, }, { - doc: "kubernetesFlag", - configfile: `{ - }`, + doc: "kubernetesFlag", flagOrchestrator: "kubernetes", expectedOrchestrator: "kubernetes", expectedKubernetes: true, expectedSwarm: false, }, { - doc: "allOrchestratorFlag", - configfile: `{ - }`, + doc: "allOrchestratorFlag", flagOrchestrator: "all", expectedOrchestrator: "all", expectedKubernetes: true, expectedSwarm: true, }, { - doc: "envOverridesConfigFile", - configfile: `{ - "stackOrchestrator": "kubernetes" - }`, + doc: "kubernetesContext", + contextOrchestrator: "kubernetes", + expectedOrchestrator: "kubernetes", + expectedKubernetes: true, + }, + { + doc: "contextOverridesConfigFile", + globalOrchestrator: "kubernetes", + contextOrchestrator: "swarm", + expectedOrchestrator: "swarm", + expectedKubernetes: false, + expectedSwarm: true, + }, + { + doc: "envOverridesConfigFile", + globalOrchestrator: "kubernetes", envOrchestrator: "swarm", expectedOrchestrator: "swarm", expectedKubernetes: false, expectedSwarm: true, }, { - doc: "flagOverridesEnv", - configfile: `{ - }`, + doc: "flagOverridesEnv", envOrchestrator: "kubernetes", flagOrchestrator: "swarm", expectedOrchestrator: "swarm", @@ -93,22 +88,10 @@ func TestOrchestratorSwitch(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.doc, func(t *testing.T) { - dir := fs.NewDir(t, testcase.doc, fs.WithFile("config.json", testcase.configfile)) - defer dir.Remove() - apiclient := &fakeClient{ - version: defaultVersion, - } if testcase.envOrchestrator != "" { defer env.Patch(t, "DOCKER_STACK_ORCHESTRATOR", testcase.envOrchestrator)() } - - cli := &DockerCli{client: apiclient, err: os.Stderr} - cliconfig.SetDir(dir.Path()) - options := flags.NewClientOptions() - err := cli.Initialize(options) - assert.NilError(t, err) - - orchestrator, err := GetStackOrchestrator(testcase.flagOrchestrator, cli.ConfigFile().StackOrchestrator, ioutil.Discard) + orchestrator, err := GetStackOrchestrator(testcase.flagOrchestrator, testcase.contextOrchestrator, testcase.globalOrchestrator, ioutil.Discard) assert.NilError(t, err) assert.Check(t, is.Equal(testcase.expectedKubernetes, orchestrator.HasKubernetes())) assert.Check(t, is.Equal(testcase.expectedSwarm, orchestrator.HasSwarm())) diff --git a/cli/command/stack/cmd.go b/cli/command/stack/cmd.go index 851ac13c4a..1570080d1e 100644 --- a/cli/command/stack/cmd.go +++ b/cli/command/stack/cmd.go @@ -3,13 +3,10 @@ package stack import ( "errors" "fmt" - "io" "strings" "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" - cliconfig "github.com/docker/cli/cli/config" - "github.com/docker/cli/cli/config/configfile" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -28,11 +25,7 @@ func NewStackCommand(dockerCli command.Cli) *cobra.Command { Short: "Manage Docker stacks", Args: cli.NoArgs, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - configFile := dockerCli.ConfigFile() - if configFile == nil { - configFile = cliconfig.LoadDefaultConfigFile(dockerCli.Err()) - } - orchestrator, err := getOrchestrator(configFile, cmd, dockerCli.Err()) + orchestrator, err := getOrchestrator(dockerCli, cmd) if err != nil { return err } @@ -81,12 +74,12 @@ func NewTopLevelDeployCommand(dockerCli command.Cli) *cobra.Command { return cmd } -func getOrchestrator(config *configfile.ConfigFile, cmd *cobra.Command, stderr io.Writer) (command.Orchestrator, error) { +func getOrchestrator(dockerCli command.Cli, cmd *cobra.Command) (command.Orchestrator, error) { var orchestratorFlag string if o, err := cmd.Flags().GetString("orchestrator"); err == nil { orchestratorFlag = o } - return command.GetStackOrchestrator(orchestratorFlag, config.StackOrchestrator, stderr) + return dockerCli.StackOrchestrator(orchestratorFlag) } func hideOrchestrationFlags(cmd *cobra.Command, orchestrator command.Orchestrator) { diff --git a/cli/command/stack/kubernetes/cli.go b/cli/command/stack/kubernetes/cli.go index f98b4c4f3a..a531846809 100644 --- a/cli/command/stack/kubernetes/cli.go +++ b/cli/command/stack/kubernetes/cli.go @@ -7,12 +7,14 @@ import ( "os" "github.com/docker/cli/cli/command" + kubecontext "github.com/docker/cli/cli/context/kubernetes" kubernetes "github.com/docker/compose-on-kubernetes/api" cliv1beta1 "github.com/docker/compose-on-kubernetes/api/client/clientset/typed/compose/v1beta1" "github.com/pkg/errors" flag "github.com/spf13/pflag" kubeclient "k8s.io/client-go/kubernetes" restclient "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" ) // KubeCli holds kubernetes specifics (client, namespace) with the command.Cli @@ -55,7 +57,18 @@ func WrapCli(dockerCli command.Cli, opts Options) (*KubeCli, error) { cli := &KubeCli{ Cli: dockerCli, } - clientConfig := kubernetes.NewKubernetesConfig(opts.Config) + var ( + clientConfig clientcmd.ClientConfig + err error + ) + if dockerCli.CurrentContext() == "" { + clientConfig = kubernetes.NewKubernetesConfig(opts.Config) + } else { + clientConfig, err = kubecontext.ConfigFromContext(dockerCli.CurrentContext(), dockerCli.ContextStore()) + } + if err != nil { + return nil, err + } cli.kubeNamespace = opts.Namespace if opts.Namespace == "" { diff --git a/cli/command/system/version.go b/cli/command/system/version.go index 97c50f1bdc..19407a1426 100644 --- a/cli/command/system/version.go +++ b/cli/command/system/version.go @@ -11,6 +11,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" + kubecontext "github.com/docker/cli/cli/context/kubernetes" "github.com/docker/cli/templates" kubernetes "github.com/docker/compose-on-kubernetes/api" "github.com/docker/docker/api/types" @@ -18,6 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" kubernetesClient "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" ) var versionTemplate = `{{with .Client -}} @@ -126,7 +128,7 @@ func runVersion(dockerCli command.Cli, opts *versionOptions) error { return cli.StatusError{StatusCode: 64, Status: err.Error()} } - orchestrator, err := command.GetStackOrchestrator("", dockerCli.ConfigFile().StackOrchestrator, dockerCli.Err()) + orchestrator, err := dockerCli.StackOrchestrator("") if err != nil { return cli.StatusError{StatusCode: 64, Status: err.Error()} } @@ -151,7 +153,7 @@ func runVersion(dockerCli command.Cli, opts *versionOptions) error { vd.Server = &sv var kubeVersion *kubernetesVersion if orchestrator.HasKubernetes() { - kubeVersion = getKubernetesVersion(opts.kubeConfig) + kubeVersion = getKubernetesVersion(dockerCli, opts.kubeConfig) } foundEngine := false foundKubernetes := false @@ -230,17 +232,29 @@ func getDetailsOrder(v types.ComponentVersion) []string { return out } -func getKubernetesVersion(kubeConfig string) *kubernetesVersion { +func getKubernetesVersion(dockerCli command.Cli, kubeConfig string) *kubernetesVersion { version := kubernetesVersion{ Kubernetes: "Unknown", StackAPI: "Unknown", } - clientConfig := kubernetes.NewKubernetesConfig(kubeConfig) - config, err := clientConfig.ClientConfig() + var ( + clientConfig clientcmd.ClientConfig + err error + ) + if dockerCli.CurrentContext() == command.ContextDockerHost { + clientConfig = kubernetes.NewKubernetesConfig(kubeConfig) + } else { + clientConfig, err = kubecontext.ConfigFromContext(dockerCli.CurrentContext(), dockerCli.ContextStore()) + } if err != nil { logrus.Debugf("failed to get Kubernetes configuration: %s", err) return &version } + config, err := clientConfig.ClientConfig() + if err != nil { + logrus.Debugf("failed to get Kubernetes client config: %s", err) + return &version + } kubeClient, err := kubernetesClient.NewForConfig(config) if err != nil { logrus.Debugf("failed to get Kubernetes client: %s", err) diff --git a/cli/config/config.go b/cli/config/config.go index 9161921a2d..64f8d3b49c 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -18,6 +18,7 @@ const ( ConfigFileName = "config.json" configFileDir = ".docker" oldConfigfile = ".dockercfg" + contextsDir = "contexts" ) var ( @@ -35,6 +36,11 @@ func Dir() string { return configDir } +// ContextStoreDir returns the directory the docker contexts are stored in +func ContextStoreDir() string { + return filepath.Join(Dir(), contextsDir) +} + // SetDir sets the directory the configuration file is stored in func SetDir(dir string) { configDir = dir diff --git a/cli/config/configfile/file.go b/cli/config/configfile/file.go index 7fa9b734b9..d815570362 100644 --- a/cli/config/configfile/file.go +++ b/cli/config/configfile/file.go @@ -48,6 +48,7 @@ type ConfigFile struct { Experimental string `json:"experimental,omitempty"` StackOrchestrator string `json:"stackOrchestrator,omitempty"` Kubernetes *KubernetesConfig `json:"kubernetes,omitempty"` + CurrentContext string `json:"currentContext,omitempty"` } // ProxyConfig contains proxy configuration settings diff --git a/cli/context/docker/constants.go b/cli/context/docker/constants.go new file mode 100644 index 0000000000..1db5556d5f --- /dev/null +++ b/cli/context/docker/constants.go @@ -0,0 +1,6 @@ +package docker + +const ( + // DockerEndpoint is the name of the docker endpoint in a stored context + DockerEndpoint = "docker" +) diff --git a/cli/context/docker/load.go b/cli/context/docker/load.go new file mode 100644 index 0000000000..4e0316a238 --- /dev/null +++ b/cli/context/docker/load.go @@ -0,0 +1,172 @@ +package docker + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "net" + "net/http" + "os" + "time" + + "github.com/docker/cli/cli/connhelper" + "github.com/docker/cli/cli/context" + "github.com/docker/cli/cli/context/store" + "github.com/docker/docker/client" + "github.com/docker/go-connections/tlsconfig" + "github.com/pkg/errors" +) + +// EndpointMeta is a typed wrapper around a context-store generic endpoint describing +// a Docker Engine endpoint, without its tls config +type EndpointMeta struct { + context.EndpointMetaBase + APIVersion string `json:"api_version,omitempty"` +} + +// Endpoint is a typed wrapper around a context-store generic endpoint describing +// a Docker Engine endpoint, with its tls data +type Endpoint struct { + EndpointMeta + TLSData *context.TLSData + TLSPassword string +} + +// WithTLSData loads TLS materials for the endpoint +func (c *EndpointMeta) WithTLSData(s store.Store, contextName string) (Endpoint, error) { + tlsData, err := context.LoadTLSData(s, contextName, DockerEndpoint) + if err != nil { + return Endpoint{}, err + } + return Endpoint{ + EndpointMeta: *c, + TLSData: tlsData, + }, nil +} + +// tlsConfig extracts a context docker endpoint TLS config +func (c *Endpoint) tlsConfig() (*tls.Config, error) { + if c.TLSData == nil && !c.SkipTLSVerify { + // there is no specific tls config + return nil, nil + } + var tlsOpts []func(*tls.Config) + if c.TLSData != nil && c.TLSData.CA != nil { + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(c.TLSData.CA) { + return nil, errors.New("failed to retrieve context tls info: ca.pem seems invalid") + } + tlsOpts = append(tlsOpts, func(cfg *tls.Config) { + cfg.RootCAs = certPool + }) + } + if c.TLSData != nil && c.TLSData.Key != nil && c.TLSData.Cert != nil { + keyBytes := c.TLSData.Key + pemBlock, _ := pem.Decode(keyBytes) + if pemBlock == nil { + return nil, fmt.Errorf("no valid private key found") + } + + var err error + if x509.IsEncryptedPEMBlock(pemBlock) { + keyBytes, err = x509.DecryptPEMBlock(pemBlock, []byte(c.TLSPassword)) + if err != nil { + return nil, errors.Wrap(err, "private key is encrypted, but could not decrypt it") + } + keyBytes = pem.EncodeToMemory(&pem.Block{Type: pemBlock.Type, Bytes: keyBytes}) + } + + x509cert, err := tls.X509KeyPair(c.TLSData.Cert, keyBytes) + if err != nil { + return nil, errors.Wrap(err, "failed to retrieve context tls info") + } + tlsOpts = append(tlsOpts, func(cfg *tls.Config) { + cfg.Certificates = []tls.Certificate{x509cert} + }) + } + if c.SkipTLSVerify { + tlsOpts = append(tlsOpts, func(cfg *tls.Config) { + cfg.InsecureSkipVerify = true + }) + } + return tlsconfig.ClientDefault(tlsOpts...), nil +} + +// ClientOpts returns a slice of Client options to configure an API client with this endpoint +func (c *Endpoint) ClientOpts() ([]func(*client.Client) error, error) { + var result []func(*client.Client) error + if c.Host != "" { + helper, err := connhelper.GetConnectionHelper(c.Host) + if err != nil { + return nil, err + } + if helper == nil { + tlsConfig, err := c.tlsConfig() + if err != nil { + return nil, err + } + result = append(result, + client.WithHost(c.Host), + withHTTPClient(tlsConfig), + ) + + } else { + httpClient := &http.Client{ + // No tls + // No proxy + Transport: &http.Transport{ + DialContext: helper.Dialer, + }, + } + result = append(result, + client.WithHTTPClient(httpClient), + client.WithHost(helper.Host), + client.WithDialContext(helper.Dialer), + ) + } + } + + version := os.Getenv("DOCKER_API_VERSION") + if version == "" { + version = c.APIVersion + } + if version != "" { + result = append(result, client.WithVersion(version)) + } + return result, nil +} + +func withHTTPClient(tlsConfig *tls.Config) func(*client.Client) error { + return func(c *client.Client) error { + if tlsConfig == nil { + // Use the default HTTPClient + return nil + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + DialContext: (&net.Dialer{ + KeepAlive: 30 * time.Second, + Timeout: 30 * time.Second, + }).DialContext, + }, + CheckRedirect: client.CheckRedirect, + } + return client.WithHTTPClient(httpClient)(c) + } +} + +// EndpointFromContext parses a context docker endpoint metadata into a typed EndpointMeta structure +func EndpointFromContext(metadata store.ContextMetadata) (EndpointMeta, error) { + ep, ok := metadata.Endpoints[DockerEndpoint] + if !ok { + return EndpointMeta{}, errors.New("cannot find docker endpoint in context") + } + typed, ok := ep.(EndpointMeta) + if !ok { + return EndpointMeta{}, errors.Errorf("endpoint %q is not of type EndpointMeta", DockerEndpoint) + } + return typed, nil +} diff --git a/cli/context/endpoint.go b/cli/context/endpoint.go new file mode 100644 index 0000000000..806a8524ef --- /dev/null +++ b/cli/context/endpoint.go @@ -0,0 +1,7 @@ +package context + +// EndpointMetaBase contains fields we expect to be common for most context endpoints +type EndpointMetaBase struct { + Host string `json:"host,omitempty"` + SkipTLSVerify bool `json:"skip_tls_verify"` +} diff --git a/cli/context/kubernetes/constants.go b/cli/context/kubernetes/constants.go new file mode 100644 index 0000000000..8998de989a --- /dev/null +++ b/cli/context/kubernetes/constants.go @@ -0,0 +1,6 @@ +package kubernetes + +const ( + // KubernetesEndpoint is the kubernetes endpoint name in a stored context + KubernetesEndpoint = "kubernetes" +) diff --git a/cli/context/kubernetes/endpoint_test.go b/cli/context/kubernetes/endpoint_test.go new file mode 100644 index 0000000000..57b308b8f7 --- /dev/null +++ b/cli/context/kubernetes/endpoint_test.go @@ -0,0 +1,183 @@ +package kubernetes + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/docker/cli/cli/context" + "github.com/docker/cli/cli/context/store" + "gotest.tools/assert" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +func testEndpoint(server, defaultNamespace string, ca, cert, key []byte, skipTLSVerify bool) *Endpoint { + var tlsData *context.TLSData + if ca != nil || cert != nil || key != nil { + tlsData = &context.TLSData{ + CA: ca, + Cert: cert, + Key: key, + } + } + return &Endpoint{ + EndpointMeta: EndpointMeta{ + EndpointMetaBase: context.EndpointMetaBase{ + Host: server, + SkipTLSVerify: skipTLSVerify, + }, + DefaultNamespace: defaultNamespace, + }, + TLSData: tlsData, + } +} + +var testStoreCfg = store.NewConfig( + func() interface{} { + return &map[string]interface{}{} + }, + store.EndpointTypeGetter(KubernetesEndpoint, func() interface{} { return &EndpointMeta{} }), +) + +func TestSaveLoadContexts(t *testing.T) { + storeDir, err := ioutil.TempDir("", "test-load-save-k8-context") + assert.NilError(t, err) + defer os.RemoveAll(storeDir) + store := store.New(storeDir, testStoreCfg) + assert.NilError(t, testEndpoint("https://test", "test", nil, nil, nil, false).Save(store, "raw-notls")) + assert.NilError(t, testEndpoint("https://test", "test", nil, nil, nil, true).Save(store, "raw-notls-skip")) + assert.NilError(t, testEndpoint("https://test", "test", []byte("ca"), []byte("cert"), []byte("key"), true).Save(store, "raw-tls")) + + kcFile, err := ioutil.TempFile(os.TempDir(), "test-load-save-k8-context") + assert.NilError(t, err) + defer os.Remove(kcFile.Name()) + defer kcFile.Close() + cfg := clientcmdapi.NewConfig() + cfg.AuthInfos["user"] = clientcmdapi.NewAuthInfo() + cfg.Contexts["context1"] = clientcmdapi.NewContext() + cfg.Clusters["cluster1"] = clientcmdapi.NewCluster() + cfg.Contexts["context2"] = clientcmdapi.NewContext() + cfg.Clusters["cluster2"] = clientcmdapi.NewCluster() + cfg.AuthInfos["user"].ClientCertificateData = []byte("cert") + cfg.AuthInfos["user"].ClientKeyData = []byte("key") + cfg.Clusters["cluster1"].Server = "https://server1" + cfg.Clusters["cluster1"].InsecureSkipTLSVerify = true + cfg.Clusters["cluster2"].Server = "https://server2" + cfg.Clusters["cluster2"].CertificateAuthorityData = []byte("ca") + cfg.Contexts["context1"].AuthInfo = "user" + cfg.Contexts["context1"].Cluster = "cluster1" + cfg.Contexts["context1"].Namespace = "namespace1" + cfg.Contexts["context2"].AuthInfo = "user" + cfg.Contexts["context2"].Cluster = "cluster2" + cfg.Contexts["context2"].Namespace = "namespace2" + cfg.CurrentContext = "context1" + cfgData, err := clientcmd.Write(*cfg) + assert.NilError(t, err) + _, err = kcFile.Write(cfgData) + assert.NilError(t, err) + kcFile.Close() + + epDefault, err := FromKubeConfig(kcFile.Name(), "", "") + assert.NilError(t, err) + epContext2, err := FromKubeConfig(kcFile.Name(), "context2", "namespace-override") + assert.NilError(t, err) + assert.NilError(t, epDefault.Save(store, "embed-default-context")) + assert.NilError(t, epContext2.Save(store, "embed-context2")) + + rawNoTLSMeta, err := store.GetContextMetadata("raw-notls") + assert.NilError(t, err) + rawNoTLSSkipMeta, err := store.GetContextMetadata("raw-notls-skip") + assert.NilError(t, err) + rawTLSMeta, err := store.GetContextMetadata("raw-tls") + assert.NilError(t, err) + embededDefaultMeta, err := store.GetContextMetadata("embed-default-context") + assert.NilError(t, err) + embededContext2Meta, err := store.GetContextMetadata("embed-context2") + assert.NilError(t, err) + + rawNoTLS := EndpointFromContext(rawNoTLSMeta) + rawNoTLSSkip := EndpointFromContext(rawNoTLSSkipMeta) + rawTLS := EndpointFromContext(rawTLSMeta) + embededDefault := EndpointFromContext(embededDefaultMeta) + embededContext2 := EndpointFromContext(embededContext2Meta) + + rawNoTLSEP, err := rawNoTLS.WithTLSData(store, "raw-notls") + assert.NilError(t, err) + checkClientConfig(t, store, rawNoTLSEP, "https://test", "test", nil, nil, nil, false) + rawNoTLSSkipEP, err := rawNoTLSSkip.WithTLSData(store, "raw-notls-skip") + assert.NilError(t, err) + checkClientConfig(t, store, rawNoTLSSkipEP, "https://test", "test", nil, nil, nil, true) + rawTLSEP, err := rawTLS.WithTLSData(store, "raw-tls") + assert.NilError(t, err) + checkClientConfig(t, store, rawTLSEP, "https://test", "test", []byte("ca"), []byte("cert"), []byte("key"), true) + embededDefaultEP, err := embededDefault.WithTLSData(store, "embed-default-context") + assert.NilError(t, err) + checkClientConfig(t, store, embededDefaultEP, "https://server1", "namespace1", nil, []byte("cert"), []byte("key"), true) + embededContext2EP, err := embededContext2.WithTLSData(store, "embed-context2") + assert.NilError(t, err) + checkClientConfig(t, store, embededContext2EP, "https://server2", "namespace-override", []byte("ca"), []byte("cert"), []byte("key"), false) +} + +func checkClientConfig(t *testing.T, s store.Store, ep Endpoint, server, namespace string, ca, cert, key []byte, skipTLSVerify bool) { + config := ep.KubernetesConfig() + cfg, err := config.ClientConfig() + assert.NilError(t, err) + ns, _, _ := config.Namespace() + assert.Equal(t, server, cfg.Host) + assert.Equal(t, namespace, ns) + assert.DeepEqual(t, ca, cfg.CAData) + assert.DeepEqual(t, cert, cfg.CertData) + assert.DeepEqual(t, key, cfg.KeyData) + assert.Equal(t, skipTLSVerify, cfg.Insecure) +} + +func TestSaveLoadGKEConfig(t *testing.T) { + storeDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(storeDir) + store := store.New(storeDir, testStoreCfg) + cfg, err := clientcmd.LoadFromFile("testdata/gke-kubeconfig") + assert.NilError(t, err) + clientCfg := clientcmd.NewDefaultClientConfig(*cfg, &clientcmd.ConfigOverrides{}) + expectedCfg, err := clientCfg.ClientConfig() + assert.NilError(t, err) + ep, err := FromKubeConfig("testdata/gke-kubeconfig", "", "") + assert.NilError(t, err) + assert.NilError(t, ep.Save(store, "gke-context")) + persistedMetadata, err := store.GetContextMetadata("gke-context") + assert.NilError(t, err) + persistedEPMeta := EndpointFromContext(persistedMetadata) + assert.Check(t, persistedEPMeta != nil) + persistedEP, err := persistedEPMeta.WithTLSData(store, "gke-context") + assert.NilError(t, err) + persistedCfg := persistedEP.KubernetesConfig() + actualCfg, err := persistedCfg.ClientConfig() + assert.NilError(t, err) + assert.DeepEqual(t, expectedCfg.AuthProvider, actualCfg.AuthProvider) +} + +func TestSaveLoadEKSConfig(t *testing.T) { + storeDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(storeDir) + store := store.New(storeDir, testStoreCfg) + cfg, err := clientcmd.LoadFromFile("testdata/eks-kubeconfig") + assert.NilError(t, err) + clientCfg := clientcmd.NewDefaultClientConfig(*cfg, &clientcmd.ConfigOverrides{}) + expectedCfg, err := clientCfg.ClientConfig() + assert.NilError(t, err) + ep, err := FromKubeConfig("testdata/eks-kubeconfig", "", "") + assert.NilError(t, err) + assert.NilError(t, ep.Save(store, "eks-context")) + persistedMetadata, err := store.GetContextMetadata("eks-context") + assert.NilError(t, err) + persistedEPMeta := EndpointFromContext(persistedMetadata) + assert.Check(t, persistedEPMeta != nil) + persistedEP, err := persistedEPMeta.WithTLSData(store, "eks-context") + assert.NilError(t, err) + persistedCfg := persistedEP.KubernetesConfig() + actualCfg, err := persistedCfg.ClientConfig() + assert.NilError(t, err) + assert.DeepEqual(t, expectedCfg.ExecProvider, actualCfg.ExecProvider) +} diff --git a/cli/context/kubernetes/load.go b/cli/context/kubernetes/load.go new file mode 100644 index 0000000000..1898f57422 --- /dev/null +++ b/cli/context/kubernetes/load.go @@ -0,0 +1,95 @@ +package kubernetes + +import ( + "github.com/docker/cli/cli/context" + "github.com/docker/cli/cli/context/store" + "github.com/docker/cli/kubernetes" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// EndpointMeta is a typed wrapper around a context-store generic endpoint describing +// a Kubernetes endpoint, without TLS data +type EndpointMeta struct { + context.EndpointMetaBase + DefaultNamespace string `json:"default_namespace,omitempty"` + AuthProvider *clientcmdapi.AuthProviderConfig `json:"auth_provider,omitempty"` + Exec *clientcmdapi.ExecConfig `json:"exec,omitempty"` +} + +// Endpoint is a typed wrapper around a context-store generic endpoint describing +// a Kubernetes endpoint, with TLS data +type Endpoint struct { + EndpointMeta + TLSData *context.TLSData +} + +// WithTLSData loads TLS materials for the endpoint +func (c *EndpointMeta) WithTLSData(s store.Store, contextName string) (Endpoint, error) { + tlsData, err := context.LoadTLSData(s, contextName, KubernetesEndpoint) + if err != nil { + return Endpoint{}, err + } + return Endpoint{ + EndpointMeta: *c, + TLSData: tlsData, + }, nil +} + +// KubernetesConfig creates the kubernetes client config from the endpoint +func (c *Endpoint) KubernetesConfig() clientcmd.ClientConfig { + cfg := clientcmdapi.NewConfig() + cluster := clientcmdapi.NewCluster() + cluster.Server = c.Host + cluster.InsecureSkipTLSVerify = c.SkipTLSVerify + authInfo := clientcmdapi.NewAuthInfo() + if c.TLSData != nil { + cluster.CertificateAuthorityData = c.TLSData.CA + authInfo.ClientCertificateData = c.TLSData.Cert + authInfo.ClientKeyData = c.TLSData.Key + } + authInfo.AuthProvider = c.AuthProvider + authInfo.Exec = c.Exec + cfg.Clusters["cluster"] = cluster + cfg.AuthInfos["authInfo"] = authInfo + ctx := clientcmdapi.NewContext() + ctx.AuthInfo = "authInfo" + ctx.Cluster = "cluster" + ctx.Namespace = c.DefaultNamespace + cfg.Contexts["context"] = ctx + cfg.CurrentContext = "context" + return clientcmd.NewDefaultClientConfig(*cfg, &clientcmd.ConfigOverrides{}) +} + +// EndpointFromContext extracts kubernetes endpoint info from current context +func EndpointFromContext(metadata store.ContextMetadata) *EndpointMeta { + ep, ok := metadata.Endpoints[KubernetesEndpoint] + if !ok { + return nil + } + typed, ok := ep.(EndpointMeta) + if !ok { + return nil + } + return &typed +} + +// ConfigFromContext resolves a kubernetes client config for the specified context. +// If kubeconfigOverride is specified, use this config file instead of the context defaults.ConfigFromContext +// if command.ContextDockerHost is specified as the context name, fallsback to the default user's kubeconfig file +func ConfigFromContext(name string, s store.Store) (clientcmd.ClientConfig, error) { + ctxMeta, err := s.GetContextMetadata(name) + if err != nil { + return nil, err + } + epMeta := EndpointFromContext(ctxMeta) + if epMeta != nil { + ep, err := epMeta.WithTLSData(s, name) + if err != nil { + return nil, err + } + return ep.KubernetesConfig(), nil + } + // context has no kubernetes endpoint + return kubernetes.NewKubernetesConfig(""), nil +} diff --git a/cli/context/kubernetes/save.go b/cli/context/kubernetes/save.go new file mode 100644 index 0000000000..35646bc57c --- /dev/null +++ b/cli/context/kubernetes/save.go @@ -0,0 +1,79 @@ +package kubernetes + +import ( + "io/ioutil" + + "github.com/docker/cli/cli/context" + "github.com/docker/cli/cli/context/store" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// FromKubeConfig creates a Kubernetes endpoint from a Kubeconfig file +func FromKubeConfig(kubeconfig, kubeContext, namespaceOverride string) (Endpoint, error) { + cfg := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + &clientcmd.ClientConfigLoadingRules{ExplicitPath: kubeconfig}, + &clientcmd.ConfigOverrides{CurrentContext: kubeContext, Context: clientcmdapi.Context{Namespace: namespaceOverride}}) + ns, _, err := cfg.Namespace() + if err != nil { + return Endpoint{}, err + } + clientcfg, err := cfg.ClientConfig() + if err != nil { + return Endpoint{}, err + } + var ca, key, cert []byte + if ca, err = readFileOrDefault(clientcfg.CAFile, clientcfg.CAData); err != nil { + return Endpoint{}, err + } + if key, err = readFileOrDefault(clientcfg.KeyFile, clientcfg.KeyData); err != nil { + return Endpoint{}, err + } + if cert, err = readFileOrDefault(clientcfg.CertFile, clientcfg.CertData); err != nil { + return Endpoint{}, err + } + var tlsData *context.TLSData + if ca != nil || cert != nil || key != nil { + tlsData = &context.TLSData{ + CA: ca, + Cert: cert, + Key: key, + } + } + return Endpoint{ + EndpointMeta: EndpointMeta{ + EndpointMetaBase: context.EndpointMetaBase{ + Host: clientcfg.Host, + SkipTLSVerify: clientcfg.Insecure, + }, + DefaultNamespace: ns, + AuthProvider: clientcfg.AuthProvider, + Exec: clientcfg.ExecProvider, + }, + TLSData: tlsData, + }, nil +} + +func readFileOrDefault(path string, defaultValue []byte) ([]byte, error) { + if path != "" { + return ioutil.ReadFile(path) + } + return defaultValue, nil +} + +// Save the endpoint metadata and TLS bundle in the context store +func (ep *Endpoint) Save(s store.Store, contextName string) error { + tlsData := ep.TLSData.ToStoreTLSData() + existingContext, err := s.GetContextMetadata(contextName) + if err != nil && !store.IsErrContextDoesNotExist(err) { + return err + } + if existingContext.Endpoints == nil { + existingContext.Endpoints = make(map[string]interface{}) + } + existingContext.Endpoints[KubernetesEndpoint] = ep.EndpointMeta + if err := s.CreateOrUpdateContext(contextName, existingContext); err != nil { + return err + } + return s.ResetContextEndpointTLSMaterial(contextName, KubernetesEndpoint, tlsData) +} diff --git a/cli/context/kubernetes/testdata/eks-kubeconfig b/cli/context/kubernetes/testdata/eks-kubeconfig new file mode 100644 index 0000000000..deed186a8a --- /dev/null +++ b/cli/context/kubernetes/testdata/eks-kubeconfig @@ -0,0 +1,23 @@ + apiVersion: v1 + clusters: + - cluster: + server: https://some-server + name: kubernetes + contexts: + - context: + cluster: kubernetes + user: aws + name: aws + current-context: aws + kind: Config + preferences: {} + users: + - name: aws + user: + exec: + apiVersion: client.authentication.k8s.io/v1alpha1 + command: heptio-authenticator-aws + args: + - "token" + - "-i" + - "eks-cf" \ No newline at end of file diff --git a/cli/context/kubernetes/testdata/gke-kubeconfig b/cli/context/kubernetes/testdata/gke-kubeconfig new file mode 100644 index 0000000000..5a6384cbae --- /dev/null +++ b/cli/context/kubernetes/testdata/gke-kubeconfig @@ -0,0 +1,23 @@ +apiVersion: v1 +clusters: +- cluster: + server: https://some-server + name: gke_sample +contexts: +- context: + cluster: gke_sample + user: gke_sample + name: gke_sample +current-context: gke_sample +kind: Config +preferences: {} +users: +- name: gke_sample + user: + auth-provider: + config: + cmd-args: config config-helper --format=json + cmd-path: /google/google-cloud-sdk/bin/gcloud + expiry-key: '{.credential.token_expiry}' + token-key: '{.credential.access_token}' + name: gcp diff --git a/cli/context/store/doc.go b/cli/context/store/doc.go new file mode 100644 index 0000000000..e432dae3b9 --- /dev/null +++ b/cli/context/store/doc.go @@ -0,0 +1,21 @@ +// Package store provides a generic way to store credentials to connect to virtually any kind of remote system. +// The term `context` comes from the similar feature in Kubernetes kubectl config files. +// +// Conceptually, a context is a set of metadata and TLS data, that can be used to connect to various endpoints +// of a remote system. TLS data and metadata are stored separately, so that in the future, we will be able to store sensitive +// information in a more secure way, depending on the os we are running on (e.g.: on Windows we could use the user Certificate Store, on Mac OS the user Keychain...). +// +// Current implementation is purely file based with the following structure: +// ${CONTEXT_ROOT} +// - meta/ +// - context1/meta.json: contains context medata (key/value pairs) as well as a list of endpoints (themselves containing key/value pair metadata) +// - contexts/can/also/be/folded/like/this/meta.json: same as context1, but for a context named `contexts/can/also/be/folded/like/this` +// - tls/ +// - context1/endpoint1/: directory containing TLS data for the endpoint1 in context1 +// +// The context store itself has absolutely no knowledge about what a docker or a kubernetes endpoint should contain in term of metadata or TLS config. +// Client code is responsible for generating and parsing endpoint metadata and TLS files. +// The multi-endpoints approach of this package allows to combine many different endpoints in the same "context" (e.g., the Docker CLI +// is able for a single context to define both a docker endpoint and a Kubernetes endpoint for the same cluster, and also specify which +// orchestrator to use by default when deploying a compose stack on this cluster). +package store diff --git a/cli/context/store/metadata_test.go b/cli/context/store/metadata_test.go new file mode 100644 index 0000000000..f8121ec20c --- /dev/null +++ b/cli/context/store/metadata_test.go @@ -0,0 +1,142 @@ +package store + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "gotest.tools/assert" + "gotest.tools/assert/cmp" +) + +var testMetadata = ContextMetadata{ + Endpoints: map[string]interface{}{ + "ep1": endpoint{Foo: "bar"}, + }, + Metadata: context{Bar: "baz"}, +} + +func TestMetadataGetNotExisting(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir, config: testCfg} + _, err = testee.get("noexist") + assert.Assert(t, IsErrContextDoesNotExist(err)) +} + +func TestMetadataCreateGetRemove(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir, config: testCfg} + expected2 := ContextMetadata{ + Endpoints: map[string]interface{}{ + "ep1": endpoint{Foo: "baz"}, + "ep2": endpoint{Foo: "bee"}, + }, + Metadata: context{Bar: "foo"}, + } + err = testee.createOrUpdate("test-context", testMetadata) + assert.NilError(t, err) + // create a new instance to check it does not depend on some sort of state + testee = metadataStore{root: testDir, config: testCfg} + meta, err := testee.get("test-context") + assert.NilError(t, err) + assert.DeepEqual(t, meta, testMetadata) + + // update + + err = testee.createOrUpdate("test-context", expected2) + assert.NilError(t, err) + meta, err = testee.get("test-context") + assert.NilError(t, err) + assert.DeepEqual(t, meta, expected2) + + assert.NilError(t, testee.remove("test-context")) + assert.NilError(t, testee.remove("test-context")) // support duplicate remove + _, err = testee.get("test-context") + assert.Assert(t, IsErrContextDoesNotExist(err)) +} + +func TestMetadataRespectJsonAnnotation(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir, config: testCfg} + assert.NilError(t, testee.createOrUpdate("test", testMetadata)) + bytes, err := ioutil.ReadFile(filepath.Join(testDir, "test", "meta.json")) + assert.NilError(t, err) + assert.Assert(t, cmp.Contains(string(bytes), "a_very_recognizable_field_name")) + assert.Assert(t, cmp.Contains(string(bytes), "another_very_recognizable_field_name")) +} + +func TestMetadataList(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir, config: testCfg} + wholeData := map[string]ContextMetadata{ + "simple": testMetadata, + "simple2": testMetadata, + "nested/context": testMetadata, + "nestedwith-parent/context": testMetadata, + "nestedwith-parent": testMetadata, + } + + for k, s := range wholeData { + err = testee.createOrUpdate(k, s) + assert.NilError(t, err) + } + + data, err := testee.list() + assert.NilError(t, err) + assert.DeepEqual(t, data, wholeData) +} + +func TestEmptyConfig(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir} + wholeData := map[string]ContextMetadata{ + "simple": testMetadata, + "simple2": testMetadata, + "nested/context": testMetadata, + "nestedwith-parent/context": testMetadata, + "nestedwith-parent": testMetadata, + } + + for k, s := range wholeData { + err = testee.createOrUpdate(k, s) + assert.NilError(t, err) + } + + data, err := testee.list() + assert.NilError(t, err) + assert.Equal(t, len(data), len(wholeData)) +} + +type contextWithEmbedding struct { + embeddedStruct +} +type embeddedStruct struct { + Val string +} + +func TestWithEmbedding(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := metadataStore{root: testDir, config: NewConfig(func() interface{} { return &contextWithEmbedding{} })} + testCtxMeta := contextWithEmbedding{ + embeddedStruct: embeddedStruct{ + Val: "Hello", + }, + } + assert.NilError(t, testee.createOrUpdate("test", ContextMetadata{Metadata: testCtxMeta})) + res, err := testee.get("test") + assert.NilError(t, err) + assert.Equal(t, testCtxMeta, res.Metadata) +} diff --git a/cli/context/store/metadatastore.go b/cli/context/store/metadatastore.go new file mode 100644 index 0000000000..178cf9e288 --- /dev/null +++ b/cli/context/store/metadatastore.go @@ -0,0 +1,146 @@ +package store + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" +) + +const ( + metadataDir = "meta" + metaFile = "meta.json" +) + +type metadataStore struct { + root string + config Config +} + +func (s *metadataStore) contextDir(name string) string { + return filepath.Join(s.root, name) +} + +func (s *metadataStore) createOrUpdate(name string, meta ContextMetadata) error { + contextDir := s.contextDir(name) + if err := os.MkdirAll(contextDir, 0755); err != nil { + return err + } + bytes, err := json.Marshal(&meta) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(contextDir, metaFile), bytes, 0644) +} + +func parseTypedOrMap(payload []byte, getter TypeGetter) (interface{}, error) { + if len(payload) == 0 || string(payload) == "null" { + return nil, nil + } + if getter == nil { + var res map[string]interface{} + if err := json.Unmarshal(payload, &res); err != nil { + return nil, err + } + return res, nil + } + typed := getter() + if err := json.Unmarshal(payload, typed); err != nil { + return nil, err + } + return reflect.ValueOf(typed).Elem().Interface(), nil +} + +func (s *metadataStore) get(name string) (ContextMetadata, error) { + contextDir := s.contextDir(name) + bytes, err := ioutil.ReadFile(filepath.Join(contextDir, metaFile)) + if err != nil { + return ContextMetadata{}, convertContextDoesNotExist(name, err) + } + var untyped untypedContextMetadata + r := ContextMetadata{ + Endpoints: make(map[string]interface{}), + } + if err := json.Unmarshal(bytes, &untyped); err != nil { + return ContextMetadata{}, err + } + if r.Metadata, err = parseTypedOrMap(untyped.Metadata, s.config.contextType); err != nil { + return ContextMetadata{}, err + } + for k, v := range untyped.Endpoints { + if r.Endpoints[k], err = parseTypedOrMap(v, s.config.endpointTypes[k]); err != nil { + return ContextMetadata{}, err + } + } + return r, err +} + +func (s *metadataStore) remove(name string) error { + contextDir := s.contextDir(name) + return os.RemoveAll(contextDir) +} + +func (s *metadataStore) list() (map[string]ContextMetadata, error) { + ctxNames, err := listRecursivelyMetadataDirs(s.root) + if err != nil { + if os.IsNotExist(err) { + // store is empty, meta dir does not exist yet + // this should not be considered an error + return map[string]ContextMetadata{}, nil + } + return nil, err + } + res := make(map[string]ContextMetadata) + for _, name := range ctxNames { + res[name], err = s.get(name) + if err != nil { + return nil, err + } + } + return res, nil +} + +func isContextDir(path string) bool { + s, err := os.Stat(filepath.Join(path, metaFile)) + if err != nil { + return false + } + return !s.IsDir() +} + +func listRecursivelyMetadataDirs(root string) ([]string, error) { + fis, err := ioutil.ReadDir(root) + if err != nil { + return nil, err + } + var result []string + for _, fi := range fis { + if fi.IsDir() { + if isContextDir(filepath.Join(root, fi.Name())) { + result = append(result, fi.Name()) + } + subs, err := listRecursivelyMetadataDirs(filepath.Join(root, fi.Name())) + if err != nil { + return nil, err + } + for _, s := range subs { + result = append(result, fmt.Sprintf("%s/%s", fi.Name(), s)) + } + } + } + return result, nil +} + +func convertContextDoesNotExist(name string, err error) error { + if os.IsNotExist(err) { + return &contextDoesNotExistError{name: name} + } + return err +} + +type untypedContextMetadata struct { + Metadata json.RawMessage `json:"metadata,omitempty"` + Endpoints map[string]json.RawMessage `json:"endpoints,omitempty"` +} diff --git a/cli/context/store/store.go b/cli/context/store/store.go new file mode 100644 index 0000000000..9238a92a23 --- /dev/null +++ b/cli/context/store/store.go @@ -0,0 +1,282 @@ +package store + +import ( + "archive/tar" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "path" + "path/filepath" + "strings" +) + +// Store provides a context store for easily remembering endpoints configuration +type Store interface { + ListContexts() (map[string]ContextMetadata, error) + CreateOrUpdateContext(name string, meta ContextMetadata) error + RemoveContext(name string) error + GetContextMetadata(name string) (ContextMetadata, error) + ResetContextTLSMaterial(name string, data *ContextTLSData) error + ResetContextEndpointTLSMaterial(contextName string, endpointName string, data *EndpointTLSData) error + ListContextTLSFiles(name string) (map[string]EndpointFiles, error) + GetContextTLSData(contextName, endpointName, fileName string) ([]byte, error) +} + +// ContextMetadata contains metadata about a context and its endpoints +type ContextMetadata struct { + Metadata interface{} `json:"metadata,omitempty"` + Endpoints map[string]interface{} `json:"endpoints,omitempty"` +} + +// EndpointTLSData represents tls data for a given endpoint +type EndpointTLSData struct { + Files map[string][]byte +} + +// ContextTLSData represents tls data for a whole context +type ContextTLSData struct { + Endpoints map[string]EndpointTLSData +} + +// New creates a store from a given directory. +// If the directory does not exist or is empty, initialize it +func New(dir string, cfg Config) Store { + metaRoot := filepath.Join(dir, metadataDir) + tlsRoot := filepath.Join(dir, tlsDir) + + return &store{ + meta: &metadataStore{ + root: metaRoot, + config: cfg, + }, + tls: &tlsStore{ + root: tlsRoot, + }, + } +} + +type store struct { + meta *metadataStore + tls *tlsStore +} + +func (s *store) ListContexts() (map[string]ContextMetadata, error) { + return s.meta.list() +} + +func (s *store) CreateOrUpdateContext(name string, meta ContextMetadata) error { + return s.meta.createOrUpdate(name, meta) +} + +func (s *store) RemoveContext(name string) error { + if err := s.meta.remove(name); err != nil { + return err + } + return s.tls.removeAllContextData(name) +} + +func (s *store) GetContextMetadata(name string) (ContextMetadata, error) { + return s.meta.get(name) +} + +func (s *store) ResetContextTLSMaterial(name string, data *ContextTLSData) error { + if err := s.tls.removeAllContextData(name); err != nil { + return err + } + if data == nil { + return nil + } + for ep, files := range data.Endpoints { + for fileName, data := range files.Files { + if err := s.tls.createOrUpdate(name, ep, fileName, data); err != nil { + return err + } + } + } + return nil +} + +func (s *store) ResetContextEndpointTLSMaterial(contextName string, endpointName string, data *EndpointTLSData) error { + if err := s.tls.removeAllEndpointData(contextName, endpointName); err != nil { + return err + } + if data == nil { + return nil + } + for fileName, data := range data.Files { + if err := s.tls.createOrUpdate(contextName, endpointName, fileName, data); err != nil { + return err + } + } + return nil +} + +func (s *store) ListContextTLSFiles(name string) (map[string]EndpointFiles, error) { + return s.tls.listContextData(name) +} + +func (s *store) GetContextTLSData(contextName, endpointName, fileName string) ([]byte, error) { + return s.tls.getData(contextName, endpointName, fileName) +} + +// Export exports an existing namespace into an opaque data stream +// This stream is actually a tarball containing context metadata and TLS materials, but it does +// not map 1:1 the layout of the context store (don't try to restore it manually without calling store.Import) +func Export(name string, s Store) io.ReadCloser { + reader, writer := io.Pipe() + go func() { + tw := tar.NewWriter(writer) + defer tw.Close() + defer writer.Close() + meta, err := s.GetContextMetadata(name) + if err != nil { + writer.CloseWithError(err) + return + } + metaBytes, err := json.Marshal(&meta) + if err != nil { + writer.CloseWithError(err) + return + } + if err = tw.WriteHeader(&tar.Header{ + Name: metaFile, + Mode: 0644, + Size: int64(len(metaBytes)), + }); err != nil { + writer.CloseWithError(err) + return + } + if _, err = tw.Write(metaBytes); err != nil { + writer.CloseWithError(err) + return + } + tlsFiles, err := s.ListContextTLSFiles(name) + if err != nil { + writer.CloseWithError(err) + return + } + if err = tw.WriteHeader(&tar.Header{ + Name: "tls", + Mode: 0700, + Size: 0, + Typeflag: tar.TypeDir, + }); err != nil { + writer.CloseWithError(err) + return + } + for endpointName, endpointFiles := range tlsFiles { + if err = tw.WriteHeader(&tar.Header{ + Name: path.Join("tls", endpointName), + Mode: 0700, + Size: 0, + Typeflag: tar.TypeDir, + }); err != nil { + writer.CloseWithError(err) + return + } + for _, fileName := range endpointFiles { + data, err := s.GetContextTLSData(name, endpointName, fileName) + if err != nil { + writer.CloseWithError(err) + return + } + if err = tw.WriteHeader(&tar.Header{ + Name: path.Join("tls", endpointName, fileName), + Mode: 0600, + Size: int64(len(data)), + }); err != nil { + writer.CloseWithError(err) + return + } + if _, err = tw.Write(data); err != nil { + writer.CloseWithError(err) + return + } + } + } + }() + return reader +} + +// Import imports an exported context into a store +func Import(name string, s Store, reader io.Reader) error { + tr := tar.NewReader(reader) + tlsData := ContextTLSData{ + Endpoints: map[string]EndpointTLSData{}, + } + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + if hdr.Typeflag == tar.TypeDir { + // skip this entry, only taking files into account + continue + } + if hdr.Name == metaFile { + data, err := ioutil.ReadAll(tr) + if err != nil { + return err + } + var meta ContextMetadata + if err := json.Unmarshal(data, &meta); err != nil { + return err + } + if err := s.CreateOrUpdateContext(name, meta); err != nil { + return err + } + } else if strings.HasPrefix(hdr.Name, "tls/") { + relative := strings.TrimPrefix(hdr.Name, "tls/") + parts := strings.SplitN(relative, "/", 2) + if len(parts) != 2 { + return errors.New("archive format is invalid") + } + endpointName := parts[0] + fileName := parts[1] + data, err := ioutil.ReadAll(tr) + if err != nil { + return err + } + if _, ok := tlsData.Endpoints[endpointName]; !ok { + tlsData.Endpoints[endpointName] = EndpointTLSData{ + Files: map[string][]byte{}, + } + } + tlsData.Endpoints[endpointName].Files[fileName] = data + } + } + return s.ResetContextTLSMaterial(name, &tlsData) +} + +type contextDoesNotExistError struct { + name string +} + +func (e *contextDoesNotExistError) Error() string { + return fmt.Sprintf("context %q does not exist", e.name) +} + +type tlsDataDoesNotExistError struct { + context, endpoint, file string +} + +func (e *tlsDataDoesNotExistError) Error() string { + return fmt.Sprintf("tls data for %s/%s/%s does not exist", e.context, e.endpoint, e.file) +} + +// IsErrContextDoesNotExist checks if the given error is a "context does not exist" condition +func IsErrContextDoesNotExist(err error) bool { + _, ok := err.(*contextDoesNotExistError) + return ok +} + +// IsErrTLSDataDoesNotExist checks if the given error is a "context does not exist" condition +func IsErrTLSDataDoesNotExist(err error) bool { + _, ok := err.(*tlsDataDoesNotExistError) + return ok +} diff --git a/cli/context/store/store_test.go b/cli/context/store/store_test.go new file mode 100644 index 0000000000..c1994d0441 --- /dev/null +++ b/cli/context/store/store_test.go @@ -0,0 +1,100 @@ +package store + +import ( + "io/ioutil" + "os" + "testing" + + "gotest.tools/assert" +) + +type endpoint struct { + Foo string `json:"a_very_recognizable_field_name"` +} + +type context struct { + Bar string `json:"another_very_recognizable_field_name"` +} + +var testCfg = NewConfig(func() interface{} { return &context{} }, + EndpointTypeGetter("ep1", func() interface{} { return &endpoint{} }), + EndpointTypeGetter("ep2", func() interface{} { return &endpoint{} }), +) + +func TestExportImport(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + s := New(testDir, testCfg) + err = s.CreateOrUpdateContext("source", + ContextMetadata{ + Endpoints: map[string]interface{}{ + "ep1": endpoint{Foo: "bar"}, + }, + Metadata: context{Bar: "baz"}, + }) + assert.NilError(t, err) + err = s.ResetContextEndpointTLSMaterial("source", "ep1", &EndpointTLSData{ + Files: map[string][]byte{ + "file1": []byte("test-data"), + }, + }) + assert.NilError(t, err) + r := Export("source", s) + defer r.Close() + err = Import("dest", s, r) + assert.NilError(t, err) + srcMeta, err := s.GetContextMetadata("source") + assert.NilError(t, err) + destMeta, err := s.GetContextMetadata("dest") + assert.NilError(t, err) + assert.DeepEqual(t, destMeta, srcMeta) + srcFileList, err := s.ListContextTLSFiles("source") + assert.NilError(t, err) + destFileList, err := s.ListContextTLSFiles("dest") + assert.NilError(t, err) + assert.DeepEqual(t, srcFileList, destFileList) + srcData, err := s.GetContextTLSData("source", "ep1", "file1") + assert.NilError(t, err) + assert.Equal(t, "test-data", string(srcData)) + destData, err := s.GetContextTLSData("dest", "ep1", "file1") + assert.NilError(t, err) + assert.Equal(t, "test-data", string(destData)) +} + +func TestRemove(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + s := New(testDir, testCfg) + err = s.CreateOrUpdateContext("source", + ContextMetadata{ + Endpoints: map[string]interface{}{ + "ep1": endpoint{Foo: "bar"}, + }, + Metadata: context{Bar: "baz"}, + }) + assert.NilError(t, err) + assert.NilError(t, s.ResetContextEndpointTLSMaterial("source", "ep1", &EndpointTLSData{ + Files: map[string][]byte{ + "file1": []byte("test-data"), + }, + })) + assert.NilError(t, s.RemoveContext("source")) + _, err = s.GetContextMetadata("source") + assert.Check(t, IsErrContextDoesNotExist(err)) + f, err := s.ListContextTLSFiles("source") + assert.NilError(t, err) + assert.Equal(t, 0, len(f)) +} + +func TestListEmptyStore(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + store := New(testDir, testCfg) + result, err := store.ListContexts() + assert.NilError(t, err) + assert.Check(t, result != nil) + assert.Check(t, len(result) == 0) +} diff --git a/cli/context/store/storeconfig.go b/cli/context/store/storeconfig.go new file mode 100644 index 0000000000..9746d93d77 --- /dev/null +++ b/cli/context/store/storeconfig.go @@ -0,0 +1,38 @@ +package store + +// TypeGetter is a func used to determine the concrete type of a context or +// endpoint metadata by returning a pointer to an instance of the object +// eg: for a context of type DockerContext, the corresponding TypeGetter should return new(DockerContext) +type TypeGetter func() interface{} + +// NamedTypeGetter is a TypeGetter associated with a name +type NamedTypeGetter struct { + name string + typeGetter TypeGetter +} + +// EndpointTypeGetter returns a NamedTypeGetter with the spcecified name and getter +func EndpointTypeGetter(name string, getter TypeGetter) NamedTypeGetter { + return NamedTypeGetter{ + name: name, + typeGetter: getter, + } +} + +// Config is used to configure the metadata marshaler of the context store +type Config struct { + contextType TypeGetter + endpointTypes map[string]TypeGetter +} + +// NewConfig creates a config object +func NewConfig(contextType TypeGetter, endpoints ...NamedTypeGetter) Config { + res := Config{ + contextType: contextType, + endpointTypes: make(map[string]TypeGetter), + } + for _, e := range endpoints { + res.endpointTypes[e.name] = e.typeGetter + } + return res +} diff --git a/cli/context/store/tlsstore.go b/cli/context/store/tlsstore.go new file mode 100644 index 0000000000..0d978df129 --- /dev/null +++ b/cli/context/store/tlsstore.go @@ -0,0 +1,99 @@ +package store + +import ( + "io/ioutil" + "os" + "path/filepath" +) + +const tlsDir = "tls" + +type tlsStore struct { + root string +} + +func (s *tlsStore) contextDir(name string) string { + return filepath.Join(s.root, name) +} + +func (s *tlsStore) endpointDir(contextName, name string) string { + return filepath.Join(s.root, contextName, name) +} + +func (s *tlsStore) filePath(contextName, endpointName, filename string) string { + return filepath.Join(s.root, contextName, endpointName, filename) +} + +func (s *tlsStore) createOrUpdate(contextName, endpointName, filename string, data []byte) error { + epdir := s.endpointDir(contextName, endpointName) + parentOfRoot := filepath.Dir(s.root) + if err := os.MkdirAll(parentOfRoot, 0755); err != nil { + return err + } + if err := os.MkdirAll(epdir, 0700); err != nil { + return err + } + return ioutil.WriteFile(s.filePath(contextName, endpointName, filename), data, 0600) +} + +func (s *tlsStore) getData(contextName, endpointName, filename string) ([]byte, error) { + data, err := ioutil.ReadFile(s.filePath(contextName, endpointName, filename)) + if err != nil { + return nil, convertTLSDataDoesNotExist(contextName, endpointName, filename, err) + } + return data, nil +} + +func (s *tlsStore) remove(contextName, endpointName, filename string) error { + err := os.Remove(s.filePath(contextName, endpointName, filename)) + if os.IsNotExist(err) { + return nil + } + return err +} + +func (s *tlsStore) removeAllEndpointData(contextName, endpointName string) error { + return os.RemoveAll(s.endpointDir(contextName, endpointName)) +} + +func (s *tlsStore) removeAllContextData(contextName string) error { + return os.RemoveAll(s.contextDir(contextName)) +} + +func (s *tlsStore) listContextData(contextName string) (map[string]EndpointFiles, error) { + epFSs, err := ioutil.ReadDir(s.contextDir(contextName)) + if err != nil { + if os.IsNotExist(err) { + return map[string]EndpointFiles{}, nil + } + return nil, err + } + r := make(map[string]EndpointFiles) + for _, epFS := range epFSs { + if epFS.IsDir() { + epDir := s.endpointDir(contextName, epFS.Name()) + fss, err := ioutil.ReadDir(epDir) + if err != nil { + return nil, err + } + var files EndpointFiles + for _, fs := range fss { + if !fs.IsDir() { + files = append(files, fs.Name()) + } + } + r[epFS.Name()] = files + } + } + return r, nil +} + +// EndpointFiles is a slice of strings representing file names +type EndpointFiles []string + +func convertTLSDataDoesNotExist(context, endpoint, file string, err error) error { + if os.IsNotExist(err) { + return &tlsDataDoesNotExistError{context: context, endpoint: endpoint, file: file} + } + return err +} diff --git a/cli/context/store/tlsstore_test.go b/cli/context/store/tlsstore_test.go new file mode 100644 index 0000000000..6079de0f8e --- /dev/null +++ b/cli/context/store/tlsstore_test.go @@ -0,0 +1,79 @@ +package store + +import ( + "io/ioutil" + "os" + "testing" + + "gotest.tools/assert" +) + +func TestTlsCreateUpdateGetRemove(t *testing.T) { + testDir, err := ioutil.TempDir("", "TestTlsCreateUpdateGetRemove") + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := tlsStore{root: testDir} + _, err = testee.getData("test-ctx", "test-ep", "test-data") + assert.Equal(t, true, IsErrTLSDataDoesNotExist(err)) + + err = testee.createOrUpdate("test-ctx", "test-ep", "test-data", []byte("data")) + assert.NilError(t, err) + data, err := testee.getData("test-ctx", "test-ep", "test-data") + assert.NilError(t, err) + assert.Equal(t, string(data), "data") + err = testee.createOrUpdate("test-ctx", "test-ep", "test-data", []byte("data2")) + assert.NilError(t, err) + data, err = testee.getData("test-ctx", "test-ep", "test-data") + assert.NilError(t, err) + assert.Equal(t, string(data), "data2") + + err = testee.remove("test-ctx", "test-ep", "test-data") + assert.NilError(t, err) + err = testee.remove("test-ctx", "test-ep", "test-data") + assert.NilError(t, err) + + _, err = testee.getData("test-ctx", "test-ep", "test-data") + assert.Equal(t, true, IsErrTLSDataDoesNotExist(err)) +} + +func TestTlsListAndBatchRemove(t *testing.T) { + testDir, err := ioutil.TempDir("", "TestTlsListAndBatchRemove") + assert.NilError(t, err) + defer os.RemoveAll(testDir) + testee := tlsStore{root: testDir} + + all := map[string]EndpointFiles{ + "ep1": {"f1", "f2", "f3"}, + "ep2": {"f1", "f2", "f3"}, + "ep3": {"f1", "f2", "f3"}, + } + + ep1ep2 := map[string]EndpointFiles{ + "ep1": {"f1", "f2", "f3"}, + "ep2": {"f1", "f2", "f3"}, + } + + for name, files := range all { + for _, file := range files { + err = testee.createOrUpdate("test-ctx", name, file, []byte("data")) + assert.NilError(t, err) + } + } + + resAll, err := testee.listContextData("test-ctx") + assert.NilError(t, err) + assert.DeepEqual(t, resAll, all) + + err = testee.removeAllEndpointData("test-ctx", "ep3") + assert.NilError(t, err) + resEp1ep2, err := testee.listContextData("test-ctx") + assert.NilError(t, err) + assert.DeepEqual(t, resEp1ep2, ep1ep2) + + err = testee.removeAllContextData("test-ctx") + assert.NilError(t, err) + resEmpty, err := testee.listContextData("test-ctx") + assert.NilError(t, err) + assert.DeepEqual(t, resEmpty, map[string]EndpointFiles{}) + +} diff --git a/cli/context/tlsdata.go b/cli/context/tlsdata.go new file mode 100644 index 0000000000..6bd05fbb78 --- /dev/null +++ b/cli/context/tlsdata.go @@ -0,0 +1,98 @@ +package context + +import ( + "io/ioutil" + + "github.com/docker/cli/cli/context/store" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + caKey = "ca.pem" + certKey = "cert.pem" + keyKey = "key.pem" +) + +// TLSData holds ca/cert/key raw data +type TLSData struct { + CA []byte + Key []byte + Cert []byte +} + +// ToStoreTLSData converts TLSData to the store representation +func (data *TLSData) ToStoreTLSData() *store.EndpointTLSData { + if data == nil { + return nil + } + result := store.EndpointTLSData{ + Files: make(map[string][]byte), + } + if data.CA != nil { + result.Files[caKey] = data.CA + } + if data.Cert != nil { + result.Files[certKey] = data.Cert + } + if data.Key != nil { + result.Files[keyKey] = data.Key + } + return &result +} + +// LoadTLSData loads TLS data from the store +func LoadTLSData(s store.Store, contextName, endpointName string) (*TLSData, error) { + tlsFiles, err := s.ListContextTLSFiles(contextName) + if err != nil { + return nil, errors.Wrapf(err, "failed to retrieve context tls files for context %q", contextName) + } + if epTLSFiles, ok := tlsFiles[endpointName]; ok { + var tlsData TLSData + for _, f := range epTLSFiles { + data, err := s.GetContextTLSData(contextName, endpointName, f) + if err != nil { + return nil, errors.Wrapf(err, "failed to retrieve context tls data for file %q of context %q", f, contextName) + } + switch f { + case caKey: + tlsData.CA = data + case certKey: + tlsData.Cert = data + case keyKey: + tlsData.Key = data + default: + logrus.Warnf("unknown file %s in context %s tls bundle", f, contextName) + } + } + return &tlsData, nil + } + return nil, nil +} + +// TLSDataFromFiles reads files into a TLSData struct (or returns nil if all paths are empty) +func TLSDataFromFiles(caPath, certPath, keyPath string) (*TLSData, error) { + var ( + ca, cert, key []byte + err error + ) + if caPath != "" { + if ca, err = ioutil.ReadFile(caPath); err != nil { + return nil, err + } + } + if certPath != "" { + if cert, err = ioutil.ReadFile(certPath); err != nil { + return nil, err + } + } + if keyPath != "" { + if key, err = ioutil.ReadFile(keyPath); err != nil { + return nil, err + } + } + if ca == nil && cert == nil && key == nil { + return nil, nil + } + return &TLSData{CA: ca, Cert: cert, Key: key}, nil +} diff --git a/cli/flags/common.go b/cli/flags/common.go index 22faf12ca6..a3bbf29571 100644 --- a/cli/flags/common.go +++ b/cli/flags/common.go @@ -37,6 +37,7 @@ type CommonOptions struct { TLS bool TLSVerify bool TLSOptions *tlsconfig.Options + Context string } // NewCommonOptions returns a new CommonOptions @@ -70,6 +71,8 @@ func (commonOpts *CommonOptions) InstallFlags(flags *pflag.FlagSet) { // opts.ValidateHost is not used here, so as to allow connection helpers hostOpt := opts.NewNamedListOptsRef("hosts", &commonOpts.Hosts, nil) flags.VarP(hostOpt, "host", "H", "Daemon socket(s) to connect to") + flags.StringVarP(&commonOpts.Context, "context", "c", "", + `Name of the context to use to connect to the daemon (overrides DOCKER_HOST env var and default context set with "docker context use")`) } // SetDefaultOptions sets default values for options after flag parsing is