From 7a64b7bcb0dfb2b0846885ac4b95a3971ded5fe0 Mon Sep 17 00:00:00 2001 From: 35niavlys <35niavlys@gmail.com> Date: Fri, 13 Oct 2023 20:15:36 +0200 Subject: [PATCH] Work in progress to add restrict-destination-v00@openssh.com --- ssh/agent/client.go | 14 ++ ssh/agent/client_test.go | 2 +- ssh/agent/extension/messages_copy.go | 29 +++ ssh/agent/extension/restrict_destination.go | 216 ++++++++++++++++++++ ssh/agent/extension/session_bind.go | 26 +++ ssh/agent/extension/sign_data.go | 82 ++++++++ ssh/agent/keyring.go | 200 ++++++++++++++++-- ssh/agent/server.go | 129 ++++++++---- 8 files changed, 634 insertions(+), 64 deletions(-) create mode 100644 ssh/agent/extension/messages_copy.go create mode 100644 ssh/agent/extension/restrict_destination.go create mode 100644 ssh/agent/extension/session_bind.go create mode 100644 ssh/agent/extension/sign_data.go diff --git a/ssh/agent/client.go b/ssh/agent/client.go index fecba8eb38..d01ee12ff0 100644 --- a/ssh/agent/client.go +++ b/ssh/agent/client.go @@ -14,6 +14,7 @@ package agent // import "golang.org/x/crypto/ssh/agent" import ( "bytes" + "context" "crypto/dsa" "crypto/ecdsa" "crypto/ed25519" @@ -89,6 +90,19 @@ type ExtendedAgent interface { Extension(extensionType string, contents []byte) ([]byte, error) } +type ContextAgent interface { + InitContext(ctx context.Context) context.Context + List(ctx context.Context) ([]*Key, error) + Add(ctx context.Context, key AddedKey) error + Remove(ctx context.Context, key ssh.PublicKey) error + RemoveAll(ctx context.Context) error + Lock(ctx context.Context, passphrase []byte) error + Unlock(ctx context.Context, passphrase []byte) error + Signers(ctx context.Context) ([]ssh.Signer, error) + Sign(ctx context.Context, key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) + Extension(ctx context.Context, extensionType string, contents []byte) ([]byte, error) +} + // ConstraintExtension describes an optional constraint defined by users. type ConstraintExtension struct { // ExtensionName consist of a UTF-8 string suffixed by the diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go index fdc8000654..7b5ce7593c 100644 --- a/ssh/agent/client_test.go +++ b/ssh/agent/client_test.go @@ -502,7 +502,7 @@ func testAgentLifetime(t *testing.T, agent Agent) { } type keyringExtended struct { - *keyring + *legacyKeyring } func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) { diff --git a/ssh/agent/extension/messages_copy.go b/ssh/agent/extension/messages_copy.go new file mode 100644 index 0000000000..912099c34a --- /dev/null +++ b/ssh/agent/extension/messages_copy.go @@ -0,0 +1,29 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package extension + +import ( + "encoding/binary" + "errors" +) + +// copy of messages.go +var errShortRead = errors.New("ssh: short read") + +// copy of messages.go +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} diff --git a/ssh/agent/extension/restrict_destination.go b/ssh/agent/extension/restrict_destination.go new file mode 100644 index 0000000000..cfb01f78f0 --- /dev/null +++ b/ssh/agent/extension/restrict_destination.go @@ -0,0 +1,216 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package extension + +import ( + "bytes" +) + +const EXT_NAME_RESTRICT_DESTINATION_00 = "restrict-destination-v00@openssh.com" + +type KeySpec struct { + keyblob []byte + is_ca bool +} + +type Hop struct { + Username string + Hostname string + reserved []byte + Hostkeys []KeySpec +} + +type DestinationConstraint struct { + From Hop + To Hop + reserved []byte +} + +/* +func parseKeyspec(data []byte) (KeySpec, error) { + fmt.Println(data) + var ks KeySpec + var ok bool + if ks.keyblob, _, ok = parseString(data); !ok { + return ks, errShortRead + } + return ks, nil +}*/ + +func parseHop(data []byte) (Hop, error) { + var hop Hop + var ok bool + var to_use []byte + { + if to_use, data, ok = parseString(data); !ok { + return hop, errShortRead + } + hop.Username = string(to_use) + } + { + if to_use, data, ok = parseString(data); !ok { + return hop, errShortRead + } + hop.Hostname = string(to_use) + } + { + if hop.reserved, data, ok = parseString(data); !ok { + return hop, errShortRead + } + } + { + for len(data) > 0 { + var keyspec KeySpec + if keyspec.keyblob, data, ok = parseString(data); !ok { + return hop, errShortRead + } + if len(data) == 0 { + return hop, errShortRead + } + keyspec.is_ca = data[0] != 0 + data = data[1:] + hop.Hostkeys = append(hop.Hostkeys, keyspec) + } + } + + return hop, nil +} + +func parseConstraint(data []byte) (DestinationConstraint, error) { + var constraint DestinationConstraint + var datahop []byte + var ok bool + var err error + if datahop, data, ok = parseString(data); !ok { + return constraint, errShortRead + } + if constraint.From, err = parseHop(datahop); err != nil { + return constraint, err + } + if datahop, constraint.reserved, ok = parseString(data); !ok { + return constraint, errShortRead + } + if constraint.To, err = parseHop(datahop); err != nil { + return constraint, err + } + return constraint, nil +} + +func ParseRestrictDestinations(data []byte) ([]DestinationConstraint, error) { + var constraints []DestinationConstraint + var to_use []byte + var ok bool + + for len(data) > 0 { + var constr DestinationConstraint + var err error + if to_use, data, ok = parseString(data); !ok { + return constraints, errShortRead + } + if constr, err = parseConstraint(to_use); err != nil { + return constraints, err + } + constraints = append(constraints, constr) + } + + return constraints, nil +} + +func IdentityPermitted(destinationConstraints []DestinationConstraint, sessions []SessionBind, user string) bool { + + if len(destinationConstraints) == 0 { + return true // unconstrained + } + if len(sessions) == 0 { + return true // local use + } + + fromkey := []KeySpec(nil) + /* + * Walk through the hops recorded by session_id and try to find a constraint that satisfies each. + */ + for i := 0; i < len(sessions); i++ { + sessionBind := sessions[i] + testuser := "" + if i == len(sessions)-1 { + testuser = user + if sessionBind.IsForwarding { + return false // tried to sign on forwarding hop + } + } else if !sessionBind.IsForwarding { + return false // tried to forward though signing bind + } + + ok := false + for _, destinationConstraint := range destinationConstraints { + if destinationConstraint.IdentityPermitted(fromkey, sessionBind.Hostkey, testuser) { + ok = true + break + } + } + if !ok { + return false + } + fromkey = []KeySpec{{keyblob: sessionBind.Hostkey, is_ca: false}} + } + + /* + * Another special case: if the last bound session ID was for a + * forwarding, and this function is not being called to check a sign + * request (i.e. no 'user' supplied), then only permit the key if + * there is a permission that would allow it to be used at another + * destination. This hides keys that are allowed to be used to + * authenicate *to* a host but not permitted for *use* beyond it. + */ + lastBindSession := sessions[len(sessions)-1] + if lastBindSession.IsForwarding && len(user) == 0 { + for _, destinationConstraint := range destinationConstraints { + if destinationConstraint.IdentityPermitted([]KeySpec{{keyblob: lastBindSession.Hostkey, is_ca: false}}, nil, "") { + return true + } + } + return false + } + /* success */ + return true +} + +func (d *DestinationConstraint) IdentityPermitted(fromkey []KeySpec, tokey []byte, user string) bool { + if len(fromkey) == 0 { + /* We are matching the first hop */ + if len(d.From.Hostname) > 0 || len(d.From.Hostkeys) > 0 { + return false + } + } else { + for _, hk := range d.From.Hostkeys { + for _, keySpec := range fromkey { + if bytes.Equal(hk.keyblob, keySpec.keyblob) { + goto to + } + } + } + return false + } + +to: + /* Match 'to' key */ + if len(tokey) > 0 { + for _, hk := range d.To.Hostkeys { + if bytes.Equal(hk.keyblob, tokey) { + goto user + } + } + return false + } + +user: + /* Match user if specified */ + // FIXME: sould be a pattern + if len(d.To.Username) > 0 && len(user) > 0 && d.To.Username != user { + return false + } + + return true +} diff --git a/ssh/agent/extension/session_bind.go b/ssh/agent/extension/session_bind.go new file mode 100644 index 0000000000..d160eeef00 --- /dev/null +++ b/ssh/agent/extension/session_bind.go @@ -0,0 +1,26 @@ +package extension + +import ( + "bytes" + + "golang.org/x/crypto/ssh" +) + +const EXT_NAME_SESSION_BIND = "session-bind@openssh.com" + +type SessionBind struct { + Hostkey []byte + SessionIdentifier []byte + Signature []byte + IsForwarding bool +} + +func (s *SessionBind) Matching(sessionId []byte) bool { + return bytes.Equal(s.SessionIdentifier, sessionId) +} + +func ParseSessionBind(data []byte) (SessionBind, error) { + var sessionBind SessionBind + err := ssh.Unmarshal(data, &sessionBind) + return sessionBind, err +} diff --git a/ssh/agent/extension/sign_data.go b/ssh/agent/extension/sign_data.go new file mode 100644 index 0000000000..e841ae59f8 --- /dev/null +++ b/ssh/agent/extension/sign_data.go @@ -0,0 +1,82 @@ +package extension + +import ( + "bytes" + "fmt" +) + +type SignData struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo string + PubKey []byte + HostKey []byte + Signature []byte +} + +func ParseSignData(data []byte, expectedKey []byte) (signData SignData, err error) { + signData = SignData{} + var tempArray []byte + + var ok bool + // Session + if signData.Session, data, ok = parseString(data); !ok { + return signData, errShortRead + } + // Type + signData.Type = data[0] + data = data[1:] + // User + if tempArray, data, ok = parseString(data); !ok { + return signData, errShortRead + } + signData.User = string(tempArray) + // Service + if tempArray, data, ok = parseString(data); !ok { + return signData, errShortRead + } + signData.Service = string(tempArray) + // Method + if tempArray, data, ok = parseString(data); !ok { + return signData, errShortRead + } + signData.Method = string(tempArray) + // Sign + signData.Sign = data[0] != 0 + data = data[1:] + // Algo + if tempArray, data, ok = parseString(data); !ok { + return signData, errShortRead + } + signData.Algo = string(tempArray) + + // Key or Host + var key []byte + if key, data, ok = parseString(data); !ok { + return signData, errShortRead + } + + if signData.Type != 50 || + !signData.Sign || + signData.Service != "ssh-connection" || + !bytes.Equal(expectedKey, key) { + return signData, fmt.Errorf("ssh: invalid sign data") + } + + if signData.Method == "publickey-hostbound-v00@openssh.com" { + signData.PubKey = key + if signData.HostKey, _, ok = parseString(data); !ok { + return signData, errShortRead + } + } else if signData.Method == "publickey" { + signData.HostKey = key + } else { + return signData, fmt.Errorf("ssh: invalid method sign data") + } + + return signData, err +} diff --git a/ssh/agent/keyring.go b/ssh/agent/keyring.go index 21bfa870fa..1659a21414 100644 --- a/ssh/agent/keyring.go +++ b/ssh/agent/keyring.go @@ -6,6 +6,7 @@ package agent import ( "bytes" + "context" "crypto/rand" "crypto/subtle" "errors" @@ -14,15 +15,24 @@ import ( "time" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent/extension" ) +type SessionsKey struct{} + type privKey struct { - signer ssh.Signer - comment string - expire *time.Time + signer ssh.Signer + comment string + expire *time.Time + destinationConstraints []extension.DestinationConstraint +} + +type legacyKeyring struct { + k ctxKeyring + dummyCtx context.Context } -type keyring struct { +type ctxKeyring struct { mu sync.Mutex keys []privKey @@ -31,15 +41,40 @@ type keyring struct { } var errLocked = errors.New("agent: locked") +var errLegacyAgent = errors.New("agent: please use NewContextKeyring instead of NewKeyring") // NewKeyring returns an Agent that holds keys in memory. It is safe // for concurrent use by multiple goroutines. -func NewKeyring() Agent { - return &keyring{} +// Deprecated: use NewContextKeyring +func NewKeyring() ExtendedAgent { + return &legacyKeyring{dummyCtx: context.TODO()} +} + +func NewContextKeyring() ContextAgent { + return &ctxKeyring{} +} + +func (r *ctxKeyring) InitContext(ctx context.Context) context.Context { + sessions := make([]extension.SessionBind, 0) + return context.WithValue(ctx, SessionsKey{}, &sessions) +} + +func (r *ctxKeyring) addBindSession(ctx context.Context, sessions ...extension.SessionBind) { + ctxSessions := ctx.Value(SessionsKey{}).(*[]extension.SessionBind) + if ctxSessions != nil { + *ctxSessions = append(*ctxSessions, sessions...) + } +} + +func (r *ctxKeyring) getBindSessions(ctx context.Context) []extension.SessionBind { + return *ctx.Value(SessionsKey{}).(*[]extension.SessionBind) } // RemoveAll removes all identities. -func (r *keyring) RemoveAll() error { +func (r *legacyKeyring) RemoveAll() error { + return r.k.RemoveAll(r.dummyCtx) +} +func (r *ctxKeyring) RemoveAll(ctx context.Context) error { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -52,7 +87,7 @@ func (r *keyring) RemoveAll() error { // removeLocked does the actual key removal. The caller must already be holding the // keyring mutex. -func (r *keyring) removeLocked(want []byte) error { +func (r *ctxKeyring) removeLocked(want []byte) error { found := false for i := 0; i < len(r.keys); { if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { @@ -72,7 +107,10 @@ func (r *keyring) removeLocked(want []byte) error { } // Remove removes all identities with the given public key. -func (r *keyring) Remove(key ssh.PublicKey) error { +func (r *legacyKeyring) Remove(key ssh.PublicKey) error { + return r.k.Remove(r.dummyCtx, key) +} +func (r *ctxKeyring) Remove(ctx context.Context, key ssh.PublicKey) error { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -83,7 +121,10 @@ func (r *keyring) Remove(key ssh.PublicKey) error { } // Lock locks the agent. Sign and Remove will fail, and List will return an empty list. -func (r *keyring) Lock(passphrase []byte) error { +func (r *legacyKeyring) Lock(passphrase []byte) error { + return r.k.Lock(r.dummyCtx, passphrase) +} +func (r *ctxKeyring) Lock(ctx context.Context, passphrase []byte) error { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -96,7 +137,10 @@ func (r *keyring) Lock(passphrase []byte) error { } // Unlock undoes the effect of Lock -func (r *keyring) Unlock(passphrase []byte) error { +func (r *legacyKeyring) Unlock(passphrase []byte) error { + return r.k.Unlock(r.dummyCtx, passphrase) +} +func (r *ctxKeyring) Unlock(ctx context.Context, passphrase []byte) error { r.mu.Lock() defer r.mu.Unlock() if !r.locked { @@ -114,7 +158,7 @@ func (r *keyring) Unlock(passphrase []byte) error { // expireKeysLocked removes expired keys from the keyring. If a key was added // with a lifetimesecs contraint and seconds >= lifetimesecs seconds have // elapsed, it is removed. The caller *must* be holding the keyring mutex. -func (r *keyring) expireKeysLocked() { +func (r *ctxKeyring) expireKeysLocked() { for _, k := range r.keys { if k.expire != nil && time.Now().After(*k.expire) { r.removeLocked(k.signer.PublicKey().Marshal()) @@ -123,7 +167,10 @@ func (r *keyring) expireKeysLocked() { } // List returns the identities known to the agent. -func (r *keyring) List() ([]*Key, error) { +func (r *legacyKeyring) List() ([]*Key, error) { + return r.k.List(r.dummyCtx) +} +func (r *ctxKeyring) List(ctx context.Context) ([]*Key, error) { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -135,10 +182,18 @@ func (r *keyring) List() ([]*Key, error) { var ids []*Key for _, k := range r.keys { pub := k.signer.PublicKey() + + //if len(k.destinationConstraints) > 0 && len(r.sessions) > 0 { + // if !extension.IdentityPermitted(k.destinationConstraints, r.sessions, "") { + // continue + // } + //} + ids = append(ids, &Key{ Format: pub.Type(), Blob: pub.Marshal(), - Comment: k.comment}) + Comment: k.comment, + }) } return ids, nil } @@ -146,7 +201,10 @@ func (r *keyring) List() ([]*Key, error) { // Insert adds a private key to the keyring. If a certificate // is given, that certificate is added as public key. Note that // any constraints given are ignored. -func (r *keyring) Add(key AddedKey) error { +func (r *legacyKeyring) Add(key AddedKey) error { + return r.k.Add(r.dummyCtx, key) +} +func (r *ctxKeyring) Add(ctx context.Context, key AddedKey) error { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -175,17 +233,39 @@ func (r *keyring) Add(key AddedKey) error { p.expire = &t } + // take care of "official" extensions + for _, constraint := range key.ConstraintExtensions { + switch constraint.ExtensionName { + case extension.EXT_NAME_RESTRICT_DESTINATION_00: + if ctx.Value(SessionsKey{}) == nil { + return errLegacyAgent + } else { + fmt.Println("Taking care of", extension.EXT_NAME_RESTRICT_DESTINATION_00) + + if p.destinationConstraints != nil { + return fmt.Errorf("agent: multiple %s extensions", extension.EXT_NAME_RESTRICT_DESTINATION_00) + } + p.destinationConstraints, err = extension.ParseRestrictDestinations(constraint.ExtensionDetails) + if err != nil { + return err + } + } + } + } + r.keys = append(r.keys, p) return nil } // Sign returns a signature for the data. -func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { - return r.SignWithFlags(key, data, 0) +func (r *legacyKeyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + return r.k.Sign(r.dummyCtx, key, data, 0) } - -func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) { +func (r *legacyKeyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) { + return r.k.Sign(r.dummyCtx, key, data, flags) +} +func (r *ctxKeyring) Sign(ctx context.Context, key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -196,6 +276,58 @@ func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureF wanted := key.Marshal() for _, k := range r.keys { if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { + + if k.destinationConstraints != nil { + sessions := r.getBindSessions(ctx) + if len(sessions) == 0 { + return nil, fmt.Errorf("agent: refusing use of destination-constrained key to sign on unbound connection") + } + lastSessionBind := sessions[len(sessions)-1] + + fmt.Println("#######") + signData, err := extension.ParseSignData(data, wanted) + if err != nil { + return nil, fmt.Errorf("agent: unable to parse sign data") + } + fmt.Printf("signData: %+v\n", signData) + fmt.Printf("lastSessionBind: %+v\n", lastSessionBind) + fmt.Printf("destinationConstraints: %+v\n", k.destinationConstraints) + fmt.Println("#######") + + /*if err := ssh.Unmarshal(data, &signData); err != nil { + fmt.Printf("%+v\n", signData) + fmt.Println(err.Error()) + //return nil, err + //refusing use of destination-constrained key to sign an unidentified signature + }*/ + + if !extension.IdentityPermitted(k.destinationConstraints, sessions, signData.User) { + return nil, fmt.Errorf("agent: destination contrained") + } + + /* + * Ensure that the session ID is the most recent one + * registered on the socket - it should have been bound by + * ssh immediately before userauth. + */ + if !lastSessionBind.Matching(signData.Session) { + return nil, fmt.Errorf("agent: invalid session id") + } + + /* + * Ensure that the hostkey embedded in the signature matches + * the one most recently bound to the socket. An exception is + * made for the initial forwarding hop. + */ + if len(sessions) > 1 && len(signData.HostKey) == 0 { + return nil, fmt.Errorf("agent: refusing use of destination-constrained key: no hostkey recorded in signature for forwarded connection") + } + if len(signData.HostKey) > 0 && !bytes.Equal(signData.HostKey, lastSessionBind.Hostkey) { + fmt.Println("refusing use of destination-constrained key: mismatch between hostkey in request and most recently bound session") + //return nil, fmt.Errorf("agent: refusing use of destination-constrained key: mismatch between hostkey in request and most recently bound session") + } + } + if flags == 0 { return k.signer.Sign(rand.Reader, data) } else { @@ -220,7 +352,10 @@ func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureF } // Signers returns signers for all the known keys. -func (r *keyring) Signers() ([]ssh.Signer, error) { +func (r *legacyKeyring) Signers() ([]ssh.Signer, error) { + return r.k.Signers(r.dummyCtx) +} +func (r *ctxKeyring) Signers(ctx context.Context) ([]ssh.Signer, error) { r.mu.Lock() defer r.mu.Unlock() if r.locked { @@ -235,7 +370,28 @@ func (r *keyring) Signers() ([]ssh.Signer, error) { return s, nil } -// The keyring does not support any extensions -func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) { +// The keyring implements only some extensions +func (r *legacyKeyring) Extension(extensionType string, contents []byte) ([]byte, error) { + return r.k.Extension(r.dummyCtx, extensionType, contents) +} + +func (r *ctxKeyring) Extension(ctx context.Context, extensionType string, contents []byte) ([]byte, error) { + + r.mu.Lock() + defer r.mu.Unlock() + + switch extensionType { + case extension.EXT_NAME_SESSION_BIND: + if ctx.Value(SessionsKey{}) == nil { + return nil, errLegacyAgent + } + sessionBind, err := extension.ParseSessionBind(contents) + if err != nil { + return nil, err + } + r.addBindSession(ctx, sessionBind) + return []byte{agentSuccess}, nil + } + fmt.Println("Extension not implemented:", extensionType) return nil, ErrExtensionUnsupported } diff --git a/ssh/agent/server.go b/ssh/agent/server.go index e35ca7ce31..8f35a57eae 100644 --- a/ssh/agent/server.go +++ b/ssh/agent/server.go @@ -5,6 +5,7 @@ package agent import ( + "context" "crypto/dsa" "crypto/ecdsa" "crypto/ed25519" @@ -23,7 +24,8 @@ import ( // server wraps an Agent and uses it to implement the agent side of // the SSH-agent, wire protocol. type server struct { - agent Agent + agent ContextAgent + ctx context.Context } func (s *server) processRequestBytes(reqData []byte) []byte { @@ -92,10 +94,10 @@ func (s *server) processRequest(data []byte) (interface{}, error) { return nil, err } - return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) + return nil, s.agent.Remove(s.ctx, &Key{Format: wk.Format, Blob: req.KeyBlob}) case agentRemoveAllIdentities: - return nil, s.agent.RemoveAll() + return nil, s.agent.RemoveAll(s.ctx) case agentLock: var req agentLockMsg @@ -103,14 +105,14 @@ func (s *server) processRequest(data []byte) (interface{}, error) { return nil, err } - return nil, s.agent.Lock(req.Passphrase) + return nil, s.agent.Lock(s.ctx, req.Passphrase) case agentUnlock: var req agentUnlockMsg if err := ssh.Unmarshal(data, &req); err != nil { return nil, err } - return nil, s.agent.Unlock(req.Passphrase) + return nil, s.agent.Unlock(s.ctx, req.Passphrase) case agentSignRequest: var req signRequestAgentMsg @@ -128,13 +130,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) { Blob: req.KeyBlob, } - var sig *ssh.Signature - var err error - if extendedAgent, ok := s.agent.(ExtendedAgent); ok { - sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags)) - } else { - sig, err = s.agent.Sign(k, req.Data) - } + sig, err := s.agent.Sign(s.ctx, k, req.Data, SignatureFlags(req.Flags)) if err != nil { return nil, err @@ -142,7 +138,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) { return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil case agentRequestIdentities: - keys, err := s.agent.List() + keys, err := s.agent.List(s.ctx) if err != nil { return nil, err } @@ -164,33 +160,27 @@ func (s *server) processRequest(data []byte) (interface{}, error) { Rest []byte `ssh:"rest"` } - if extendedAgent, ok := s.agent.(ExtendedAgent); !ok { - // If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7 - // requires that we return a standard SSH_AGENT_FAILURE message. - responseStub.Rest = []byte{agentFailure} - } else { - var req extensionAgentMsg - if err := ssh.Unmarshal(data, &req); err != nil { - return nil, err - } - res, err := extendedAgent.Extension(req.ExtensionType, req.Contents) - if err != nil { - // If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE - // message as required by [PROTOCOL.agent] section 4.7. - if err == ErrExtensionUnsupported { - responseStub.Rest = []byte{agentFailure} - } else { - // As the result of any other error processing an extension request, - // [PROTOCOL.agent] section 4.7 requires that we return a - // SSH_AGENT_EXTENSION_FAILURE code. - responseStub.Rest = []byte{agentExtensionFailure} - } + var req extensionAgentMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + res, err := s.agent.Extension(s.ctx, req.ExtensionType, req.Contents) + if err != nil { + // If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE + // message as required by [PROTOCOL.agent] section 4.7. + if err == ErrExtensionUnsupported { + responseStub.Rest = []byte{agentFailure} } else { - if len(res) == 0 { - return nil, nil - } - responseStub.Rest = res + // As the result of any other error processing an extension request, + // [PROTOCOL.agent] section 4.7 requires that we return a + // SSH_AGENT_EXTENSION_FAILURE code. + responseStub.Rest = []byte{agentExtensionFailure} + } + } else { + if len(res) == 0 { + return nil, nil } + responseStub.Rest = res } return responseStub, nil @@ -527,13 +517,25 @@ func (s *server) insertIdentity(req []byte) error { if err != nil { return err } - return s.agent.Add(*addedKey) + return s.agent.Add(s.ctx, *addedKey) } // ServeAgent serves the agent protocol on the given connection. It // returns when an I/O error occurs. -func ServeAgent(agent Agent, c io.ReadWriter) error { - s := &server{agent} +func ServeAgent(agent interface{}, c io.ReadWriter) error { + + var s *server + if ctxagent, ok := agent.(ContextAgent); ok { + s = &server{ + ctxagent, + ctxagent.InitContext(context.TODO()), + } + } else { + s = &server{ + &ctxAgentWrapper{agent.(Agent)}, + context.TODO(), + } + } var length [4]byte for { @@ -568,3 +570,48 @@ func ServeAgent(agent Agent, c io.ReadWriter) error { } } } + +type ctxAgentWrapper struct { + agent Agent +} + +func (wrap *ctxAgentWrapper) InitContext(ctx context.Context) context.Context { + return nil +} +func (wrap *ctxAgentWrapper) List(ctx context.Context) ([]*Key, error) { + return wrap.agent.List() +} +func (wrap *ctxAgentWrapper) Add(ctx context.Context, key AddedKey) error { + return wrap.agent.Add(key) +} +func (wrap *ctxAgentWrapper) Remove(ctx context.Context, key ssh.PublicKey) error { + return wrap.agent.Remove(key) +} +func (wrap *ctxAgentWrapper) RemoveAll(ctx context.Context) error { + return wrap.agent.RemoveAll() +} +func (wrap *ctxAgentWrapper) Lock(ctx context.Context, passphrase []byte) error { + return wrap.agent.Lock(passphrase) +} +func (wrap *ctxAgentWrapper) Unlock(ctx context.Context, passphrase []byte) error { + return wrap.agent.Unlock(passphrase) +} +func (wrap *ctxAgentWrapper) Signers(ctx context.Context) ([]ssh.Signer, error) { + return wrap.agent.Signers() +} +func (wrap *ctxAgentWrapper) Sign(ctx context.Context, key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) { + if extendedAgent, ok := wrap.agent.(ExtendedAgent); ok { + return extendedAgent.SignWithFlags(key, data, flags) + } else { + return wrap.agent.Sign(key, data) + } +} +func (wrap *ctxAgentWrapper) Extension(ctx context.Context, extensionType string, contents []byte) ([]byte, error) { + if extendedAgent, ok := wrap.agent.(ExtendedAgent); ok { + return extendedAgent.Extension(extensionType, contents) + } else { + // If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7 + // requires that we return a standard SSH_AGENT_FAILURE message. + return []byte{agentFailure}, nil + } +}