diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 16b158d..5d8eb39 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,6 +11,9 @@ on: permissions: contents: write + id-token: write + attestations: write + jobs: release: @@ -22,3 +25,5 @@ jobs: go_version_file: go.mod release_tag: ${{ github.event.inputs.release_tag || '' }} generate_attestations: true + release_android: true + android_sdk_version: 34 diff --git a/.gitignore b/.gitignore index 7b903ed..54f9c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /gh-models-darwin-* /gh-models-linux-* /gh-models-windows-* +/gh-models-android-* diff --git a/cmd/run/run.go b/cmd/run/run.go index 4c60f03..7a3a885 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -21,6 +21,7 @@ import ( "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" + "gopkg.in/yaml.v3" ) // ModelParameters represents the parameters that can be set for a model run. @@ -188,6 +189,22 @@ func isPipe(r io.Reader) bool { return false } +// promptFile mirrors the format of .prompt.yml +type promptFile struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Model string `yaml:"model"` + ModelParameters struct { + MaxTokens *int `yaml:"maxTokens"` + Temperature *float64 `yaml:"temperature"` + TopP *float64 `yaml:"topP"` + } `yaml:"modelParameters"` + Messages []struct { + Role string `yaml:"role"` + Content string `yaml:"content"` + } `yaml:"messages"` +} + // NewRunCommand returns a new gh command for running a model. func NewRunCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ @@ -208,6 +225,24 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { + filePath, _ := cmd.Flags().GetString("file") + var pf *promptFile + if filePath != "" { + b, err := os.ReadFile(filePath) + if err != nil { + return err + } + p := promptFile{} + if err := yaml.Unmarshal(b, &p); err != nil { + return err + } + pf = &p + // Inject model name as the first positional arg if user didn't supply one + if pf.Model != "" && len(args) == 0 { + args = append([]string{pf.Model}, args...) + } + } + cmdHandler := newRunCommandHandler(cmd, cfg, args) if cmdHandler == nil { return nil @@ -223,19 +258,25 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return err } + interactiveMode := true initialPrompt := "" - singleShot := false + pipedContent := "" if len(args) > 1 { initialPrompt = strings.Join(args[1:], " ") - singleShot = true + interactiveMode = false } if isPipe(os.Stdin) { promptFromPipe, _ := io.ReadAll(os.Stdin) if len(promptFromPipe) > 0 { - initialPrompt = initialPrompt + "\n" + string(promptFromPipe) - singleShot = true + interactiveMode = false + pipedContent = strings.TrimSpace(string(promptFromPipe)) + if initialPrompt != "" { + initialPrompt = initialPrompt + "\n" + pipedContent + } else { + initialPrompt = pipedContent + } } } @@ -248,70 +289,51 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { systemPrompt: systemPrompt, } + // If there is no prompt file, add the initialPrompt to the conversation. + // If a prompt file is passed, load the messages from the file, templating {{input}} + // using the initialPrompt. + if pf == nil { + conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt) + } else { + interactiveMode = false + + for _, m := range pf.Messages { + content := m.Content + switch strings.ToLower(m.Role) { + case "system": + conversation.systemPrompt = content + case "user": + content = strings.ReplaceAll(content, "{{input}}", initialPrompt) + conversation.AddMessage(azuremodels.ChatMessageRoleUser, content) + case "assistant": + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content) + } + } + } + mp := ModelParameters{} + + if pf != nil { + mp.maxTokens = pf.ModelParameters.MaxTokens + mp.temperature = pf.ModelParameters.Temperature + mp.topP = pf.ModelParameters.TopP + } + err = mp.PopulateFromFlags(cmd.Flags()) if err != nil { return err } for { - prompt := "" - if initialPrompt != "" { - prompt = initialPrompt - initialPrompt = "" - } - - if prompt == "" { - fmt.Printf(">>> ") - reader := bufio.NewReader(os.Stdin) - prompt, err = reader.ReadString('\n') - if err != nil { - return err - } - } - - prompt = strings.TrimSpace(prompt) - - if prompt == "" { - continue - } - - if strings.HasPrefix(prompt, "/") { - if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + if interactiveMode { + conversation, err = cmdHandler.ChatWithUser(conversation, mp) + if errors.Is(err, ErrExitChat) || errors.Is(err, io.EOF) { break + } else if err != nil { + return err } - - if prompt == "/parameters" { - cmdHandler.handleParametersPrompt(conversation, mp) - continue - } - - if prompt == "/reset" || prompt == "/clear" { - cmdHandler.handleResetPrompt(conversation) - continue - } - - if strings.HasPrefix(prompt, "/set ") { - cmdHandler.handleSetPrompt(prompt, mp) - continue - } - - if strings.HasPrefix(prompt, "/system-prompt ") { - conversation = cmdHandler.handleSystemPrompt(prompt, conversation) - continue - } - - if prompt == "/help" { - cmdHandler.handleHelpPrompt() - continue - } - - cmdHandler.handleUnrecognizedPrompt(prompt) - continue } - conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) - req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, @@ -360,7 +382,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) - if singleShot { + if !interactiveMode { break } } @@ -369,6 +391,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { }, } + cmd.Flags().String("file", "", "Path to a .prompt.yml file.") cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.") cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") @@ -546,3 +569,57 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice func (h *runCommandHandler) writeToOut(message string) { h.cfg.WriteToOut(message) } + +var ErrExitChat = errors.New("exiting chat") + +func (h *runCommandHandler) ChatWithUser(conversation Conversation, mp ModelParameters) (Conversation, error) { + fmt.Printf(">>> ") + reader := bufio.NewReader(os.Stdin) + + prompt, err := reader.ReadString('\n') + if err != nil { + return conversation, err + } + + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return conversation, nil + } + + if strings.HasPrefix(prompt, "/") { + if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + return conversation, ErrExitChat + } + + if prompt == "/parameters" { + h.handleParametersPrompt(conversation, mp) + return conversation, nil + } + + if prompt == "/reset" || prompt == "/clear" { + h.handleResetPrompt(conversation) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/set ") { + h.handleSetPrompt(prompt, mp) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/system-prompt ") { + conversation = h.handleSystemPrompt(prompt, conversation) + return conversation, nil + } + + if prompt == "/help" { + h.handleHelpPrompt() + return conversation, nil + } + + h.handleUnrecognizedPrompt(prompt) + return conversation, nil + } + + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) + return conversation, nil +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 5b88cfa..7395e7c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -3,6 +3,7 @@ package run import ( "bytes" "context" + "os" "regexp" "testing" @@ -80,4 +81,253 @@ func TestRun(t *testing.T) { require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output) require.Empty(t, errBuf.String()) }) + + t.Run("--file pre-loads YAML from file", func(t *testing.T) { + const yamlBody = ` +name: Text Summarizer +description: Summarizes input text concisely +model: openai/test-model +modelParameters: + temperature: 0.5 +messages: + - role: system + content: You are a text summarizer. + - role: user + content: Hello there! +` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "test-model", + Publisher: "openai", + Task: "chat-completion", + } + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "Summary - foo" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + azuremodels.FormatIdentifier("openai", "test-model"), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, 2, len(capturedReq.Messages)) + require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) + require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content) + + require.NotNil(t, capturedReq.Temperature) + require.Equal(t, 0.5, *capturedReq.Temperature) + + require.Contains(t, out.String(), reply) // response streamed to output + }) + + t.Run("--file with {{input}} placeholder is substituted with initial prompt and stdin", func(t *testing.T) { + const yamlBody = ` +name: Summarizer +description: Summarizes input text +model: openai/test-model +messages: + - role: system + content: You are a text summarizer. + - role: user + content: "{{input}}" +` + + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "test-model", + Publisher: "openai", + Task: "chat-completion", + } + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "Summary - bar" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + // create a pipe to fake stdin so that isPipe(os.Stdin)==true + r, w, err := os.Pipe() + require.NoError(t, err) + oldStdin := os.Stdin + os.Stdin = r + defer func() { os.Stdin = oldStdin }() + piped := "Hello there!" + go func() { + _, _ = w.Write([]byte(piped)) + _ = w.Close() + }() + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + initialPrompt := "Please summarize the provided text." + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + azuremodels.FormatIdentifier("openai", "test-model"), + initialPrompt, + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Len(t, capturedReq.Messages, 2) + require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) + require.Equal(t, initialPrompt+"\n"+piped, *capturedReq.Messages[1].Content) // {{input}} -> "Please summarize the provided text.\nHello there!" + + require.Contains(t, out.String(), reply) + }) + + t.Run("cli flags override params set in the prompt.yaml file", func(t *testing.T) { + // Begin setup: + const yamlBody = ` + name: Example Prompt + description: Example description + model: openai/example-model + modelParameters: + maxTokens: 300 + temperature: 0.8 + topP: 0.9 + messages: + - role: system + content: System message + - role: user + content: User message + ` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yaml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "example-model", + Publisher: "openai", + Task: "chat-completion", + } + modelSummary2 := &azuremodels.ModelSummary{ + Name: "example-model-4o-mini-plus", + Publisher: "openai", + Task: "chat-completion", + } + + client.MockListModels = func(ctx context.Context) ([]*azuremodels. + ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary, modelSummary2}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "hello" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + + // End setup. + // --- + // We're finally ready to start making assertions. + + // Test case 1: with no flags, the model params come from the YAML file + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model", capturedReq.Model) + require.Equal(t, 300, *capturedReq.MaxTokens) + require.Equal(t, 0.8, *capturedReq.Temperature) + require.Equal(t, 0.9, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + + // Hooray! + // Test case 2: values from flags override the params from the YAML file + runCmd = NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "openai/example-model-4o-mini-plus", + "--file", tmp.Name(), + "--max-tokens", "150", + "--temperature", "0.1", + "--top-p", "0.3", + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model-4o-mini-plus", capturedReq.Model) + require.Equal(t, 150, *capturedReq.MaxTokens) + require.Equal(t, 0.1, *capturedReq.Temperature) + require.Equal(t, 0.3, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + }) } diff --git a/go.mod b/go.mod index e9559f6..56dae7e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.10.0 golang.org/x/text v0.23.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -49,5 +50,4 @@ require ( golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/term v0.30.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/s.prompt.yml b/s.prompt.yml new file mode 100644 index 0000000..b8b577f --- /dev/null +++ b/s.prompt.yml @@ -0,0 +1,14 @@ +name: Text Summarizer +description: Summarizes input text concisely +model: openai/gpt-4o-mini +modelParameters: + temperature: 0.5 +messages: + - role: system + content: You are a text summarizer. Your only job is to summarize text given to you. + - role: user + content: | + Summarize the given text, beginning with "Summary -": + + {{input}} + \ No newline at end of file diff --git a/script/build b/script/build index f481d7c..bfd66d7 100755 --- a/script/build +++ b/script/build @@ -28,6 +28,8 @@ fi if [[ "$OS" == "linux" || "$OS" == "all" ]]; then GOOS=linux GOARCH=amd64 build + GOOS=android GOARCH=arm64 build + GOOS=android GOARCH=amd64 build fi if [[ "$OS" == "darwin" || "$OS" == "all" ]]; then diff --git a/script/upload-release b/script/upload-release index db0e3ec..7c215f3 100755 --- a/script/upload-release +++ b/script/upload-release @@ -11,6 +11,6 @@ if [ -z $TAG ]; then fi shift -BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe" +BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe gh-models-android-arm64 gh-models-android-amd64" gh release upload $* $TAG $BINARIES