Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions aibridged/aibridged.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package aibridged

import (
"context"
"errors"
"net/http"
"sync"
"sync/atomic"
"time"

"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/retry"

"github.com/coder/coder/v2/codersdk"

"github.com/coder/aibridge"

Check failure on line 18 in aibridged/aibridged.go

View workflow job for this annotation

GitHub Actions / test-go-pg-17

github.com/coder/[email protected]: replacement directory /home/coder/aibridge does not exist

Check failure on line 18 in aibridged/aibridged.go

View workflow job for this annotation

GitHub Actions / test-go-pg (ubuntu-latest)

github.com/coder/[email protected]: replacement directory /home/coder/aibridge does not exist

Check failure on line 18 in aibridged/aibridged.go

View workflow job for this annotation

GitHub Actions / test-go-race-pg

github.com/coder/[email protected]: replacement directory /home/coder/aibridge does not exist
)

type Server interface {
http.Handler

Shutdown(context.Context) error
Close() error
}

// server is the implementation which fulfills the DRPCServer interface.
// It is responsible for communication with the
type server struct {
clientDialer Dialer
clientCh chan DRPCClient

requestBridgePool pooler

logger slog.Logger
wg sync.WaitGroup

// initConnectionCh will receive when the daemon connects to coderd for the
// first time.
initConnectionCh chan struct{}
initConnectionOnce sync.Once

// closeContext is canceled when we start closing.
closeContext context.Context
closeCancel context.CancelFunc
closeOnce sync.Once
// closeError stores the error when closing to return to subsequent callers
closeError error
// closingB is set to true when we start closing
closing atomic.Bool
shutdownOnce sync.Once
// shuttingDownCh will receive when we start graceful shutdown
shuttingDownCh chan struct{}
}

func New(rpcDialer Dialer, cfg aibridge.Config, logger slog.Logger) (Server, error) {
if rpcDialer == nil {
return nil, xerrors.Errorf("nil rpcDialer given")
}

pool, err := NewCachedBridgePool(cfg, logger.Named("aibridge-pool")) // TODO: configurable size.
if err != nil {
return nil, xerrors.Errorf("create aibridge pool: %w", err)
}

ctx, cancel := context.WithCancel(context.Background())
daemon := &server{
logger: logger,
clientDialer: rpcDialer,
requestBridgePool: pool,
clientCh: make(chan DRPCClient),
closeContext: ctx,
closeCancel: cancel,
initConnectionCh: make(chan struct{}),
shuttingDownCh: make(chan struct{}),
}

daemon.wg.Add(1)
go daemon.connect()

return daemon, nil
}

// Connect establishes a connection to coderd.
func (s *server) connect() {
defer s.logger.Debug(s.closeContext, "connect loop exited")
defer s.wg.Done()
logConnect := s.logger.With(slog.F("context", "aibridged.server")).Debug
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
connectLoop:
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(s.closeContext); {
// It's possible for the aibridge daemon to be shut down
// before the wait is complete!
if s.isClosed() {
return
}
s.logger.Debug(s.closeContext, "dialing coderd")
client, err := s.clientDialer(s.closeContext)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
var sdkErr *codersdk.Error
// If something is wrong with our auth, stop trying to connect.
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden {
s.logger.Error(s.closeContext, "not authorized to dial coderd", slog.Error(err))
return
}
if s.isClosed() {
return
}
s.logger.Warn(s.closeContext, "coderd client failed to dial", slog.Error(err))
continue
}

// TODO: log this with INFO level when we implement external aibridge daemons.
logConnect(s.closeContext, "successfully connected to coderd")
retrier.Reset()
s.initConnectionOnce.Do(func() {
close(s.initConnectionCh)
})

// serve the client until we are closed or it disconnects
for {
select {
case <-s.closeContext.Done():
client.DRPCConn().Close()
return
case <-client.DRPCConn().Closed():
logConnect(s.closeContext, "connection to coderd closed")
continue connectLoop
case s.clientCh <- client:
continue
}
}
}
}

func (s *server) Client() (DRPCClient, error) {
select {
case <-s.closeContext.Done():
return nil, xerrors.New("context closed")
case <-s.shuttingDownCh:
// Shutting down should return a nil client and unblock
return nil, xerrors.New("shutting down")
case client := <-s.clientCh:
return client, nil
}
}

// GetRequestHandler retrieves a (possibly reused) *aibridge.RequestBridge from the pool, for the given user.
func (s *server) GetRequestHandler(ctx context.Context, req Request) (http.Handler, error) {
if s.requestBridgePool == nil {
return nil, xerrors.New("nil requestBridgePool")
}

reqBridge, err := s.requestBridgePool.Acquire(ctx, req, s.Client)
if err != nil {
return nil, xerrors.Errorf("acquire request bridge: %w", err)
}

return reqBridge, nil
}

// isClosed returns whether the API is closed or not.
func (s *server) isClosed() bool {
select {
case <-s.closeContext.Done():
return true
default:
return false
}
}

