Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix: Use environment variables for agent authentication #1238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (

type Options struct {
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}

Expand All @@ -66,6 +67,7 @@ func New(dialer Dialer, options *Options) io.Closer {
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
}
server.init(ctx)
return server
Expand All @@ -83,23 +85,21 @@ type agent struct {
closeMutex sync.Mutex
closed chan struct{}

// Environment variables sent by Coder to inject for shell sessions.
// These are atomic because values can change after reconnect.
envVars atomic.Value
ownerEmail atomic.String
ownerUsername atomic.String
envVars map[string]string
// metadata is atomic because values can change after reconnection.
metadata atomic.Value
startupScript atomic.Bool
sshServer *ssh.Server
}

func (a *agent) run(ctx context.Context) {
var options Metadata
var metadata Metadata
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
options, peerListener, err = a.dialer(ctx, a.logger)
metadata, peerListener, err = a.dialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
Expand All @@ -118,14 +118,12 @@ func (a *agent) run(ctx context.Context) {
return
default:
}
a.envVars.Store(options.EnvironmentVariables)
a.ownerEmail.Store(options.OwnerEmail)
a.ownerUsername.Store(options.OwnerUsername)
a.metadata.Store(metadata)

if a.startupScript.CAS(false, true) {
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, options.StartupScript)
err := a.runStartupScript(ctx, metadata.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
Expand Down Expand Up @@ -172,7 +170,7 @@ func (*agent) runStartupScript(ctx context.Context, script string) error {
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
if err != nil {
// If the syslog isn't supported or cannot be created, use a text file in temp.
writer, err = os.CreateTemp("", "coder-startup-script.txt")
writer, err = os.CreateTemp("", "coder-startup-script-*.txt")
if err != nil {
return xerrors.Errorf("open startup script log file: %w", err)
}
Expand Down Expand Up @@ -319,6 +317,15 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
return nil, xerrors.Errorf("get user shell: %w", err)
}

rawMetadata := a.metadata.Load()
if rawMetadata == nil {
return nil, xerrors.Errorf("no metadata was provided: %w", err)
}
metadata, valid := rawMetadata.(Metadata)
if !valid {
return nil, xerrors.Errorf("metadata is the wrong type: %T", metadata)
}

// gliderlabs/ssh returns a command slice of zero
// when a shell is requested.
command := rawCommand
Expand All @@ -344,22 +351,23 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
// These prevent the user from having to specify _anything_ to successfully commit.
// Both author and committer must be set!
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, a.ownerEmail.Load()))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_EMAIL=%s`, a.ownerEmail.Load()))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_NAME=%s`, a.ownerUsername.Load()))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_NAME=%s`, a.ownerUsername.Load()))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, metadata.OwnerEmail))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_EMAIL=%s`, metadata.OwnerEmail))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_NAME=%s`, metadata.OwnerUsername))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_NAME=%s`, metadata.OwnerUsername))

// Load environment variables passed via the agent.
// These should override all variables we manually specify.
envVars := a.envVars.Load()
if envVars != nil {
envVarMap, ok := envVars.(map[string]string)
if ok {
for key, value := range envVarMap {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
}
}
for key, value := range metadata.EnvironmentVariables {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
}

// Agent-level environment variables should take over all!
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
for key, value := range a.envVars {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
}

return cmd, nil
}

