Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d5ea768

Browse files
committed
feat: aibridge package
1 parent ae36bfa commit d5ea768

File tree

8 files changed

+1117
-18
lines changed

8 files changed

+1117
-18
lines changed

aibridged/aibridged.go

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
package aibridged
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"sync"
8+
"sync/atomic"
9+
"time"
10+
11+
"golang.org/x/xerrors"
12+
13+
"cdr.dev/slog"
14+
"github.com/coder/retry"
15+
16+
"github.com/coder/coder/v2/codersdk"
17+
18+
"github.com/coder/aibridge"
19+
)
20+
21+
type Server interface {
22+
http.Handler
23+
24+
Shutdown(context.Context) error
25+
Close() error
26+
}
27+
28+
// server is the implementation which fulfills the DRPCServer interface.
29+
// It is responsible for communication with the
30+
type server struct {
31+
clientDialer Dialer
32+
clientCh chan DRPCClient
33+
34+
requestBridgePool pooler
35+
36+
logger slog.Logger
37+
wg sync.WaitGroup
38+
39+
// initConnectionCh will receive when the daemon connects to coderd for the
40+
// first time.
41+
initConnectionCh chan struct{}
42+
initConnectionOnce sync.Once
43+
44+
// closeContext is canceled when we start closing.
45+
closeContext context.Context
46+
closeCancel context.CancelFunc
47+
closeOnce sync.Once
48+
// closeError stores the error when closing to return to subsequent callers
49+
closeError error
50+
// closingB is set to true when we start closing
51+
closing atomic.Bool
52+
shutdownOnce sync.Once
53+
// shuttingDownCh will receive when we start graceful shutdown
54+
shuttingDownCh chan struct{}
55+
}
56+
57+
func New(rpcDialer Dialer, cfg aibridge.Config, logger slog.Logger) (Server, error) {
58+
if rpcDialer == nil {
59+
return nil, xerrors.Errorf("nil rpcDialer given")
60+
}
61+
62+
pool, err := NewCachedBridgePool(cfg, logger.Named("aibridge-pool")) // TODO: configurable size.
63+
if err != nil {
64+
return nil, xerrors.Errorf("create aibridge pool: %w", err)
65+
}
66+
67+
ctx, cancel := context.WithCancel(context.Background())
68+
daemon := &server{
69+
logger: logger,
70+
clientDialer: rpcDialer,
71+
requestBridgePool: pool,
72+
clientCh: make(chan DRPCClient),
73+
closeContext: ctx,
74+
closeCancel: cancel,
75+
initConnectionCh: make(chan struct{}),
76+
shuttingDownCh: make(chan struct{}),
77+
}
78+
79+
daemon.wg.Add(1)
80+
go daemon.connect()
81+
82+
return daemon, nil
83+
}
84+
85+
// Connect establishes a connection to coderd.
86+
func (s *server) connect() {
87+
defer s.logger.Debug(s.closeContext, "connect loop exited")
88+
defer s.wg.Done()
89+
logConnect := s.logger.With(slog.F("context", "aibridged.server")).Debug
90+
// An exponential back-off occurs when the connection is failing to dial.
91+
// This is to prevent server spam in case of a coderd outage.
92+
connectLoop:
93+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(s.closeContext); {
94+
// It's possible for the aibridge daemon to be shut down
95+
// before the wait is complete!
96+
if s.isClosed() {
97+
return
98+
}
99+
s.logger.Debug(s.closeContext, "dialing coderd")
100+
client, err := s.clientDialer(s.closeContext)
101+
if err != nil {
102+
if errors.Is(err, context.Canceled) {
103+
return
104+
}
105+
var sdkErr *codersdk.Error
106+
// If something is wrong with our auth, stop trying to connect.
107+
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden {
108+
s.logger.Error(s.closeContext, "not authorized to dial coderd", slog.Error(err))
109+
return
110+
}
111+
if s.isClosed() {
112+
return
113+
}
114+
s.logger.Warn(s.closeContext, "coderd client failed to dial", slog.Error(err))
115+
continue
116+
}
117+
118+
// TODO: log this with INFO level when we implement external aibridge daemons.
119+
logConnect(s.closeContext, "successfully connected to coderd")
120+
retrier.Reset()
121+
s.initConnectionOnce.Do(func() {
122+
close(s.initConnectionCh)
123+
})
124+
125+
// serve the client until we are closed or it disconnects
126+
for {
127+
select {
128+
case <-s.closeContext.Done():
129+
client.DRPCConn().Close()
130+
return
131+
case <-client.DRPCConn().Closed():
132+
logConnect(s.closeContext, "connection to coderd closed")
133+
continue connectLoop
134+
case s.clientCh <- client:
135+
continue
136+
}
137+
}
138+
}
139+
}
140+
141+
func (s *server) Client() (DRPCClient, error) {
142+
select {
143+
case <-s.closeContext.Done():
144+
return nil, xerrors.New("context closed")
145+
case <-s.shuttingDownCh:
146+
// Shutting down should return a nil client and unblock
147+
return nil, xerrors.New("shutting down")
148+
case client := <-s.clientCh:
149+
return client, nil
150+
}
151+
}
152+
153+
// GetRequestHandler retrieves a (possibly reused) *aibridge.RequestBridge from the pool, for the given user.
154+
func (s *server) GetRequestHandler(ctx context.Context, req Request) (http.Handler, error) {
155+
if s.requestBridgePool == nil {
156+
return nil, xerrors.New("nil requestBridgePool")
157+
}
158+
159+
reqBridge, err := s.requestBridgePool.Acquire(ctx, req, s.Client)
160+
if err != nil {
161+
return nil, xerrors.Errorf("acquire request bridge: %w", err)
162+
}
163+
164+
return reqBridge, nil
165+
}
166+
167+
// isClosed returns whether the API is closed or not.
168+
func (s *server) isClosed() bool {
169+
select {
170+
case <-s.closeContext.Done():
171+
return true
172+
default:
173+
return false
174+
}
175+
}
176+
177+
// closeWithError closes aibridged once; subsequent calls will return the error err.
178+
func (s *server) closeWithError(err error) error {
179+
s.closing.Store(true)
180+
s.closeOnce.Do(func() {
181+
s.closeCancel()
182+
s.logger.Debug(context.Background(), "waiting for goroutines to exit")
183+
s.wg.Wait()
184+
s.logger.Debug(context.Background(), "closing server with error", slog.Error(err))
185+
s.closeError = err
186+
})
187+
188+
return s.closeError
189+
}
190+
191+
// Close ends the aibridge daemon.
192+
func (s *server) Close() error {
193+
if s == nil {
194+
return nil
195+
}
196+
197+
s.logger.Info(s.closeContext, "closing aibridged")
198+
return s.closeWithError(nil)
199+
}
200+
201+
// Shutdown waits for all exiting in-flight requests to complete, or the context to expire, whichever comes first.
202+
func (s *server) Shutdown(ctx context.Context) error {
203+
if s == nil {
204+
return nil
205+
}
206+
207+
var err error
208+
s.shutdownOnce.Do(func() {
209+
close(s.shuttingDownCh)
210+
211+
select {
212+
case <-ctx.Done():
213+
s.logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err()))
214+
err = ctx.Err()
215+
return
216+
default:
217+
}
218+
219+
s.logger.Info(ctx, "shutting down aibridged pool")
220+
if err = s.requestBridgePool.Shutdown(ctx); err != nil && errors.Is(err, http.ErrServerClosed) {
221+
s.logger.Error(ctx, "shutdown failed with error", slog.Error(err))
222+
return
223+
}
224+
225+
s.logger.Info(ctx, "gracefully shutdown")
226+
})
227+
return err
228+
}
229+
230+
var DefaultServer Server = &NoopServer{}
231+
232+
var _ Server = &NoopServer{}
233+
234+
type NoopServer struct{}
235+
236+
func (*NoopServer) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
237+
http.Error(w, "no aibridged server", http.StatusBadGateway)
238+
}
239+
240+
func (*NoopServer) Shutdown(context.Context) error {
241+
return nil
242+
}
243+
244+
func (*NoopServer) Close() error {
245+
return nil
246+
}

