Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion plumbing/transport/ssh/auth_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,10 @@ func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
// /etc/ssh/ssh_known_hosts
func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) {
kh, err := NewKnownHostsDb(files...)
return kh.HostKeyCallback(), err
if err != nil {
return nil, err
}
return kh.HostKeyCallback(), nil
}

// NewKnownHostsDb returns knownhosts.HostKeyDB based on a file based on a
Expand Down Expand Up @@ -311,13 +314,40 @@ type HostKeyCallbackHelper struct {
// HostKeyAlgorithms is a list of supported host key algorithms that will
// be used for host key verification.
HostKeyAlgorithms []string

// fallback allows for injecting the fallback call, which is called
// when a HostKeyCallback is not set.
fallback func(files ...string) (ssh.HostKeyCallback, error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth indicating that this is only meant for testing, but in practice it always calls NewKnownHostsCallback

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I was initially considered this, but given that it is already an internal field and the code may not necessarily move into v6, I ended up deciding against it.

}

// SetHostKeyCallbackAndAlgorithms sets the field HostKeyCallback and HostKeyAlgorithms in the given cfg.
// If the host key callback or algorithms is empty it is left empty. It will be handled by the dial method,
// falling back to knownhosts.
func (m *HostKeyCallbackHelper) SetHostKeyCallbackAndAlgorithms(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) {
if cfg == nil {
cfg = &ssh.ClientConfig{}
}

if m.HostKeyCallback == nil {
if m.fallback == nil {
m.fallback = NewKnownHostsCallback
}

hkcb, err := m.fallback()
if err != nil {
return nil, fmt.Errorf("cannot create known hosts callback: %w", err)
}

cfg.HostKeyCallback = hkcb
cfg.HostKeyAlgorithms = m.HostKeyAlgorithms
return cfg, err
}

cfg.HostKeyCallback = m.HostKeyCallback
cfg.HostKeyAlgorithms = m.HostKeyAlgorithms
return cfg, nil
}

func (m *HostKeyCallbackHelper) SetHostKeyCallback(cfg *ssh.ClientConfig) (*ssh.ClientConfig, error) {
return m.SetHostKeyCallbackAndAlgorithms(cfg)
}
101 changes: 101 additions & 0 deletions plumbing/transport/ssh/auth_method_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import (
"bufio"
"fmt"
"os"
"reflect"
"runtime"
"slices"
"strings"
"testing"

"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-billy/v5/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"

Expand Down Expand Up @@ -317,3 +321,100 @@ func (*SuiteCommon) TestNewKnownHostsDbWithCert(c *C) {
}
}
}

func TestHostKeyCallbackHelper(t *testing.T) {
cb1 := ssh.FixedHostKey(nil)
tests := []struct {
name string
cb ssh.HostKeyCallback
algos []string
fallback func(files ...string) (ssh.HostKeyCallback, error)
cc *ssh.ClientConfig
want *ssh.ClientConfig
wantErr string
}{
{
name: "keep existing callback if set",
cb: cb1,
cc: &ssh.ClientConfig{},
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
},
},
{
name: "create new client config is one isn't provided",
cb: cb1,
cc: nil,
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
},
},
{
name: "respect pre-set algos",
cb: cb1,
algos: []string{"foo"},
cc: &ssh.ClientConfig{},
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
HostKeyAlgorithms: []string{"foo"},
},
},
{
name: "no callback is set, call fallback",
cc: &ssh.ClientConfig{},
fallback: func(files ...string) (ssh.HostKeyCallback, error) {
return cb1, nil
},
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
},
},
{
name: "no callback is set with nil client config",
fallback: func(files ...string) (ssh.HostKeyCallback, error) {
return cb1, nil
},
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
},
},
{
name: "algos with no callback, call fallback",
algos: []string{"bar"},
cc: &ssh.ClientConfig{},
fallback: func(files ...string) (ssh.HostKeyCallback, error) {
return cb1, nil
},
want: &ssh.ClientConfig{
HostKeyCallback: cb1,
HostKeyAlgorithms: []string{"bar"},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
helper := HostKeyCallbackHelper{
HostKeyCallback: tc.cb,
HostKeyAlgorithms: tc.algos,
fallback: tc.fallback,
}

got, gotErr := helper.SetHostKeyCallback(tc.cc)

if tc.wantErr == "" {
require.NoError(t, gotErr)
require.NotNil(t, got)

wantFunc := runtime.FuncForPC(reflect.ValueOf(tc.want.HostKeyCallback).Pointer()).Name()
gotFunc := runtime.FuncForPC(reflect.ValueOf(got.HostKeyCallback).Pointer()).Name()
assert.Equal(t, wantFunc, gotFunc)

assert.Equal(t, tc.want.HostKeyAlgorithms, got.HostKeyAlgorithms)
} else {
assert.ErrorContains(t, gotErr, tc.wantErr)
assert.Nil(t, got)
}
})
}
}
Loading