Expand Down
34 changes: 13 additions & 21 deletions cli/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@ import (

func workspaceAgent() *cobra.Command {
var (
rawURL string
auth string
token string
auth string
)
cmd := &cobra.Command{
Use: "agent",
// This command isn't useful to manually execute.
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
if rawURL == "" {
return xerrors.New("CODER_URL must be set")
rawURL, err := cmd.Flags().GetString(varAgentURL)
if err != nil {
return xerrors.Errorf("CODER_AGENT_URL must be set: %w", err)
}
coderURL, err := url.Parse(rawURL)
if err != nil {
Expand All @@ -46,8 +45,9 @@ func workspaceAgent() *cobra.Command {
var exchangeToken func(context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error)
switch auth {
case "token":
if token == "" {
return xerrors.Errorf("CODER_TOKEN must be set for token auth")
token, err := cmd.Flags().GetString(varAgentToken)
if err != nil {
return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err)
}
client.SessionToken = token
case "google-instance-identity":
Expand Down Expand Up @@ -115,27 +115,19 @@ func workspaceAgent() *cobra.Command {
}
}

cfg := createConfig(cmd)
err = cfg.AgentSession().Write(client.SessionToken)
if err != nil {
return xerrors.Errorf("writing agent session token to config: %w", err)
}
err = cfg.URL().Write(client.URL.String())
if err != nil {
return xerrors.Errorf("writing agent url to config: %w", err)
}

closer := agent.New(client.ListenWorkspaceAgent, &agent.Options{
Logger: logger,
EnvironmentVariables: map[string]string{
// Override the "CODER_AGENT_TOKEN" variable in all
// shells so "gitssh" works!
"CODER_AGENT_TOKEN": client.SessionToken,
},
})
<-cmd.Context().Done()
return closer.Close()
},
}

cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent")
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder")
cliflag.StringVarP(cmd.Flags(), &token, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder")

cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AGENT_AUTH", "token", "Specify the authentication type to use for the agent")
return cmd
}
6 changes: 3 additions & 3 deletions cli/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
go func() {
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
go func() {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
go func() {
Expand Down
9 changes: 9 additions & 0 deletions cli/cliflag/cliflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ import (
"github.com/spf13/pflag"
)

// String sets a string flag on the given flag set.
func String(flagset *pflag.FlagSet, name, shorthand, env, def, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
v = def
}
flagset.StringP(name, shorthand, v, fmtUsage(usage, env))
}

// StringVarP sets a string flag on the given flag set.
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
v, ok := os.LookupEnv(env)
Expand Down
24 changes: 23 additions & 1 deletion cli/cliflag/cliflag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ import (
//nolint:paralleltest
func TestCliflag(t *testing.T) {
t.Run("StringDefault", func(t *testing.T) {
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.String(flagset, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
})

t.Run("StringEnvVar", func(t *testing.T) {
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.String(10)
cliflag.String(flagset, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})

t.Run("StringVarPDefault", func(t *testing.T) {
var ptr string
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.String(10)
Expand All @@ -28,7 +50,7 @@ func TestCliflag(t *testing.T) {
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
})

t.Run("StringEnvVar", func(t *testing.T) {
t.Run("StringVarPEnvVar", func(t *testing.T) {
var ptr string
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
Expand Down
4 changes: 0 additions & 4 deletions cli/config/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ func (r Root) Organization() File {
return File(filepath.Join(string(r), "organization"))
}

func (r Root) AgentSession() File {
return File(filepath.Join(string(r), "agentsession"))
}

// File provides convenience methods for interacting with *os.File.
type File string

Expand Down
18 changes: 2 additions & 16 deletions cli/gitssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cli

import (
"fmt"
"net/url"
"os"
"os/exec"
"strings"
Expand All @@ -11,7 +10,6 @@ import (
"golang.org/x/xerrors"

"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)

func gitssh() *cobra.Command {
Expand All @@ -20,22 +18,10 @@ func gitssh() *cobra.Command {
Hidden: true,
Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := createConfig(cmd)
rawURL, err := cfg.URL().Read()
client, err := createAgentClient(cmd)
if err != nil {
return xerrors.Errorf("read agent url from config: %w", err)
return xerrors.Errorf("create agent client: %w", err)
}
parsedURL, err := url.Parse(rawURL)
if err != nil {
return xerrors.Errorf("parse agent url from config: %w", err)
}
session, err := cfg.AgentSession().Read()
if err != nil {
return xerrors.Errorf("read agent session from config: %w", err)
}
client := codersdk.New(parsedURL)
client.SessionToken = session

key, err := client.AgentGitSSHKey(cmd.Context())
if err != nil {
return xerrors.Errorf("get agent git ssh token: %w", err)
Expand Down
20 changes: 3 additions & 17 deletions cli/gitssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ import (

"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"

"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/cli/config"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
Expand Down Expand Up @@ -61,7 +59,7 @@ func TestGitSSH(t *testing.T) {
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)

// start workspace agent
cmd, root := clitest.New(t, "agent", "--token", agentToken, "--url", client.URL.String())
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
agentClient := &*client
clitest.SetupConfig(t, agentClient, root)
ctx, cancelFunc := context.WithCancel(context.Background())
Expand Down Expand Up @@ -92,7 +90,7 @@ func TestGitSSH(t *testing.T) {
// as long as we get a successful session we don't care if the server errors
_ = ssh.Serve(l, func(s ssh.Session) {
atomic.AddInt64(&inc, 1)
t.Log("got authenticated sesion")
t.Log("got authenticated session")
err := s.Exit(0)
require.NoError(t, err)
}, publicKeyOption)
Expand All @@ -101,22 +99,10 @@ func TestGitSSH(t *testing.T) {
// start ssh session
addr, ok := l.Addr().(*net.TCPAddr)
require.True(t, ok)
cfgDir := createConfig(cmd)
// set to agent config dir
cmd, root = clitest.New(t, "gitssh", "--global-config="+string(cfgDir), "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "127.0.0.1")
clitest.SetupConfig(t, agentClient, root)

cmd, _ = clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "127.0.0.1")
err = cmd.ExecuteContext(context.Background())
require.NoError(t, err)
require.EqualValues(t, 1, inc)
})
}

// createConfig consumes the global configuration flag to produce a config root.
func createConfig(cmd *cobra.Command) config.Root {
globalRoot, err := cmd.Flags().GetString("global-config")
if err != nil {
panic(err)
}
return config.Root(globalRoot)
}
Loading