Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: Add peerbroker proxy for agent connections #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions peerbroker/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package peerbroker

import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"

"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"

"cdr.dev/slog"
"github.com/coder/coder/database"
"github.com/coder/coder/peerbroker/proto"
)

var (
// Each NegotiateConnection() function call spawns a new stream.
streamIDLength = len(uuid.NewString())
// We shouldn't PubSub anything larger than this!
maxPayloadSizeBytes = 8192
)

// ProxyOptions provides values to configure a proxy.
type ProxyOptions struct {
ChannelID string
Logger slog.Logger
Pubsub database.Pubsub
}

// ProxyDial writes client negotiation streams over PubSub.
//
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
// messages are small in size (<=8KB), and we don't require delivery
// guarantees because connections can always be renegotiated.
// ┌────────────────────┐ ┌─────────────────────────────┐
// │ coderd │ │ coderd │
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
// │ client │ │ │ │ │ ┌─────┐
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
// │ │ │NegotiateConnection() streams│
// │<stream-id><payload>│ │or write to existing ones. │
// └────────────────────┘ └─────────────────────────────┘
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love me some inline diagrams like this! 🖼️

func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
proxyDial := &proxyDial{
channelID: options.ChannelID,
logger: options.Logger,
pubsub: options.Pubsub,
connection: client,
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
}
return proxyDial, proxyDial.listen()
}

// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
// as new NegotiateConnection() streams.
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
mux := drpcmux.New()
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
channelID: options.ChannelID,
pubsub: options.Pubsub,
logger: options.Logger,
})
if err != nil {
return xerrors.Errorf("register peer broker: %w", err)
}
server := drpcserver.New(mux)
err = server.Serve(ctx, connListener)
if err != nil {
if errors.Is(err, yamux.ErrSessionShutdown) {
return nil
}
return xerrors.Errorf("serve: %w", err)
}
return nil
}

type proxyListen struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
}

func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
streamID := uuid.NewString()
var err error
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onServerToClientMessage(streamID, stream, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
})
if err != nil {
return xerrors.Errorf("subscribe: %w", err)
}
defer closeSubscribe()
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(clientToServerMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyOutID(p.channelID), data)
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}

func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
serverStreamID := string(message[0:streamIDLength])
if serverStreamID != streamID {
// It's not trying to communicate with this stream!
return nil
}
var msg proto.NegotiateConnection_ServerToClient
err := protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("send message: %w", err)
}
return nil
}

type proxyDial struct {
channelID string
pubsub database.Pubsub
logger slog.Logger

connection proto.DRPCPeerBrokerClient
closeSubscribe func()
streamMutex sync.Mutex
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
}

func (p *proxyDial) listen() error {
var err error
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onClientToServerMessage(ctx, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
}
})
if err != nil {
return err
}
return nil
}

func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
var err error
streamID := string(message[0:streamIDLength])
p.streamMutex.Lock()
stream, ok := p.streams[streamID]
if !ok {
stream, err = p.connection.NegotiateConnection(ctx)
if err != nil {
p.streamMutex.Unlock()
return xerrors.Errorf("negotiate connection: %w", err)
}
p.streams[streamID] = stream
go func() {
defer stream.Close()

err = p.onServerToClientMessage(streamID, stream)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
}()
go func() {
<-stream.Context().Done()
p.streamMutex.Lock()
delete(p.streams, streamID)
p.streamMutex.Unlock()
}()
}
p.streamMutex.Unlock()

var msg proto.NegotiateConnection_ClientToServer
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("write message: %w", err)
}
return nil
}

func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, context.Canceled) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(serverToClientMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyInID(p.channelID), data)
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}

func (p *proxyDial) Close() error {
p.streamMutex.Lock()
defer p.streamMutex.Unlock()
p.closeSubscribe()
return nil
}

func proxyOutID(channelID string) string {
return fmt.Sprintf("%s-out", channelID)
}

func proxyInID(channelID string) string {
return fmt.Sprintf("%s-in", channelID)
}
81 changes: 81 additions & 0 deletions peerbroker/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package peerbroker_test

import (
"context"
"sync"
"testing"

"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/database"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)

func TestProxy(t *testing.T) {
t.Parallel()
ctx := context.Background()
channelID := "hello"
pubsub := database.NewPubsubInMemory()
dialerClient, dialerServer := provisionersdk.TransportPipe()
defer dialerClient.Close()
defer dialerServer.Close()
listenerClient, listenerServer := provisionersdk.TransportPipe()
defer listenerClient.Close()
defer listenerServer.Close()

listener, err := peerbroker.Listen(listenerServer, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)

proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = proxyCloser.Close()
})

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
require.NoError(t, err)
}()

api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer clientConn.Close()

serverConn, err := listener.Accept()
require.NoError(t, err)
defer serverConn.Close()
_, err = serverConn.Ping()
require.NoError(t, err)

_, err = clientConn.Ping()
require.NoError(t, err)

_ = dialerServer.Close()
wg.Wait()
}