// closeWithError closes aibridged once; subsequent calls will return the error err.
func (s *server) closeWithError(err error) error {
s.closing.Store(true)
s.closeOnce.Do(func() {
s.closeCancel()
s.logger.Debug(context.Background(), "waiting for goroutines to exit")
s.wg.Wait()
s.logger.Debug(context.Background(), "closing server with error", slog.Error(err))
s.closeError = err
})

return s.closeError
}

// Close ends the aibridge daemon.
func (s *server) Close() error {
if s == nil {
return nil
}

s.logger.Info(s.closeContext, "closing aibridged")
return s.closeWithError(nil)
}

// Shutdown waits for all exiting in-flight requests to complete, or the context to expire, whichever comes first.
func (s *server) Shutdown(ctx context.Context) error {
if s == nil {
return nil
}

var err error
s.shutdownOnce.Do(func() {
close(s.shuttingDownCh)

select {
case <-ctx.Done():
s.logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err()))
err = ctx.Err()
return
default:
}

s.logger.Info(ctx, "shutting down aibridged pool")
if err = s.requestBridgePool.Shutdown(ctx); err != nil && errors.Is(err, http.ErrServerClosed) {
s.logger.Error(ctx, "shutdown failed with error", slog.Error(err))
return
}

s.logger.Info(ctx, "gracefully shutdown")
})
return err
}

var DefaultServer Server = &NoopServer{}

var _ Server = &NoopServer{}

type NoopServer struct{}

func (*NoopServer) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "no aibridged server", http.StatusBadGateway)
}

func (*NoopServer) Shutdown(context.Context) error {
return nil
}

func (*NoopServer) Close() error {
return nil
}
32 changes: 32 additions & 0 deletions aibridged/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package aibridged

import (
"context"

"storj.io/drpc"

"github.com/coder/coder/v2/aibridged/proto"
)

type Dialer func(ctx context.Context) (DRPCClient, error)

// DRPCClient is the union of various service interfaces the client must support.
type DRPCClient interface {
proto.DRPCRecorderClient
proto.DRPCMCPConfiguratorClient
proto.DRPCAuthenticatorClient
}

var _ DRPCClient = &Client{}

type Client struct {
proto.DRPCRecorderClient
proto.DRPCMCPConfiguratorClient
proto.DRPCAuthenticatorClient

Conn drpc.Conn
}

func (c *Client) DRPCConn() drpc.Conn {
return c.Conn
}
90 changes: 90 additions & 0 deletions aibridged/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package aibridged

import (
"net/http"
"strings"

"github.com/google/uuid"

"cdr.dev/slog"
"github.com/coder/aibridge"
"github.com/coder/coder/v2/aibridged/proto"
)

var _ http.Handler = &server{}

// bridgeAIRequest handles requests destined for an upstream AI provider; aibridged intercepts these requests
// and applies a governance layer.
func (s *server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

logger := s.logger.With(slog.F("path", r.URL.Path))

key := strings.TrimSpace(extractAuthToken(r))
if key == "" {
logger.Warn(ctx, "no auth key provided")
http.Error(rw, "no authentication key provided", http.StatusBadRequest)
return
}

client, err := s.Client()
if err != nil {
logger.Error(ctx, "failed to connect to coderd", slog.Error(err))
http.Error(rw, "could not connect to coderd", http.StatusInternalServerError)
return
}

resp, err := client.AuthenticateKey(ctx, &proto.AuthenticateKeyRequest{Key: key})
if err != nil {
logger.Error(ctx, "failed to authenticate key", slog.Error(err))
http.Error(rw, "unauthorized", http.StatusForbidden)
return
}

// Rewire request context to include actor.
r = r.WithContext(aibridge.AsActor(ctx, resp.GetOwnerId(), nil))

id, err := uuid.Parse(resp.GetOwnerId())
if err != nil {
logger.Error(ctx, "failed to parse user ID", slog.Error(err), slog.F("id", resp.GetOwnerId()))
http.Error(rw, "unauthorized", http.StatusForbidden)
return
}

handler, err := s.GetRequestHandler(ctx, Request{
SessionKey: key,
InitiatorID: id,
})
if err != nil {
logger.Error(ctx, "failed to handle request", slog.Error(err))
http.Error(rw, "failed to handle request", http.StatusInternalServerError)
return
}

handler.ServeHTTP(rw, r)
}

// extractAuthToken extracts authorization token from HTTP request using multiple sources.
// These sources represent the different ways clients authenticate against AI providers.
// It checks the Authorization header (Bearer token) and X-Api-Key header.
// If neither are present, an empty string is returned.
func extractAuthToken(r *http.Request) string {
// 1. Check Authorization header for Bearer token.
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
segs := strings.Split(authHeader, " ")
if len(segs) > 1 {
if strings.ToLower(segs[0]) == "bearer" {
return strings.Join(segs[1:], "")
}
}
}

// 2. Check X-Api-Key header.
apiKeyHeader := r.Header.Get("X-Api-Key")
if apiKeyHeader != "" {
return apiKeyHeader
}

return ""
}
Loading
Loading