diff --git a/cmd/run/run.go b/cmd/run/run.go index 5668e06..7a3a885 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -258,19 +258,19 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return err } + interactiveMode := true initialPrompt := "" - singleShot := false pipedContent := "" if len(args) > 1 { initialPrompt = strings.Join(args[1:], " ") - singleShot = true + interactiveMode = false } if isPipe(os.Stdin) { promptFromPipe, _ := io.ReadAll(os.Stdin) if len(promptFromPipe) > 0 { - singleShot = true + interactiveMode = false pipedContent = strings.TrimSpace(string(promptFromPipe)) if initialPrompt != "" { initialPrompt = initialPrompt + "\n" + pipedContent @@ -289,35 +289,29 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { systemPrompt: systemPrompt, } - // If a prompt file is passed, load the messages from the file, templating {{input}} from stdin - if pf != nil { + // If there is no prompt file, add the initialPrompt to the conversation. + // If a prompt file is passed, load the messages from the file, templating {{input}} + // using the initialPrompt. + if pf == nil { + conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt) + } else { + interactiveMode = false + for _, m := range pf.Messages { content := m.Content - if strings.ToLower(m.Role) == "user" { - content = strings.ReplaceAll(content, "{{input}}", initialPrompt) - } switch strings.ToLower(m.Role) { case "system": - if conversation.systemPrompt == "" { - conversation.systemPrompt = content - } else { - conversation.AddMessage(azuremodels.ChatMessageRoleSystem, content) - } + conversation.systemPrompt = content case "user": + content = strings.ReplaceAll(content, "{{input}}", initialPrompt) conversation.AddMessage(azuremodels.ChatMessageRoleUser, content) case "assistant": conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content) } } - - initialPrompt = "" } mp := ModelParameters{} - err = mp.PopulateFromFlags(cmd.Flags()) - if err != nil { - return err - } if pf != nil { mp.maxTokens = pf.ModelParameters.MaxTokens @@ -325,64 +319,21 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { mp.topP = pf.ModelParameters.TopP } - for { - prompt := "" - if initialPrompt != "" { - prompt = initialPrompt - initialPrompt = "" - } - - if prompt == "" && pf == nil { - fmt.Printf(">>> ") - reader := bufio.NewReader(os.Stdin) - prompt, err = reader.ReadString('\n') - if err != nil { - return err - } - } - - prompt = strings.TrimSpace(prompt) - - if prompt == "" && pf == nil { - continue - } + err = mp.PopulateFromFlags(cmd.Flags()) + if err != nil { + return err + } - if strings.HasPrefix(prompt, "/") { - if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + for { + if interactiveMode { + conversation, err = cmdHandler.ChatWithUser(conversation, mp) + if errors.Is(err, ErrExitChat) || errors.Is(err, io.EOF) { break + } else if err != nil { + return err } - - if prompt == "/parameters" { - cmdHandler.handleParametersPrompt(conversation, mp) - continue - } - - if prompt == "/reset" || prompt == "/clear" { - cmdHandler.handleResetPrompt(conversation) - continue - } - - if strings.HasPrefix(prompt, "/set ") { - cmdHandler.handleSetPrompt(prompt, mp) - continue - } - - if strings.HasPrefix(prompt, "/system-prompt ") { - conversation = cmdHandler.handleSystemPrompt(prompt, conversation) - continue - } - - if prompt == "/help" { - cmdHandler.handleHelpPrompt() - continue - } - - cmdHandler.handleUnrecognizedPrompt(prompt) - continue } - conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) - req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, @@ -431,7 +382,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) - if singleShot || pf != nil { + if !interactiveMode { break } } @@ -618,3 +569,57 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice func (h *runCommandHandler) writeToOut(message string) { h.cfg.WriteToOut(message) } + +var ErrExitChat = errors.New("exiting chat") + +func (h *runCommandHandler) ChatWithUser(conversation Conversation, mp ModelParameters) (Conversation, error) { + fmt.Printf(">>> ") + reader := bufio.NewReader(os.Stdin) + + prompt, err := reader.ReadString('\n') + if err != nil { + return conversation, err + } + + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return conversation, nil + } + + if strings.HasPrefix(prompt, "/") { + if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + return conversation, ErrExitChat + } + + if prompt == "/parameters" { + h.handleParametersPrompt(conversation, mp) + return conversation, nil + } + + if prompt == "/reset" || prompt == "/clear" { + h.handleResetPrompt(conversation) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/set ") { + h.handleSetPrompt(prompt, mp) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/system-prompt ") { + conversation = h.handleSystemPrompt(prompt, conversation) + return conversation, nil + } + + if prompt == "/help" { + h.handleHelpPrompt() + return conversation, nil + } + + h.handleUnrecognizedPrompt(prompt) + return conversation, nil + } + + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) + return conversation, nil +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 27cc468..7395e7c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -139,7 +139,7 @@ messages: _, err = runCmd.ExecuteC() require.NoError(t, err) - require.Equal(t, 3, len(capturedReq.Messages)) + require.Equal(t, 2, len(capturedReq.Messages)) require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content) @@ -220,10 +220,114 @@ messages: _, err = runCmd.ExecuteC() require.NoError(t, err) - require.Len(t, capturedReq.Messages, 3) + require.Len(t, capturedReq.Messages, 2) require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) require.Equal(t, initialPrompt+"\n"+piped, *capturedReq.Messages[1].Content) // {{input}} -> "Please summarize the provided text.\nHello there!" require.Contains(t, out.String(), reply) }) + + t.Run("cli flags override params set in the prompt.yaml file", func(t *testing.T) { + // Begin setup: + const yamlBody = ` + name: Example Prompt + description: Example description + model: openai/example-model + modelParameters: + maxTokens: 300 + temperature: 0.8 + topP: 0.9 + messages: + - role: system + content: System message + - role: user + content: User message + ` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yaml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "example-model", + Publisher: "openai", + Task: "chat-completion", + } + modelSummary2 := &azuremodels.ModelSummary{ + Name: "example-model-4o-mini-plus", + Publisher: "openai", + Task: "chat-completion", + } + + client.MockListModels = func(ctx context.Context) ([]*azuremodels. + ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary, modelSummary2}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "hello" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + + // End setup. + // --- + // We're finally ready to start making assertions. + + // Test case 1: with no flags, the model params come from the YAML file + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model", capturedReq.Model) + require.Equal(t, 300, *capturedReq.MaxTokens) + require.Equal(t, 0.8, *capturedReq.Temperature) + require.Equal(t, 0.9, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + + // Hooray! + // Test case 2: values from flags override the params from the YAML file + runCmd = NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "openai/example-model-4o-mini-plus", + "--file", tmp.Name(), + "--max-tokens", "150", + "--temperature", "0.1", + "--top-p", "0.3", + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model-4o-mini-plus", capturedReq.Model) + require.Equal(t, 150, *capturedReq.MaxTokens) + require.Equal(t, 0.1, *capturedReq.Temperature) + require.Equal(t, 0.3, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + }) } diff --git a/script/build b/script/build index f481d7c..bfd66d7 100755 --- a/script/build +++ b/script/build @@ -28,6 +28,8 @@ fi if [[ "$OS" == "linux" || "$OS" == "all" ]]; then GOOS=linux GOARCH=amd64 build + GOOS=android GOARCH=arm64 build + GOOS=android GOARCH=amd64 build fi if [[ "$OS" == "darwin" || "$OS" == "all" ]]; then diff --git a/script/upload-release b/script/upload-release index db0e3ec..7c215f3 100755 --- a/script/upload-release +++ b/script/upload-release @@ -11,6 +11,6 @@ if [ -z $TAG ]; then fi shift -BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe" +BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe gh-models-android-arm64 gh-models-android-amd64" gh release upload $* $TAG $BINARIES