aibridged/client.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package aibridged
2+
3+
import (
4+
"context"
5+
6+
"storj.io/drpc"
7+
8+
"github.com/coder/coder/v2/aibridged/proto"
9+
)
10+
11+
type Dialer func(ctx context.Context) (DRPCClient, error)
12+
13+
// DRPCClient is the union of various service interfaces the client must support.
14+
type DRPCClient interface {
15+
proto.DRPCRecorderClient
16+
proto.DRPCMCPConfiguratorClient
17+
proto.DRPCAuthenticatorClient
18+
}
19+
20+
var _ DRPCClient = &Client{}
21+
22+
type Client struct {
23+
proto.DRPCRecorderClient
24+
proto.DRPCMCPConfiguratorClient
25+
proto.DRPCAuthenticatorClient
26+
27+
Conn drpc.Conn
28+
}
29+
30+
func (c *Client) DRPCConn() drpc.Conn {
31+
return c.Conn
32+
}

aibridged/http.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package aibridged
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
7+
"github.com/google/uuid"
8+
9+
"cdr.dev/slog"
10+
"github.com/coder/aibridge"
11+
"github.com/coder/coder/v2/aibridged/proto"
12+
)
13+
14+
var _ http.Handler = &server{}
15+
16+
// bridgeAIRequest handles requests destined for an upstream AI provider; aibridged intercepts these requests
17+
// and applies a governance layer.
18+
func (s *server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
19+
ctx := r.Context()
20+
21+
logger := s.logger.With(slog.F("path", r.URL.Path))
22+
23+
key := strings.TrimSpace(extractAuthToken(r))
24+
if key == "" {
25+
logger.Warn(ctx, "no auth key provided")
26+
http.Error(rw, "no authentication key provided", http.StatusBadRequest)
27+
return
28+
}
29+
30+
client, err := s.Client()
31+
if err != nil {
32+
logger.Error(ctx, "failed to connect to coderd", slog.Error(err))
33+
http.Error(rw, "could not connect to coderd", http.StatusInternalServerError)
34+
return
35+
}
36+
37+
resp, err := client.AuthenticateKey(ctx, &proto.AuthenticateKeyRequest{Key: key})
38+
if err != nil {
39+
logger.Error(ctx, "failed to authenticate key", slog.Error(err))
40+
http.Error(rw, "unauthorized", http.StatusForbidden)
41+
return
42+
}
43+
44+
// Rewire request context to include actor.
45+
r = r.WithContext(aibridge.AsActor(ctx, resp.GetOwnerId(), nil))
46+
47+
id, err := uuid.Parse(resp.GetOwnerId())
48+
if err != nil {
49+
logger.Error(ctx, "failed to parse user ID", slog.Error(err), slog.F("id", resp.GetOwnerId()))
50+
http.Error(rw, "unauthorized", http.StatusForbidden)
51+
return
52+
}
53+
54+
handler, err := s.GetRequestHandler(ctx, Request{
55+
SessionKey: key,
56+
InitiatorID: id,
57+
})
58+
if err != nil {
59+
logger.Error(ctx, "failed to handle request", slog.Error(err))
60+
http.Error(rw, "failed to handle request", http.StatusInternalServerError)
61+
return
62+
}
63+
64+
handler.ServeHTTP(rw, r)
65+
}
66+
67+
// extractAuthToken extracts authorization token from HTTP request using multiple sources.
68+
// These sources represent the different ways clients authenticate against AI providers.
69+
// It checks the Authorization header (Bearer token) and X-Api-Key header.
70+
// If neither are present, an empty string is returned.
71+
func extractAuthToken(r *http.Request) string {
72+
// 1. Check Authorization header for Bearer token.
73+
authHeader := r.Header.Get("Authorization")
74+
if authHeader != "" {
75+
segs := strings.Split(authHeader, " ")
76+
if len(segs) > 1 {
77+
if strings.ToLower(segs[0]) == "bearer" {
78+
return strings.Join(segs[1:], "")
79+
}
80+
}
81+
}
82+
83+
// 2. Check X-Api-Key header.
84+
apiKeyHeader := r.Header.Get("X-Api-Key")
85+
if apiKeyHeader != "" {
86+
return apiKeyHeader
87+
}
88+
89+
return ""
90+
}

0 commit comments

Comments
 (0)