Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ update-golden-files:
mock-ssh:
cd ./test && docker-compose up

mock-performance-ssh:
cd ./test && docker-compose -f docker-compose-performance.yaml up

build:
CGO_ENABLED=0 go build \
-ldflags "-s -w -X '${PACKAGE}/cmd.version=${VERSION}' -X '${PACKAGE}/cmd.commit=${GIT}' -X '${PACKAGE}/cmd.date=${DATE}'" \
Expand Down
12 changes: 12 additions & 0 deletions core/dao/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Server struct {
WorkDir string
IdentityFile *string
Password *string
AuthMethod string

context string // config path
contextLine int // defined at
Expand All @@ -45,6 +46,7 @@ type ServerYAML struct {
WorkDir string `yaml:"work_dir"`
IdentityFile *string `yaml:"identity_file"`
Password *string `yaml:"password"`
AuthMethod string `yaml:"-"`
}

func (s Server) GetValue(key string, _ int) string {
Expand Down Expand Up @@ -188,6 +190,16 @@ func (c *ConfigYAML) ParseServersYAML() ([]Server, []ResourceErrors[Server]) {
server.Password = serverYAML.Password
}

if server.IdentityFile != nil && server.Password != nil {
server.AuthMethod = "password-key"
} else if server.IdentityFile != nil {
server.AuthMethod = "key"
} else if server.Password != nil {
server.AuthMethod = "password"
} else {
server.AuthMethod = "none"
}

servers = append(servers, *server)
}

Expand Down
4 changes: 4 additions & 0 deletions core/run/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ type ErrConnect struct {
Port uint16
Reason string
}

func (e *ErrConnect) Error() string {
return ""
}
166 changes: 67 additions & 99 deletions core/run/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"golang.org/x/crypto/ssh"
"os"
"os/signal"
"os/user"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -161,45 +160,6 @@ func (run *Run) SetClients(
clientCh chan Client,
errCh chan ErrConnect,
) ([]ErrConnect, error) {
globalIdentityFile, globalPassword := getGlobalIdentity(runFlags)

// Iterate through servers and create a singleton AuthMethod, which is used for
// connecting to all hosts using a identity key. Servers which only use a password
// are not included here and are handled separately.
var identities []Identity
for _, server := range run.Servers {
if server.Local {
continue
}

var pass *string
if server.Password != nil {
pw, err := dao.EvaluatePassword(*server.Password)
pass = &pw
if err != nil {
errConnect := &ErrConnect{
Name: server.Name,
User: server.User,
Host: server.Host,
Port: server.Port,
Reason: err.Error(),
}
return []ErrConnect{*errConnect}, nil
}
}

identities = append(identities, Identity{
IdentityFile: server.IdentityFile,
Password: pass,
})
}

// VerifyHost
authMethod, err := InitAuthMethod(globalIdentityFile, globalPassword, identities)
if err != nil {
return []ErrConnect{}, err
}

createLocalClient := func(server dao.Server, wg *sync.WaitGroup, mu *sync.Mutex) {
defer wg.Done()

Expand All @@ -212,37 +172,15 @@ func (run *Run) SetClients(
clientCh <- local
}

createRemoteClient := func(server dao.Server, wg *sync.WaitGroup, mu *sync.Mutex) {
createRemoteClient := func(authMethod []ssh.AuthMethod, server dao.Server, wg *sync.WaitGroup, mu *sync.Mutex) {
defer wg.Done()

var auth ssh.AuthMethod
if server.IdentityFile == nil && server.Password != nil {
// Password only logic
password, err := dao.EvaluatePassword(*server.Password)
if err != nil {
errConnect := &ErrConnect{
Name: server.Name,
User: server.User,
Host: server.Host,
Port: server.Port,
Reason: err.Error(),
}
errCh <- *errConnect
}

passwordAuth := ssh.Password(password)
auth = passwordAuth
} else {
// Identity key logic
auth = authMethod
}

remote := &SSHClient{
Name: server.Name,
User: server.User,
Host: server.Host,
Port: server.Port,
AuthMethod: auth,
AuthMethod: authMethod,
}

if err := remote.Connect(run.Config.DisableVerifyHost, run.Config.KnownHostsFile, mu); err != nil {
Expand All @@ -255,13 +193,77 @@ func (run *Run) SetClients(

var wg sync.WaitGroup
var mu sync.Mutex

agentSigners, err := GetSSHAgentSigners()
if err != nil {
return []ErrConnect{}, err
}

globalSigner, err := GetGlobalIdentitySigner(runFlags)
if err != nil {
return []ErrConnect{}, err
}

identities := make(map[string]ssh.Signer)
passwordAuthMethods := make(map[string]ssh.AuthMethod)
for _, server := range run.Servers {
if server.AuthMethod == "password-key" {
_, found := identities[*server.IdentityFile]
if !found {
signer, err := GetPassworIdentitySigner(server)
if err != nil {
return []ErrConnect{*err}, nil
}
identities[*server.IdentityFile] = signer
}
} else if server.AuthMethod == "key" {
_, found := identities[*server.IdentityFile]
if !found {
signer, err := GetIdentity(server)
if err != nil {
return []ErrConnect{*err}, nil
}
identities[*server.IdentityFile] = signer
}
} else if server.AuthMethod == "password" {
_, found := passwordAuthMethods[*server.Password]
if !found {
passAuthMethod, err := GetPasswordAuth(server)
if err != nil {
return []ErrConnect{*err}, nil
}
passwordAuthMethods[*server.Password] = passAuthMethod
}
}
}

for _, server := range run.Servers {
wg.Add(1)
go createLocalClient(server, &wg, &mu)

if !server.Local {
wg.Add(1)
go createRemoteClient(server, &wg, &mu)

var authMethods []ssh.AuthMethod
var signers []ssh.Signer

if globalSigner != nil {
signers = append(signers, globalSigner)
} else if server.AuthMethod == "password" {
pwAuth := passwordAuthMethods[*server.Password]
authMethods = append(authMethods, pwAuth)
} else if server.AuthMethod == "key" || server.AuthMethod == "password-key" {
identitySigner := identities[*server.IdentityFile]
signers = append(signers, identitySigner)
} else if agentSigners != nil {
signers = append(signers, agentSigners...)
}

if len(signers) > 0 {
authMethods = append(authMethods, ssh.PublicKeys(signers...))
}

go createRemoteClient(authMethods, server, &wg, &mu)
}
}
wg.Wait()
Expand Down Expand Up @@ -412,40 +414,6 @@ func (run *Run) setKnownHostsFile(knownHostsFileFlag string) error {
return nil
}

func getGlobalIdentity(runFlags *core.RunFlags) (string, string) {
var identityFile string
var password string

if runFlags.IdentityFile != "" {
identityFile = runFlags.IdentityFile
} else {
value, found := os.LookupEnv("SAKE_IDENTITY_FILE")
if found {
if strings.HasPrefix(value, "~/") {
usr, err := user.Current()
if err != nil {
panic(err)
}
dir := usr.HomeDir
identityFile = filepath.Join(dir, value[2:])
} else {
identityFile = value
}
}
}

if runFlags.Password != "" {
password = runFlags.Password
} else {
value, found := os.LookupEnv("SAKE_PASSWORD")
if found {
password = value
}
}

return identityFile, password
}

func getWorkDir(cmd dao.TaskCmd, server dao.Server) string {
if cmd.Local || server.Local {
rootDir := os.ExpandEnv(cmd.RootDir)
Expand Down
Loading