-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathchat.go
More file actions
153 lines (138 loc) · 3.91 KB
/
chat.go
File metadata and controls
153 lines (138 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"github.com/kylecarbs/aisdk-go"
"golang.org/x/xerrors"
)
// CreateChat creates a new chat.
func (c *Client) CreateChat(ctx context.Context) (Chat, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/v2/chats", nil)
if err != nil {
return Chat{}, xerrors.Errorf("execute request: %w", err)
}
if res.StatusCode != http.StatusCreated {
return Chat{}, ReadBodyAsError(res)
}
defer res.Body.Close()
var chat Chat
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
type Chat struct {
ID uuid.UUID `json:"id" format:"uuid"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
Title string `json:"title"`
}
// ListChats lists all chats.
func (c *Client) ListChats(ctx context.Context) ([]Chat, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/v2/chats", nil)
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var chats []Chat
return chats, json.NewDecoder(res.Body).Decode(&chats)
}
// Chat returns a chat by ID.
func (c *Client) Chat(ctx context.Context, id uuid.UUID) (Chat, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s", id), nil)
if err != nil {
return Chat{}, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return Chat{}, ReadBodyAsError(res)
}
var chat Chat
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// ChatMessages returns the messages of a chat.
func (c *Client) ChatMessages(ctx context.Context, id uuid.UUID) ([]ChatMessage, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s/messages", id), nil)
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var messages []ChatMessage
return messages, json.NewDecoder(res.Body).Decode(&messages)
}
type ChatMessage = aisdk.Message
type CreateChatMessageRequest struct {
Model string `json:"model"`
Message ChatMessage `json:"message"`
Thinking bool `json:"thinking"`
}
// CreateChatMessage creates a new chat message and streams the response.
// If the provided message has a conflicting ID with an existing message,
// it will be overwritten.
func (c *Client) CreateChatMessage(ctx context.Context, id uuid.UUID, req CreateChatMessageRequest) (<-chan aisdk.DataStreamPart, error) {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/chats/%s/messages", id), req)
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
nextEvent := ServerSentEventReader(ctx, res.Body)
wc := make(chan aisdk.DataStreamPart, 256)
go func() {
defer close(wc)
defer res.Body.Close()
for {
select {
case <-ctx.Done():
return
default:
sse, err := nextEvent()
if err != nil {
return
}
if sse.Type != ServerSentEventTypeData {
continue
}
var part aisdk.DataStreamPart
b, ok := sse.Data.([]byte)
if !ok {
return
}
err = json.Unmarshal(b, &part)
if err != nil {
return
}
select {
case <-ctx.Done():
return
case wc <- part:
}
}
}
}()
return wc, nil
}
func (c *Client) DeleteChat(ctx context.Context, id uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/chats/%s", id), nil)
if err != nil {
return xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}